package org.allenai.ml.sequences.crf.conll;

import com.gs.collections.api.tuple.Pair;
import com.gs.collections.impl.tuple.Tuples;
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collectors;
import org.allenai.ml.eval.TrainCriterionEval;
import org.allenai.ml.linalg.Vector;
import org.allenai.ml.sequences.Evaluation;
import org.allenai.ml.sequences.crf.CRFModel;
import org.allenai.ml.sequences.crf.CRFPredicateExtractor;
import org.allenai.ml.sequences.crf.CRFTrainer;
import org.allenai.ml.sequences.crf.conll.ConllFormat;
import org.allenai.ml.util.IOUtils;
import org.allenai.ml.util.Parallel;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/allenai/ml/sequences/crf/conll/Trainer.class */
public class Trainer {
    private static final Logger logger = LoggerFactory.getLogger(Trainer.class);

    /* loaded from: input_file:org/allenai/ml/sequences/crf/conll/Trainer$Opts.class */
    public static class Opts {

        @Option(name = "-featureTemplates", usage = "FeatureTemplate template pattern file", required = true)
        public String templateFile;

        @Option(name = "-trainData", usage = "Path to training data", required = true)
        public String trainPath;

        @Option(name = "-modelSave", usage = "where to write model", required = true)
        public String modelPath;

        @Option(name = "-sigmaSq", usage = "L2 regularization to use")
        public double sigmaSquared = 1.0d;

        @Option(name = "-numThreads", usage = "number of threads to train with")
        public int numThreads = 1;

        @Option(name = "-featureKeepProb", usage = "probability of keeping a feature predicate")
        public double featureKeepProb = 1.0d;

        @Option(name = "-maxTrainIters", usage = "max number of train iterations")
        public int maxIterations = Integer.MAX_VALUE;

        @Option(name = "-lbfgsHistorySize", usage = "history size for LBFGS")
        public int lbfgsHistorySize = 3;

        @Option(name = "-testSplitRatio", usage = "Data to hold for eval ever iter")
        public double testSplitRatio = 0.2d;

        @Option(name = "-maxNumDipIters", usage = "How many iterations after test eval drop to continue training")
        public int maxNumDipIters = 3;
    }

    private static <T> Pair<List<T>, List<T>> splitData(List<T> list, double d) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        if (d > 0.0d) {
            Collections.shuffle(list, new Random(0L));
            int size = (int) ((1.0d - d) * list.size());
            arrayList.addAll(list.subList(0, size));
            arrayList2.addAll(list.subList(size, list.size()));
        } else {
            arrayList.addAll(list);
        }
        return Tuples.pair(arrayList, arrayList2);
    }

    public static void trainAndSaveModel(Opts opts) {
        List list = (List) IOUtils.linesFromPath(opts.templateFile).collect(Collectors.toList());
        logger.info("Loading train data from {}", opts.trainPath);
        CRFPredicateExtractor<ConllFormat.Row, String> predicatesFromTemplate = ConllFormat.predicatesFromTemplate(list.stream());
        List list2 = (List) ConllFormat.readData(IOUtils.linesFromPath(opts.trainPath), true).stream().map(list3 -> {
            return (List) list3.stream().map(row -> {
                return row.asLabeledPair().swap();
            }).collect(Collectors.toList());
        }).collect(Collectors.toList());
        logger.info("CRF training with {} threads and {} labeled examples", Integer.valueOf(opts.numThreads), Integer.valueOf(list2.size()));
        Pair splitData = splitData(list2, opts.testSplitRatio);
        List list4 = (List) splitData.getOne();
        List list5 = (List) splitData.getTwo();
        CRFTrainer.Opts opts2 = new CRFTrainer.Opts();
        opts2.sigmaSq = opts.sigmaSquared;
        opts2.lbfgsHistorySize = opts.lbfgsHistorySize;
        opts2.optimizerOpts.maxIters = opts.maxIterations;
        opts2.minExpectedFeatureCount = (int) (1.0d / opts.featureKeepProb);
        opts2.numThreads = opts.numThreads;
        CRFTrainer cRFTrainer = new CRFTrainer(list4, predicatesFromTemplate, opts2);
        Parallel.MROpts withIdAndThreads = Parallel.MROpts.withIdAndThreads("mr-crf-train-eval", opts.numThreads);
        List list6 = (List) list4.stream().map(list7 -> {
            return (List) list7.stream().map((v0) -> {
                return v0.swap();
            }).collect(Collectors.toList());
        }).collect(Collectors.toList());
        List list8 = (List) list5.stream().map(list9 -> {
            return (List) list9.stream().map((v0) -> {
                return v0.swap();
            }).collect(Collectors.toList());
        }).collect(Collectors.toList());
        ToDoubleFunction toDoubleFunction = cRFModel -> {
            return Evaluation.compute(cRFModel, list6, withIdAndThreads).tokenAccuracy.accuracy();
        };
        TrainCriterionEval trainCriterionEval = new TrainCriterionEval(cRFModel2 -> {
            return Evaluation.compute(cRFModel2, list8, withIdAndThreads).tokenAccuracy.accuracy();
        });
        trainCriterionEval.maxNumDipIters = opts.maxNumDipIters;
        opts2.iterCallback = cRFModel3 -> {
            logger.info("Train Accuracy: {}", Double.valueOf(toDoubleFunction.applyAsDouble(cRFModel3)));
            return trainCriterionEval.test(cRFModel3);
        };
        cRFTrainer.train(list4);
        CRFModel cRFModel4 = (CRFModel) trainCriterionEval.getBestModel();
        Parallel.shutdownExecutor(withIdAndThreads.executorService, Long.MAX_VALUE);
        Vector weights = cRFModel4.weights();
        DataOutputStream dataOutputStream = new DataOutputStream(new FileOutputStream(opts.modelPath));
        logger.info("Writing model to {}", opts.modelPath);
        ConllFormat.saveModel(dataOutputStream, list, cRFModel4.featureEncoder, weights);
    }

    public static void main(String[] strArr) {
        Opts opts = new Opts();
        CmdLineParser cmdLineParser = new CmdLineParser(opts);
        try {
            cmdLineParser.parseArgument(strArr);
        } catch (CmdLineException e) {
            cmdLineParser.printUsage(System.err);
            System.exit(2);
        }
        trainAndSaveModel(opts);
    }
}
