package org.allenai.ml.sequences.crf.conll;

import com.gs.collections.api.list.ImmutableList;
import com.gs.collections.api.map.primitive.ObjectDoubleMap;
import com.gs.collections.api.tuple.Pair;
import com.gs.collections.api.tuple.primitive.IntIntPair;
import com.gs.collections.impl.factory.Lists;
import com.gs.collections.impl.map.mutable.primitive.ObjectDoubleHashMap;
import com.gs.collections.impl.tuple.Tuples;
import com.gs.collections.impl.tuple.primitive.PrimitiveTuples;
import java.beans.ConstructorProperties;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.allenai.ml.linalg.DenseVector;
import org.allenai.ml.linalg.Vector;
import org.allenai.ml.sequences.StateSpace;
import org.allenai.ml.sequences.crf.CRFFeatureEncoder;
import org.allenai.ml.sequences.crf.CRFModel;
import org.allenai.ml.sequences.crf.CRFPredicateExtractor;
import org.allenai.ml.sequences.crf.CRFWeightsEncoder;
import org.allenai.ml.util.IOUtils;
import org.allenai.ml.util.Indexer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/allenai/ml/sequences/crf/conll/ConllFormat.class */
public class ConllFormat {
    private static final Logger log = LoggerFactory.getLogger(ConllFormat.class);
    public static final String startState = "<s>";
    public static final String stopState = "</s>";
    private static final String DATA_VERSION = "1.1";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/allenai/ml/sequences/crf/conll/ConllFormat$ConllPredicateExtractor.class */
    public static class ConllPredicateExtractor implements CRFPredicateExtractor<Row, String> {
        private final List<FeatureTemplate> nodeTemplates;
        private final List<FeatureTemplate> edgeTemplates;
        static final /* synthetic */ boolean $assertionsDisabled;

        private static List<ObjectDoubleMap<String>> buildPredVals(List<FeatureTemplate> list, List<Row> list2) {
            ArrayList arrayList = new ArrayList(list2.size());
            for (int i = 0; i < list2.size(); i++) {
                ObjectDoubleHashMap objectDoubleHashMap = new ObjectDoubleHashMap(list.size());
                if (!$assertionsDisabled && !list2.get(0).features.equals(Arrays.asList(ConllFormat.startState))) {
                    throw new AssertionError();
                }
                if (!$assertionsDisabled && !list2.get(list2.size() - 1).features.equals(Arrays.asList(ConllFormat.stopState))) {
                    throw new AssertionError();
                }
                Iterator<FeatureTemplate> it = list.iterator();
                while (it.hasNext()) {
                    String value = it.next().value(list2, i);
                    if (value != null) {
                        objectDoubleHashMap.put(value, 1.0d);
                    }
                }
                arrayList.add(objectDoubleHashMap);
            }
            return arrayList;
        }

        @Override // org.allenai.ml.sequences.crf.CRFPredicateExtractor
        public List<ObjectDoubleMap<String>> nodePredicates(List<Row> list) {
            List<ObjectDoubleMap<String>> buildPredVals = buildPredVals(this.nodeTemplates, list);
            buildPredVals.set(0, ObjectDoubleHashMap.newMap());
            buildPredVals.set(buildPredVals.size() - 1, ObjectDoubleHashMap.newMap());
            return buildPredVals;
        }

        @Override // org.allenai.ml.sequences.crf.CRFPredicateExtractor
        public List<ObjectDoubleMap<String>> edgePredicates(List<Row> list) {
            return buildPredVals(this.edgeTemplates, list).subList(0, list.size() - 1);
        }

        @ConstructorProperties({"nodeTemplates", "edgeTemplates"})
        public ConllPredicateExtractor(List<FeatureTemplate> list, List<FeatureTemplate> list2) {
            this.nodeTemplates = list;
            this.edgeTemplates = list2;
        }

        static {
            $assertionsDisabled = !ConllFormat.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:org/allenai/ml/sequences/crf/conll/ConllFormat$FeatureTemplate.class */
    public static class FeatureTemplate {
        public final String prefix;
        public final ImmutableList<IntIntPair> rowCols;
        public final Type type;
        private static final Pattern rowColPattern = Pattern.compile("\\%x\\[(-?\\d+),(\\d+)\\]");

        /* loaded from: input_file:org/allenai/ml/sequences/crf/conll/ConllFormat$FeatureTemplate$Type.class */
        public enum Type {
            NODE,
            EDGE
        }

        public FeatureTemplate(String str, List<IntIntPair> list) {
            if (str.startsWith("U")) {
                this.type = Type.NODE;
            } else {
                if (!str.startsWith("B")) {
                    throw new IllegalArgumentException("FeatureTemplate prefix must begin with 'U' or 'B'");
                }
                this.type = Type.EDGE;
            }
            this.prefix = str;
            this.rowCols = Lists.immutable.ofAll(list);
        }

        public static FeatureTemplate fromLineSpec(String str) {
            int indexOf = str.indexOf(58);
            String substring = indexOf < 0 ? str : str.substring(0, indexOf);
            Stream of = Stream.of((Object[]) str.substring(indexOf + 1).split("/"));
            Pattern pattern = rowColPattern;
            pattern.getClass();
            return new FeatureTemplate(substring, (List) of.map((v1) -> {
                return r1.matcher(v1);
            }).filter((v0) -> {
                return v0.matches();
            }).map(matcher -> {
                return PrimitiveTuples.pair(Integer.parseInt(matcher.group(1)), Integer.parseInt(matcher.group(2)));
            }).collect(Collectors.toList()));
        }

        public String value(List<Row> list, int i) {
            if (this.rowCols.isEmpty()) {
                return this.prefix;
            }
            int size = list.size();
            ArrayList arrayList = new ArrayList(this.rowCols.size());
            for (IntIntPair intIntPair : this.rowCols) {
                int one = i + intIntPair.getOne();
                if (one < 0 || one >= size) {
                    arrayList.add("@_X" + one);
                } else {
                    Row row = list.get(one);
                    int two = intIntPair.getTwo();
                    if (two >= row.features.size()) {
                        arrayList.add("@_Y" + one);
                    } else {
                        arrayList.add((String) row.features.get(two));
                    }
                }
            }
            StringBuilder sb = new StringBuilder();
            sb.append(this.prefix);
            sb.append(':');
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                if (i2 > 0) {
                    sb.append('/');
                }
                sb.append((String) arrayList.get(i2));
            }
            return sb.toString();
        }

        public String toString() {
            if (this.rowCols.isEmpty()) {
                return this.prefix;
            }
            return this.prefix + ':' + ((String) this.rowCols.toList().stream().map(intIntPair -> {
                return String.format("%%x[%d,%d]", Integer.valueOf(intIntPair.getOne()), Integer.valueOf(intIntPair.getTwo()));
            }).collect(Collectors.joining("/")));
        }
    }

    /* loaded from: input_file:org/allenai/ml/sequences/crf/conll/ConllFormat$Row.class */
    public static class Row {
        public final ImmutableList<String> features;
        private final String label;

        public Row(List<String> list) {
            this(list, null);
        }

        public Optional<String> getLabel() {
            return Optional.ofNullable(this.label);
        }

        public Row(List<String> list, String str) {
            this.features = Lists.immutable.ofAll(list);
            this.label = str;
        }

        public Pair<String, Row> asLabeledPair() {
            if (this.label == null) {
                throw new RuntimeException("Must be a labeled example");
            }
            return Tuples.pair(this.label, this);
        }
    }

    private static List<List<String>> chunkedLines(Stream<String> stream) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        stream.forEach(str -> {
            if (!str.trim().isEmpty()) {
                arrayList2.add(str);
            } else {
                arrayList.add(new ArrayList(arrayList2));
                arrayList2.clear();
            }
        });
        return arrayList;
    }

    public static List<Row> readDatum(List<String> list, boolean z) {
        return readDatum(list, z, "\\s+");
    }

    public static List<Row> readDatum(List<String> list, boolean z, String str) {
        List<Row> list2 = (List) list.stream().map(str2 -> {
            List asList = Arrays.asList(str2.split(str));
            if (!z || asList.size() >= 2) {
                return z ? new Row(asList.subList(0, asList.size() - 1), (String) asList.get(asList.size() - 1)) : new Row(asList);
            }
            throw new IllegalArgumentException("Labeled row doesn't appear to have at least two columns");
        }).collect(Collectors.toList());
        list2.add(0, new Row(Arrays.asList(startState), startState));
        list2.add(new Row(Arrays.asList(stopState), stopState));
        return list2;
    }

    public static List<List<Row>> readData(Stream<String> stream, boolean z) {
        ArrayList arrayList = new ArrayList();
        Iterator<List<String>> it = chunkedLines(stream).iterator();
        while (it.hasNext()) {
            arrayList.add(readDatum(it.next(), z));
        }
        return arrayList;
    }

    public static CRFPredicateExtractor<Row, String> predicatesFromTemplate(Stream<String> stream) {
        List list = (List) stream.filter(str -> {
            return str.startsWith("U") || str.startsWith("B");
        }).map(FeatureTemplate::fromLineSpec).collect(Collectors.toList());
        return new ConllPredicateExtractor((List) list.stream().filter(featureTemplate -> {
            return featureTemplate.type == FeatureTemplate.Type.NODE;
        }).collect(Collectors.toList()), (List) list.stream().filter(featureTemplate2 -> {
            return featureTemplate2.type == FeatureTemplate.Type.EDGE;
        }).collect(Collectors.toList()));
    }

    public static void saveModel(DataOutputStream dataOutputStream, List<String> list, CRFFeatureEncoder<String, Row, String> cRFFeatureEncoder, Vector vector) throws IOException {
        dataOutputStream.writeUTF(DATA_VERSION);
        IOUtils.saveList(dataOutputStream, list);
        cRFFeatureEncoder.stateSpace.save(dataOutputStream);
        cRFFeatureEncoder.nodeFeatures.save(dataOutputStream);
        cRFFeatureEncoder.edgeFeatures.save(dataOutputStream);
        IOUtils.saveDoubles(dataOutputStream, vector.toDoubles());
    }

    public static CRFModel<String, Row, String> loadModel(DataInputStream dataInputStream) throws IOException {
        IOUtils.ensureVersionMatch(dataInputStream, DATA_VERSION);
        CRFPredicateExtractor<Row, String> predicatesFromTemplate = predicatesFromTemplate(IOUtils.loadList(dataInputStream).stream());
        StateSpace<String> load = StateSpace.load(dataInputStream);
        Indexer<String> load2 = Indexer.load(dataInputStream);
        Indexer<String> load3 = Indexer.load(dataInputStream);
        CRFFeatureEncoder cRFFeatureEncoder = new CRFFeatureEncoder(predicatesFromTemplate, load, load2, load3);
        CRFWeightsEncoder cRFWeightsEncoder = new CRFWeightsEncoder(load, load2.size(), load3.size());
        DenseVector of = DenseVector.of(IOUtils.loadDoubles(dataInputStream));
        Tuples.pair(new Row(Arrays.asList(startState)), new Row(Arrays.asList(stopState)));
        return new CRFModel<>(cRFFeatureEncoder, cRFWeightsEncoder, of);
    }

    public static void main(String[] strArr) {
        System.out.println(FeatureTemplate.fromLineSpec("U00:%x[-2,0]/%x[2,0]"));
    }
}
