package org.allenai.ml.eval;

import java.util.function.Predicate;
import java.util.function.ToDoubleFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/allenai/ml/eval/TrainCriterionEval.class */
public class TrainCriterionEval<M> implements Predicate<M> {
    private final ToDoubleFunction<M> baseEvalFn;
    private static final Logger log = LoggerFactory.getLogger(TrainCriterionEval.class);
    private int numIters = 0;
    private int numDipIters = -1;
    private double lastVal = Double.NEGATIVE_INFINITY;
    public M bestModel = null;
    public double dipTolerance = 1.0E-4d;
    public int maxNumDipIters = 0;

    public TrainCriterionEval(ToDoubleFunction<M> toDoubleFunction) {
        this.baseEvalFn = toDoubleFunction;
    }

    public M getBestModel() {
        return this.bestModel;
    }

    @Override // java.util.function.Predicate
    public boolean test(M m) {
        double applyAsDouble = this.baseEvalFn.applyAsDouble(m);
        log.info(String.format("[Iteration %d] Eval metric: %.3f", Integer.valueOf(this.numIters), Double.valueOf(applyAsDouble)));
        double d = this.lastVal - applyAsDouble;
        this.numIters++;
        if (d <= this.dipTolerance) {
            this.numDipIters = -1;
            this.lastVal = applyAsDouble;
            this.bestModel = m;
            return true;
        }
        this.numDipIters = this.numDipIters >= 0 ? this.numDipIters + 1 : 1;
        if (this.numDipIters > this.maxNumDipIters) {
            log.info("Exceeded max dip iters, bailing");
            return false;
        }
        log.info("Another down iteration, waiting " + (this.maxNumDipIters - this.numDipIters) + " more iters");
        return true;
    }
}
