package com.aliasi.dca;

import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Matrices;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Math;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.Formatter;
import java.util.IllegalFormatException;
import java.util.Locale;

/* loaded from: input_file:com/aliasi/dca/DiscreteChooser.class */
public class DiscreteChooser implements Serializable {
    static final long serialVersionUID = 9199242060691577692L;
    private final Vector mCoefficients;

    /* loaded from: input_file:com/aliasi/dca/DiscreteChooser$Externalizer.class */
    static class Externalizer extends AbstractExternalizable {
        static final long serialVersionUID = -8567713287299117186L;
        private final DiscreteChooser mChooser;

        public Externalizer() {
            this(null);
        }

        public Externalizer(DiscreteChooser discreteChooser) {
            this.mChooser = discreteChooser;
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            objectOutput.writeObject(this.mChooser.mCoefficients);
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws IOException, ClassNotFoundException {
            return new DiscreteChooser((Vector) objectInput.readObject());
        }
    }

    public DiscreteChooser(Vector vector) {
        this.mCoefficients = vector;
    }

    public int choose(Vector[] vectorArr) {
        verifyNonEmpty(vectorArr);
        if (vectorArr.length == 1) {
            return 0;
        }
        int i = 0;
        double linearBasis = linearBasis(vectorArr[0]);
        for (int i2 = 1; i2 < vectorArr.length; i2++) {
            double linearBasis2 = linearBasis(vectorArr[i2]);
            if (linearBasis2 > linearBasis) {
                linearBasis = linearBasis2;
                i = i2;
            }
        }
        return i;
    }

    public double[] choiceProbs(Vector[] vectorArr) {
        verifyNonEmpty(vectorArr);
        double[] choiceLogProbs = choiceLogProbs(vectorArr);
        for (int i = 0; i < choiceLogProbs.length; i++) {
            choiceLogProbs[i] = Math.exp(choiceLogProbs[i]);
        }
        return choiceLogProbs;
    }

    public double[] choiceLogProbs(Vector[] vectorArr) {
        verifyNonEmpty(vectorArr);
        double[] dArr = new double[vectorArr.length];
        for (int i = 0; i < vectorArr.length; i++) {
            dArr[i] = this.mCoefficients.dotProduct(vectorArr[i]);
        }
        double logSumOfExponentials = Math.logSumOfExponentials(dArr);
        for (int i2 = 0; i2 < vectorArr.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] - logSumOfExponentials;
        }
        return dArr;
    }

    public Vector coefficients() {
        return Matrices.unmodifiableVector(this.mCoefficients);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("DiscreteChoose(");
        int[] nonZeroDimensions = this.mCoefficients.nonZeroDimensions();
        for (int i = 0; i < nonZeroDimensions.length; i++) {
            int i2 = nonZeroDimensions[i];
            if (i > 0) {
                sb.append(",");
            }
            sb.append(Integer.toString(i2));
            sb.append('=');
            sb.append(Double.toString(this.mCoefficients.value(i2)));
        }
        sb.append(")");
        return sb.toString();
    }

    double linearBasis(Vector vector) {
        return vector.dotProduct(this.mCoefficients);
    }

    Object writeReplace() {
        return new Externalizer(this);
    }

    public static DiscreteChooser estimate(Vector[][] vectorArr, int[] iArr, RegressionPrior regressionPrior, int i, AnnealingSchedule annealingSchedule, double d, int i2, int i3, Reporter reporter) {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        int length = vectorArr.length;
        reporter.info("estimate()");
        reporter.info("# training cases=" + length);
        reporter.info("regression prior=" + regressionPrior);
        reporter.info("annealing schedule=" + annealingSchedule);
        reporter.info("min improvement=" + d);
        reporter.info("min epochs=" + i2);
        reporter.info("max epochs=" + i3);
        if (vectorArr.length == 0) {
            throw new IllegalArgumentException("Require at least 1 training instance.   Found alternativess.length=0");
        }
        if (vectorArr.length != iArr.length) {
            throw new IllegalArgumentException("Alternatives and choices must be the same length. Found alternativess.length=" + vectorArr.length + " choices.length=" + iArr.length);
        }
        for (int i4 = 0; i4 < vectorArr.length; i4++) {
            if (vectorArr[i4].length < 1) {
                throw new IllegalArgumentException("Require at least one alternative. Found alternativess[" + i4 + "].length=0");
            }
        }
        for (int i5 = 0; i5 < vectorArr.length; i5++) {
            if (iArr[i5] < 0) {
                throw new IllegalArgumentException("Choices must be non-negative. Found choices[" + i5 + "]=" + iArr[i5]);
            }
            if (iArr[i5] > vectorArr[i5].length) {
                throw new IllegalArgumentException("Choices must be less than alts length. Found choices[" + i5 + "]=" + iArr[i5] + " alternativess[" + i5 + "].length=" + vectorArr.length + ".");
            }
        }
        int numDimensions = vectorArr[0][0].numDimensions();
        for (int i6 = 0; i6 < vectorArr.length; i6++) {
            for (int i7 = 0; i7 < vectorArr[i6].length; i7++) {
                if (numDimensions != vectorArr[i6][i7].numDimensions()) {
                    throw new IllegalArgumentException("All alternatives must be same length. alternativess[0][0].length=" + numDimensions + " alternativess[" + i6 + "][" + i7 + "]=" + vectorArr[i6][i7] + ".");
                }
            }
        }
        DenseVector denseVector = new DenseVector(numDimensions);
        DiscreteChooser discreteChooser = new DiscreteChooser(denseVector);
        double d2 = Double.NaN;
        double d3 = 1.0d;
        double d4 = Double.NEGATIVE_INFINITY;
        int i8 = 0;
        while (true) {
            if (i8 >= i3) {
                break;
            }
            double learningRate = annealingSchedule.learningRate(i8);
            for (int i9 = 0; i9 < length; i9++) {
                Vector[] vectorArr2 = vectorArr[i9];
                int i10 = iArr[i9];
                double[] choiceProbs = discreteChooser.choiceProbs(vectorArr2);
                int i11 = 0;
                while (i11 < vectorArr2.length) {
                    double d5 = i10 == i11 ? choiceProbs[i11] - 1.0d : choiceProbs[i11];
                    if (d5 != 0.0d) {
                        denseVector.increment((-learningRate) * d5, vectorArr2[i11]);
                    }
                    i11++;
                }
                if (i9 % i == 0) {
                    updatePrior(regressionPrior, denseVector, (learningRate * i) / length);
                }
            }
            updatePrior(regressionPrior, denseVector, (learningRate * (length % i)) / length);
            double logLikelihood = logLikelihood(discreteChooser, vectorArr, iArr);
            double logBase2ToNaturalLog = Math.logBase2ToNaturalLog(regressionPrior.log2Prior(denseVector));
            double d6 = logLikelihood + logBase2ToNaturalLog;
            if (d6 > d4) {
                d4 = d6;
            }
            if (i8 > 0) {
                d3 = ((9.0d * d3) + Math.relativeAbsoluteDifference(d2, d6)) / 10.0d;
            }
            d2 = d6;
            if (reporter.isDebugEnabled()) {
                Formatter formatter = null;
                try {
                    try {
                        formatter = new Formatter(Locale.ENGLISH);
                        formatter.format("epoch=%5d lr=%11.9f ll=%11.4f lp=%11.4f llp=%11.4f llp*=%11.4f", Integer.valueOf(i8), Double.valueOf(learningRate), Double.valueOf(logLikelihood), Double.valueOf(logBase2ToNaturalLog), Double.valueOf(d6), Double.valueOf(d4));
                        reporter.debug(formatter.toString());
                        if (formatter != null) {
                            formatter.close();
                        }
                    } catch (IllegalFormatException e) {
                        reporter.warn("Illegal format in discrete chooser");
                        if (formatter != null) {
                            formatter.close();
                        }
                    }
                } catch (Throwable th) {
                    if (formatter != null) {
                        formatter.close();
                    }
                    throw th;
                }
            }
            if (d3 < d) {
                reporter.info("Converged with rollingAverageRelativeDiff=" + d3);
                break;
            }
            i8++;
        }
        return discreteChooser;
    }

    static void updatePrior(RegressionPrior regressionPrior, Vector vector, double d) {
        if (regressionPrior.isUniform()) {
            return;
        }
        int numDimensions = vector.numDimensions();
        for (int i = 0; i < numDimensions; i++) {
            double mode = regressionPrior.mode(i);
            double value = vector.value(i);
            if (value != mode) {
                double gradient = d * regressionPrior.gradient(value, i);
                if (value != 0.0d) {
                    vector.setValue(i, value > 0.0d ? Math.max(0.0d, value - gradient) : Math.min(0.0d, value - gradient));
                }
            }
        }
    }

    static double logLikelihood(DiscreteChooser discreteChooser, Vector[][] vectorArr, int[] iArr) {
        double d = 0.0d;
        for (int i = 0; i < vectorArr.length; i++) {
            d += logLikelihood(discreteChooser, vectorArr[i], iArr[i]);
        }
        return d;
    }

    static double logLikelihood(DiscreteChooser discreteChooser, Vector[] vectorArr, int i) {
        return discreteChooser.choiceLogProbs(vectorArr)[i];
    }

    static void verifyNonEmpty(Vector[] vectorArr) {
        if (vectorArr.length <= 0) {
            throw new IllegalArgumentException("Require at least one choice. Found choices.length=0.");
        }
    }
}
