package org.allenai.ml.sequences.crf;

import com.gs.collections.api.map.primitive.ObjectDoubleMap;
import com.gs.collections.api.tuple.Pair;
import java.lang.Comparable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.allenai.ml.linalg.SparseVector;
import org.allenai.ml.linalg.Vector;
import org.allenai.ml.sequences.StateSpace;
import org.allenai.ml.util.Indexer;
import org.allenai.ml.util.Parallel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/allenai/ml/sequences/crf/CRFFeatureEncoder.class */
public class CRFFeatureEncoder<S, O, F extends Comparable<F>> {
    private static final Logger log = LoggerFactory.getLogger(CRFFeatureEncoder.class);
    private final CRFPredicateExtractor<O, F> predicateExtractor;
    public final StateSpace<S> stateSpace;
    public final Indexer<F> nodeFeatures;
    public final Indexer<F> edgeFeatures;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.allenai.ml.sequences.crf.CRFFeatureEncoder$1IndexData, reason: invalid class name */
    /* loaded from: input_file:org/allenai/ml/sequences/crf/CRFFeatureEncoder$1IndexData.class */
    public class C1IndexData {
        private final Set<F> nodeFeatures = new HashSet();
        private final Set<F> edgeFeatures = new HashSet();
        private final Random rand;
        final /* synthetic */ BuildOpts val$opts;

        C1IndexData(BuildOpts buildOpts) {
            this.val$opts = buildOpts;
            this.rand = new Random(this.val$opts.randSeed);
        }
    }

    /* loaded from: input_file:org/allenai/ml/sequences/crf/CRFFeatureEncoder$BuildOpts.class */
    public static class BuildOpts {
        private long randSeed;
        private int numThreads;
        private double probabilityToAccept;

        /* loaded from: input_file:org/allenai/ml/sequences/crf/CRFFeatureEncoder$BuildOpts$BuildOptsBuilder.class */
        public static class BuildOptsBuilder {
            private long randSeed;
            private int numThreads;
            private double probabilityToAccept;

            BuildOptsBuilder() {
            }

            public BuildOptsBuilder randSeed(long j) {
                this.randSeed = j;
                return this;
            }

            public BuildOptsBuilder numThreads(int i) {
                this.numThreads = i;
                return this;
            }

            public BuildOptsBuilder probabilityToAccept(double d) {
                this.probabilityToAccept = d;
                return this;
            }

            public BuildOpts build() {
                return new BuildOpts(this.randSeed, this.numThreads, this.probabilityToAccept);
            }

            public String toString() {
                return "CRFFeatureEncoder.BuildOpts.BuildOptsBuilder(randSeed=" + this.randSeed + ", numThreads=" + this.numThreads + ", probabilityToAccept=" + this.probabilityToAccept + ")";
            }
        }

        BuildOpts(long j, int i, double d) {
            this.randSeed = 0L;
            this.numThreads = 1;
            this.probabilityToAccept = 1.0d;
            this.randSeed = j;
            this.numThreads = i;
            this.probabilityToAccept = d;
        }

        public static BuildOptsBuilder builder() {
            return new BuildOptsBuilder();
        }
    }

    public CRFFeatureEncoder(CRFPredicateExtractor<O, F> cRFPredicateExtractor, StateSpace<S> stateSpace, Indexer<F> indexer, Indexer<F> indexer2) {
        this.predicateExtractor = cRFPredicateExtractor;
        this.stateSpace = stateSpace;
        this.nodeFeatures = indexer;
        this.edgeFeatures = indexer2;
    }

    public CRFIndexedExample indexedExample(List<O> list) {
        return new CRFIndexedExample(indexFeatures(this.predicateExtractor.nodePredicates(list), this.nodeFeatures), indexFeatures(this.predicateExtractor.edgePredicates(list), this.edgeFeatures));
    }

    private static <F extends Comparable<F>> List<Vector> indexFeatures(List<ObjectDoubleMap<F>> list, Indexer<F> indexer) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<ObjectDoubleMap<F>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(SparseVector.indexed(it.next(), indexer));
        }
        return arrayList;
    }

    public CRFIndexedExample indexLabeledExample(List<Pair<O, S>> list) {
        List<O> list2 = (List) list.stream().map((v0) -> {
            return v0.getOne();
        }).collect(Collectors.toList());
        List<Vector> indexFeatures = indexFeatures(this.predicateExtractor.nodePredicates(list2), this.nodeFeatures);
        List<Vector> indexFeatures2 = indexFeatures(this.predicateExtractor.edgePredicates(list2), this.edgeFeatures);
        Stream<R> map = list.stream().map((v0) -> {
            return v0.getTwo();
        });
        StateSpace<S> stateSpace = this.stateSpace;
        stateSpace.getClass();
        int[] array = map.mapToInt(stateSpace::stateIndex).toArray();
        if (array[0] != this.stateSpace.startStateIndex()) {
            throw new IllegalArgumentException("Must use StateSpace startState to start sequence, instead got " + list.get(0).getTwo());
        }
        if (array[list.size() - 1] != this.stateSpace.stopStateIndex()) {
            throw new IllegalArgumentException("Must use StateSpace stopState to end sequence, instead got " + list.get(list.size() - 1).getTwo());
        }
        return new CRFIndexedExample(indexFeatures, indexFeatures2, array);
    }

    public static <S, O, F extends Comparable<F>> CRFFeatureEncoder build(List<List<O>> list, final CRFPredicateExtractor<O, F> cRFPredicateExtractor, StateSpace<S> stateSpace, final BuildOpts buildOpts) {
        log.info("Indexing features with {} prob to keep and {} threads", Double.valueOf(buildOpts.probabilityToAccept), Integer.valueOf(buildOpts.numThreads));
        Parallel.MROpts withIdAndThreads = Parallel.MROpts.withIdAndThreads("mr-feature-index", buildOpts.numThreads);
        C1IndexData c1IndexData = (C1IndexData) Parallel.mapReduce(list, new Parallel.MapReduceDriver<List<O>, C1IndexData>() { // from class: org.allenai.ml.sequences.crf.CRFFeatureEncoder.1IndexWorker
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.allenai.ml.util.Parallel.MapReduceDriver
            public C1IndexData newData() {
                return new C1IndexData(BuildOpts.this);
            }

            @Override // org.allenai.ml.util.Parallel.MapReduceDriver
            public void update(C1IndexData c1IndexData2, List<O> list2) {
                stochasticAddAll(c1IndexData2.rand, c1IndexData2.nodeFeatures, cRFPredicateExtractor.nodePredicates(list2));
                stochasticAddAll(c1IndexData2.rand, c1IndexData2.edgeFeatures, cRFPredicateExtractor.edgePredicates(list2));
            }

            @Override // org.allenai.ml.util.Parallel.MapReduceDriver
            public void merge(C1IndexData c1IndexData2, C1IndexData c1IndexData3) {
                c1IndexData2.nodeFeatures.addAll(c1IndexData3.nodeFeatures);
                c1IndexData2.edgeFeatures.addAll(c1IndexData3.edgeFeatures);
            }

            /* JADX WARN: Multi-variable type inference failed */
            private void stochasticAddAll(Random random, Set<F> set, List<ObjectDoubleMap<F>> list2) {
                for (ObjectDoubleMap<F> objectDoubleMap : list2) {
                    Collections.sort(new ArrayList(objectDoubleMap.keySet()));
                    for (Comparable comparable : objectDoubleMap.keysView()) {
                        if (random.nextDouble() < BuildOpts.this.probabilityToAccept) {
                            set.add(comparable);
                        }
                    }
                }
            }
        }, withIdAndThreads);
        Parallel.shutdownExecutor(withIdAndThreads.executorService, Long.MAX_VALUE);
        return new CRFFeatureEncoder(cRFPredicateExtractor, stateSpace, Indexer.fromStream(c1IndexData.nodeFeatures.stream()), Indexer.fromStream(c1IndexData.edgeFeatures.stream()));
    }
}
