package org.allenai.ml.classification;

import com.gs.collections.api.map.primitive.ObjectDoubleMap;
import com.gs.collections.api.tuple.Pair;
import com.gs.collections.impl.tuple.primitive.PrimitiveTuples;
import java.beans.ConstructorProperties;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.lang.Comparable;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import org.allenai.ml.linalg.DenseVector;
import org.allenai.ml.linalg.SparseVector;
import org.allenai.ml.linalg.Vector;
import org.allenai.ml.objective.BatchObjectiveFn;
import org.allenai.ml.optimize.CachingGradientFn;
import org.allenai.ml.optimize.NewtonMethod;
import org.allenai.ml.optimize.QuasiNewton;
import org.allenai.ml.optimize.Regularizer;
import org.allenai.ml.util.IOUtils;
import org.allenai.ml.util.Indexer;
import org.allenai.ml.util.Parallel;

/* loaded from: input_file:org/allenai/ml/classification/MaxEntModel.class */
public class MaxEntModel<L extends Comparable<L>, D, F extends Comparable<F>> implements ProbabilisticClassifier<D, L> {
    private final Indexer<F> featureIndexer;
    private final Indexer<L> classIndexer;
    private final Vector weights;
    private final FeatureExtractor<D, F> featureExtractor;

    /* loaded from: input_file:org/allenai/ml/classification/MaxEntModel$TrainOpts.class */
    public static class TrainOpts {
        public double sigmaSq;
        public int minExpectedFeatureCount = 0;
        public int numThreads = 1;
        public long randSeed = 0;
        public NewtonMethod.Opts optimizerOpts = null;
    }

    @Override // org.allenai.ml.classification.ProbabilisticClassifier
    public ObjectDoubleMap<L> probabilities(D d) {
        return this.classIndexer.toMap(DenseVector.of(MaxEntObjective.classProbs(SparseVector.indexed(this.featureExtractor.features(d), this.featureIndexer), this.weights, this.classIndexer.size())));
    }

    public static <D> MaxEntModel<String, D, String> load(DataInputStream dataInputStream, FeatureExtractor<D, String> featureExtractor) {
        return new MaxEntModel<>(Indexer.load(dataInputStream), Indexer.load(dataInputStream), DenseVector.of(IOUtils.loadDoubles(dataInputStream)), featureExtractor);
    }

    public void save(DataOutputStream dataOutputStream) {
        this.featureIndexer.save(dataOutputStream);
        this.classIndexer.save(dataOutputStream);
        IOUtils.saveDoubles(dataOutputStream, this.weights.toDoubles());
    }

    public static <D> MaxEntModel<String, D, String> train(List<Pair<D, String>> list, FeatureExtractor<D, String> featureExtractor, TrainOpts trainOpts) {
        Random random = new Random(trainOpts.randSeed);
        double d = trainOpts.minExpectedFeatureCount > 0 ? 1.0d / trainOpts.minExpectedFeatureCount : 1.0d;
        Indexer fromStream = Indexer.fromStream(list.stream().flatMap(pair -> {
            return featureExtractor.features(pair.getOne()).keySet().stream();
        }).filter(str -> {
            return random.nextDouble() < d;
        }));
        Indexer fromStream2 = Indexer.fromStream(list.stream().map((v0) -> {
            return v0.getTwo();
        }));
        MaxEntObjective maxEntObjective = new MaxEntObjective(fromStream2.size());
        long size = fromStream.size() * fromStream2.size();
        List list2 = (List) list.stream().map(pair2 -> {
            return PrimitiveTuples.pair(fromStream2.indexOf(pair2.getTwo()), SparseVector.indexed(featureExtractor.features(pair2.getOne()), fromStream));
        }).collect(Collectors.toList());
        Parallel.MROpts withIdAndThreads = Parallel.MROpts.withIdAndThreads("mr-max-ent-train", trainOpts.numThreads);
        BatchObjectiveFn batchObjectiveFn = new BatchObjectiveFn(list2, maxEntObjective, size, withIdAndThreads);
        CachingGradientFn cachingGradientFn = new CachingGradientFn(3, batchObjectiveFn.add(Regularizer.l2(batchObjectiveFn.dimension(), trainOpts.sigmaSq)));
        QuasiNewton lbfgs = QuasiNewton.lbfgs(3);
        Vector vector = new NewtonMethod(gradientFn -> {
            return lbfgs;
        }, trainOpts.optimizerOpts != null ? trainOpts.optimizerOpts : new NewtonMethod.Opts()).minimize(cachingGradientFn).xmin;
        Parallel.shutdownExecutor(withIdAndThreads.executorService, Long.MAX_VALUE);
        return new MaxEntModel<>(fromStream, fromStream2, vector, featureExtractor);
    }

    @ConstructorProperties({"featureIndexer", "classIndexer", "weights", "featureExtractor"})
    public MaxEntModel(Indexer<F> indexer, Indexer<L> indexer2, Vector vector, FeatureExtractor<D, F> featureExtractor) {
        this.featureIndexer = indexer;
        this.classIndexer = indexer2;
        this.weights = vector;
        this.featureExtractor = featureExtractor;
    }
}
