/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.stats;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Matrices;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.stats.Statistics;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.Math;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Formatter;
import java.util.IllegalFormatException;
import java.util.Locale;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class LogisticRegression
implements Compilable,
Serializable {
    static final long serialVersionUID = -8585743596322227589L;
    private final Vector[] mWeightVectors;

    public LogisticRegression(Vector[] weightVectors) {
        if (weightVectors.length < 1) {
            String msg = "Require at least one weight vector.";
            throw new IllegalArgumentException(msg);
        }
        int numDimensions = weightVectors[0].numDimensions();
        for (int k = 1; k < weightVectors.length; ++k) {
            if (numDimensions == weightVectors[k].numDimensions()) continue;
            String msg = "All weight vectors must be same dimensionality. Found weightVectors[0].numDimensions()=" + numDimensions + " weightVectors[" + k + "]=" + weightVectors[k].numDimensions();
            throw new IllegalArgumentException(msg);
        }
        this.mWeightVectors = weightVectors;
    }

    public LogisticRegression(Vector weightVector) {
        this.mWeightVectors = new Vector[]{weightVector};
    }

    public int numInputDimensions() {
        return this.mWeightVectors[0].numDimensions();
    }

    public int numOutcomes() {
        return this.mWeightVectors.length + 1;
    }

    public Vector[] weightVectors() {
        Vector[] immutables = new Vector[this.mWeightVectors.length];
        for (int i = 0; i < immutables.length; ++i) {
            immutables[i] = Matrices.unmodifiableVector(this.mWeightVectors[i]);
        }
        return immutables;
    }

    public double[] classify(Vector x) {
        double[] ysHat = new double[this.numOutcomes()];
        this.classify(x, ysHat);
        return ysHat;
    }

    public void classify(Vector x, double[] ysHat) {
        int k;
        if (this.numInputDimensions() != x.numDimensions()) {
            String msg = "Vector and classifer must be of same dimensionality. Regression model this.numInputDimensions()=" + this.numInputDimensions() + " Vector x.numDimensions()=" + x.numDimensions();
            throw new IllegalArgumentException(msg);
        }
        int numOutcomesMinus1 = ysHat.length - 1;
        ysHat[numOutcomesMinus1] = 0.0;
        double max = 0.0;
        for (int k2 = 0; k2 < numOutcomesMinus1; ++k2) {
            ysHat[k2] = x.dotProduct(this.mWeightVectors[k2]);
            if (!(ysHat[k2] > max)) continue;
            max = ysHat[k2];
        }
        double z = 0.0;
        for (k = 0; k < ysHat.length; ++k) {
            ysHat[k] = java.lang.Math.exp(ysHat[k] - max);
            z += ysHat[k];
        }
        k = 0;
        while (k < ysHat.length) {
            int n = k++;
            ysHat[n] = ysHat[n] / z;
        }
    }

    @Override
    public void compileTo(ObjectOutput out) throws IOException {
        out.writeObject(new Externalizer(this));
    }

    Object writeReplace() {
        return new Externalizer(this);
    }

    public static LogisticRegression estimate(Vector[] xs, int[] cs, RegressionPrior prior, AnnealingSchedule annealingSchedule, Reporter reporter, double minImprovement, int minEpochs, int maxEpochs) {
        LogisticRegression hotStart = null;
        ObjectHandler<LogisticRegression> handler = null;
        int rollingAverageSize = 10;
        int priorBlockSize = java.lang.Math.max(1, cs.length / 50);
        return LogisticRegression.estimate(xs, cs, prior, priorBlockSize, hotStart, annealingSchedule, minImprovement, rollingAverageSize, minEpochs, maxEpochs, handler, reporter);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static LogisticRegression estimate(Vector[] xs, int[] cs, RegressionPrior prior, int priorBlockSize, LogisticRegression hotStart, AnnealingSchedule annealingSchedule, double minImprovement, int rollingAverageSize, int minEpochs, int maxEpochs, ObjectHandler<LogisticRegression> handler, Reporter reporter) {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        reporter.info("Logistic Regression Estimation");
        boolean monitoringConvergence = !Double.isNaN(minImprovement);
        reporter.info("Monitoring convergence=" + monitoringConvergence);
        if (minImprovement < 0.0) {
            String msg = "Min improvement should be Double.NaN to turn off convergence or >= 0.0 otherwise. Found minImprovement=" + minImprovement;
            throw new IllegalArgumentException(msg);
        }
        if (xs.length < 1) {
            String msg = "Require at least one training instance.";
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        if (xs.length != cs.length) {
            String msg = "Require same number of training instances as outcomes. Found xs.length=" + xs.length + " cs.length=" + cs.length;
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        int numTrainingInstances = xs.length;
        int numOutcomesMinus1 = Math.max(cs);
        int numOutcomes = numOutcomesMinus1 + 1;
        int numDimensions = xs[0].numDimensions();
        prior.verifyNumberOfDimensions(numDimensions);
        for (int i = 1; i < xs.length; ++i) {
            if (xs[i].numDimensions() == numDimensions) continue;
            String msg = "Number of dimensions must match for all input vectors. Found xs[0].numDimensions()=" + numDimensions + " xs[" + i + "].numDimensions()=" + xs[i].numDimensions();
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        Vector[] weightVectors = new DenseVector[numOutcomesMinus1];
        if (hotStart == null) {
            for (int k = 0; k < numOutcomesMinus1; ++k) {
                weightVectors[k] = new DenseVector(numDimensions);
            }
        } else {
            Vector[] hotStartWeightVectors = hotStart.weightVectors();
            for (int k = 0; k < weightVectors.length; ++k) {
                weightVectors[k] = new DenseVector(hotStartWeightVectors[k]);
            }
        }
        LogisticRegression regression = new LogisticRegression(weightVectors);
        boolean hasPrior = prior != null && !prior.isUniform();
        reporter.info("Number of dimensions=" + numDimensions);
        reporter.info("Number of Outcomes=" + numOutcomes);
        reporter.info("Number of Parameters=" + (long)(numOutcomes - 1) * (long)numDimensions);
        reporter.info("Number of Training Instances=" + cs.length);
        reporter.info("Prior=" + prior);
        reporter.info("Annealing Schedule=" + annealingSchedule);
        reporter.info("Minimum Epochs=" + minEpochs);
        reporter.info("Maximum Epochs=" + maxEpochs);
        reporter.info("Minimum Improvement Per Period=" + minImprovement);
        reporter.info("Has Informative Prior=" + hasPrior);
        double lastLog2LikelihoodAndPrior = -8.988465674311579E307;
        double[] rollingAbsDiffs = new double[rollingAverageSize];
        Arrays.fill(rollingAbsDiffs, Double.POSITIVE_INFINITY);
        int rollingAveragePosition = 0;
        double bestLog2LikelihoodAndPrior = Double.NEGATIVE_INFINITY;
        for (int epoch = 0; epoch < maxEpochs; ++epoch) {
            double relativeAbsDiff;
            boolean acceptUpdate;
            DenseVector[] weightVectorCopies = LogisticRegression.copy((DenseVector[])weightVectors);
            double learningRate = annealingSchedule.learningRate(epoch);
            double[] conditionalProbs = new double[numOutcomes];
            for (int j = 0; j < numTrainingInstances; ++j) {
                if (j % (numTrainingInstances / 10) == 0 && reporter.isDebugEnabled()) {
                    reporter.debug("          epoch " + epoch + " is " + 100 * j / numTrainingInstances + "% complete");
                }
                Vector xsJ = xs[j];
                int csJ = cs[j];
                if (hasPrior && j > 0 && j % priorBlockSize == 0) {
                    LogisticRegression.adjustWeightsWithPrior((DenseVector[])weightVectors, prior, learningRate * (double)priorBlockSize / (double)numTrainingInstances);
                }
                regression.classify(xsJ, conditionalProbs);
                for (int k = 0; k < numOutcomesMinus1; ++k) {
                    LogisticRegression.adjustWeightsWithConditionalProbs((DenseVector)weightVectors[k], conditionalProbs[k], learningRate, xsJ, k, csJ);
                }
            }
            reporter.debug("catching up regularizations at end of epoch");
            int blockRemainder = numTrainingInstances % priorBlockSize;
            if (blockRemainder == 0) {
                blockRemainder = priorBlockSize;
            }
            if (hasPrior) {
                LogisticRegression.adjustWeightsWithPrior((DenseVector[])weightVectors, prior, learningRate * (double)blockRemainder / (double)numTrainingInstances);
            }
            if (handler != null) {
                reporter.debug("handling regression for epoch");
                handler.handle(regression);
            }
            if (!monitoringConvergence) {
                reporter.info("Unmonitored Epoch=" + epoch);
                continue;
            }
            reporter.debug("computing log likelihood");
            double log2Likelihood = LogisticRegression.log2Likelihood(xs, cs, regression);
            double log2Prior = prior.log2Prior(weightVectors);
            double log2LikelihoodAndPrior = log2Likelihood + prior.log2Prior(weightVectors);
            if (log2LikelihoodAndPrior > bestLog2LikelihoodAndPrior) {
                bestLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
            }
            if (reporter.isInfoEnabled()) {
                Formatter formatter = null;
                try {
                    formatter = new Formatter(Locale.ENGLISH);
                    formatter.format("epoch=%5d lr=%11.9f ll=%11.4f lp=%11.4f llp=%11.4f llp*=%11.4f", epoch, learningRate, log2Likelihood, log2Prior, log2LikelihoodAndPrior, bestLog2LikelihoodAndPrior);
                    reporter.info(formatter.toString());
                }
                catch (IllegalFormatException e) {
                    reporter.warn("Illegal format in Logistic Regression");
                }
                finally {
                    if (formatter != null) {
                        formatter.close();
                    }
                }
            }
            if (!(acceptUpdate = annealingSchedule.receivedError(epoch, learningRate, -log2LikelihoodAndPrior))) {
                reporter.info("Annealing rejected update at learningRate=" + learningRate + " error=" + -log2LikelihoodAndPrior);
                weightVectors = weightVectorCopies;
                regression = new LogisticRegression(weightVectors);
                continue;
            }
            rollingAbsDiffs[rollingAveragePosition] = relativeAbsDiff = Math.relativeAbsoluteDifference(lastLog2LikelihoodAndPrior, log2LikelihoodAndPrior);
            if (++rollingAveragePosition == rollingAbsDiffs.length) {
                rollingAveragePosition = 0;
            }
            double rollingAvgAbsDiff = Statistics.mean(rollingAbsDiffs);
            reporter.debug("relativeAbsDiff=" + relativeAbsDiff + " rollingAvg=" + rollingAvgAbsDiff);
            lastLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
            if (!(rollingAvgAbsDiff < minImprovement)) continue;
            reporter.info("Converged with Rolling Average Absolute Difference=" + rollingAvgAbsDiff);
            break;
        }
        return regression;
    }

    public static double log2Likelihood(Vector[] inputs, int[] cats, LogisticRegression regression) {
        if (inputs.length != cats.length) {
            String msg = "Inputs and categories must be same length. Found inputs.length=" + inputs.length + " cats.length=" + cats.length;
            throw new IllegalArgumentException(msg);
        }
        int numTrainingInstances = inputs.length;
        double log2Likelihood = 0.0;
        double[] conditionalProbs = new double[regression.numOutcomes()];
        for (int j = 0; j < numTrainingInstances; ++j) {
            regression.classify(inputs[j], conditionalProbs);
            log2Likelihood += Math.log2(conditionalProbs[cats[j]]);
        }
        return log2Likelihood;
    }

    private static void adjustWeightsWithPrior(DenseVector[] weightVectors, RegressionPrior prior, double learningRate) {
        for (int k = 0; k < weightVectors.length; ++k) {
            DenseVector weightVectorsK = weightVectors[k];
            int numDimensions = weightVectorsK.numDimensions();
            for (int i = 0; i < numDimensions; ++i) {
                double priorGradient;
                double delta;
                double priorMode;
                double weight_k_i = weightVectorsK.value(i);
                if (weight_k_i == (priorMode = prior.mode(i)) || (delta = (priorGradient = prior.gradient(weight_k_i, i)) * learningRate) == 0.0) continue;
                double adjWeight_k_i = weight_k_i - delta;
                double mode = prior.mode(i);
                if (weight_k_i > mode) {
                    if (adjWeight_k_i < mode) {
                        adjWeight_k_i = mode;
                    }
                } else if (adjWeight_k_i > mode) {
                    adjWeight_k_i = mode;
                }
                weightVectorsK.setValue(i, adjWeight_k_i);
            }
        }
    }

    private static void adjustWeightsWithConditionalProbs(DenseVector weightVectorsK, double conditionalProb, double learningRate, Vector xsJ, int k, int csJ) {
        double conditionalProbMinusTruth;
        double d = conditionalProbMinusTruth = k == csJ ? conditionalProb - 1.0 : conditionalProb;
        if (conditionalProbMinusTruth == 0.0) {
            return;
        }
        weightVectorsK.increment(-learningRate * conditionalProbMinusTruth, xsJ);
    }

    private static DenseVector[] copy(DenseVector[] xs) {
        DenseVector[] result = new DenseVector[xs.length];
        for (int k = 0; k < xs.length; ++k) {
            result[k] = new DenseVector(xs[k]);
        }
        return result;
    }

    static class Externalizer
    extends AbstractExternalizable {
        static final long serialVersionUID = -2256261505231943102L;
        final LogisticRegression mRegression;

        public Externalizer() {
            this(null);
        }

        public Externalizer(LogisticRegression regression) {
            this.mRegression = regression;
        }

        public void writeExternal(ObjectOutput out) throws IOException {
            int numOutcomes = this.mRegression.mWeightVectors.length + 1;
            out.writeInt(numOutcomes);
            int numDimensions = this.mRegression.mWeightVectors[0].numDimensions();
            out.writeInt(numDimensions);
            for (int c = 0; c < numOutcomes - 1; ++c) {
                Vector vC = this.mRegression.mWeightVectors[c];
                for (int i = 0; i < numDimensions; ++i) {
                    out.writeDouble(vC.value(i));
                }
            }
        }

        public Object read(ObjectInput in) throws IOException {
            int numOutcomes = in.readInt();
            int numDimensions = in.readInt();
            Vector[] weightVectors = new Vector[numOutcomes - 1];
            for (int c = 0; c < weightVectors.length; ++c) {
                DenseVector weightVectorsC = new DenseVector(numDimensions);
                weightVectors[c] = weightVectorsC;
                for (int i = 0; i < numDimensions; ++i) {
                    weightVectorsC.setValue(i, in.readDouble());
                }
            }
            return new LogisticRegression(weightVectors);
        }
    }
}

