package org.allenai.ml.classification;

import com.gs.collections.api.tuple.primitive.IntObjectPair;
import java.beans.ConstructorProperties;
import org.allenai.ml.linalg.Vector;
import org.allenai.ml.math.SloppyMath;
import org.allenai.ml.objective.ExampleObjectiveFn;

/* loaded from: input_file:org/allenai/ml/classification/MaxEntObjective.class */
public class MaxEntObjective implements ExampleObjectiveFn<IntObjectPair<Vector>> {
    private final int numClasses;

    public static int weightIdx(int i, int i2, int i3) {
        return (i * i3) + i2;
    }

    @Override // org.allenai.ml.objective.ExampleObjectiveFn
    public double evaluate(IntObjectPair<Vector> intObjectPair, Vector vector, Vector vector2) {
        int one = intObjectPair.getOne();
        double[] classProbs = classProbs((Vector) intObjectPair.getTwo(), vector, this.numClasses);
        Vector.Iterator it = ((Vector) intObjectPair.getTwo()).iterator();
        while (!it.isExhausted()) {
            int index = (int) it.index();
            double value = it.value();
            vector2.inc(weightIdx(index, one, this.numClasses), value);
            for (int i = 0; i < this.numClasses; i++) {
                vector2.inc(weightIdx(index, i, this.numClasses), (-value) * classProbs[i]);
            }
            it.advance();
        }
        return Math.log(classProbs[one]);
    }

    public static double[] classProbs(Vector vector, Vector vector2, int i) {
        double[] dArr = new double[i];
        Vector.Iterator it = vector.iterator();
        while (!it.isExhausted()) {
            int index = (int) it.index();
            double value = it.value();
            for (int i2 = 0; i2 < i; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (vector2.at(weightIdx(index, i2, i)) * value);
            }
            it.advance();
        }
        double logSumExp = SloppyMath.logSumExp(dArr);
        for (int i4 = 0; i4 < i; i4++) {
            dArr[i4] = SloppyMath.sloppyExp(dArr[i4] - logSumExp);
        }
        return dArr;
    }

    @ConstructorProperties({"numClasses"})
    public MaxEntObjective(int i) {
        this.numClasses = i;
    }
}
