package org.allenai.ml.sequences.crf.conll;

import com.gs.collections.api.tuple.Pair;
import com.gs.collections.impl.tuple.Tuples;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.util.List;
import java.util.stream.Collectors;
import org.allenai.ml.sequences.Evaluation;
import org.allenai.ml.sequences.crf.CRFModel;
import org.allenai.ml.sequences.crf.conll.ConllFormat;
import org.allenai.ml.util.IOUtils;
import org.allenai.ml.util.Parallel;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/allenai/ml/sequences/crf/conll/Evaluator.class */
public class Evaluator {
    private static final Logger log = LoggerFactory.getLogger(Evaluator.class);

    /* loaded from: input_file:org/allenai/ml/sequences/crf/conll/Evaluator$Opts.class */
    public static class Opts {

        @Option(name = "-model", usage = "where to read model", required = true)
        public String modelPath;

        @Option(name = "-data", usage = "where to read data", required = true)
        public String dataPath;
    }

    public static Pair<Double, Double> evaluateModel(Opts opts) {
        CRFModel<String, ConllFormat.Row, String> loadModel = ConllFormat.loadModel(new DataInputStream(new FileInputStream(opts.modelPath)));
        List<List<ConllFormat.Row>> readData = ConllFormat.readData(IOUtils.linesFromPath(opts.dataPath), true);
        long currentTimeMillis = System.currentTimeMillis();
        List list = (List) readData.stream().map(list2 -> {
            return (List) list2.stream().map((v0) -> {
                return v0.asLabeledPair();
            }).collect(Collectors.toList());
        }).collect(Collectors.toList());
        Parallel.MROpts withIdAndThreads = Parallel.MROpts.withIdAndThreads("mr-test-eval", 1);
        double accuracy = Evaluation.compute(loadModel, list, withIdAndThreads).tokenAccuracy.accuracy();
        long currentTimeMillis2 = System.currentTimeMillis();
        Parallel.shutdownExecutor(withIdAndThreads.executorService, Long.MAX_VALUE);
        return Tuples.pair(Double.valueOf(accuracy), Double.valueOf((currentTimeMillis2 - currentTimeMillis) / readData.size()));
    }

    public static void main(String[] strArr) {
        Opts opts = new Opts();
        CmdLineParser cmdLineParser = new CmdLineParser(opts);
        try {
            cmdLineParser.parseArgument(strArr);
        } catch (CmdLineException e) {
            cmdLineParser.printUsage(System.err);
            System.exit(2);
        }
        Pair<Double, Double> evaluateModel = evaluateModel(opts);
        log.info("Accuracy: {}", evaluateModel.getOne());
        log.info("Inference avg ms per example: {}", evaluateModel.getTwo());
    }
}
