/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.LossFunction;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossCalculation;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.LinAlgExceptions;

public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseOutputLayer>
extends BaseLayer<LayerConfT>
implements Serializable,
Classifier {
    protected INDArray labels;
    private transient Solver solver;
    private double fullNetworkL1;
    private double fullNetworkL2;

    public BaseOutputLayer(NeuralNetConfiguration conf) {
        super(conf);
    }

    public BaseOutputLayer(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    public double computeScore(double fullNetworkL1, double fullNetworkL2, boolean training) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels");
        }
        this.fullNetworkL1 = fullNetworkL1;
        this.fullNetworkL2 = fullNetworkL2;
        INDArray preOut = this.preOutput2d(training);
        LossFunctions.LossFunction lf = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.conf.getLayer()).getLossFunction();
        if ((lf == LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD || lf == LossFunctions.LossFunction.MCXENT) && ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getActivationFunction().equals("softmax")) {
            this.setScore(null, preOut);
        } else {
            INDArray output = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf().getLayer().getActivationFunction(), preOut));
            this.setScoreWithZ(output);
        }
        return this.score;
    }

    public INDArray computeScoreForExamples(double fullNetworkL1, double fullNetworkL2) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels");
        }
        INDArray preOut = this.preOutput2d(false);
        INDArray output = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf().getLayer().getActivationFunction(), preOut.dup()));
        return LossCalculation.builder().l1(fullNetworkL1).l2(fullNetworkL2).labels(this.getLabels2d()).z(output).preOut(preOut).activationFn(this.conf().getLayer().getActivationFunction()).lossFunction(((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getLossFunction()).useRegularization(this.conf.isUseRegularization()).mask(this.maskArray).build().scoreExamples();
    }

    @Override
    public void computeGradientAndScore() {
        if (this.input == null || this.labels == null) {
            return;
        }
        INDArray preOut = this.preOutput2d(true);
        Triple<Gradient, INDArray, INDArray> triple = this.getGradientsAndDelta(preOut);
        this.gradient = triple.getFirst();
        this.setScore(triple.getThird(), preOut);
    }

    @Override
    protected void setScoreWithZ(INDArray z) {
        this.setScore(z, null);
    }

    private void setScore(INDArray z, INDArray preOut) {
        if (((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getLossFunction() == LossFunctions.LossFunction.CUSTOM) {
            LossFunction create = Nd4j.getOpFactory().createLossFunction(((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getCustomLossFunction(), this.input, z);
            create.exec();
            this.score = create.getFinalResult().doubleValue();
        } else {
            this.score = LossCalculation.builder().l1(this.fullNetworkL1).l2(this.fullNetworkL2).labels(this.getLabels2d()).z(z).preOut(preOut).activationFn(this.conf().getLayer().getActivationFunction()).lossFunction(((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getLossFunction()).miniBatch(this.conf.isMiniBatch()).miniBatchSize(this.getInputMiniBatchSize()).useRegularization(this.conf.isUseRegularization()).mask(this.maskArray).build().score();
        }
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<Gradient, Double>(this.gradient(), this.score());
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        Triple<Gradient, INDArray, INDArray> triple = this.getGradientsAndDelta(this.preOutput2d(true));
        INDArray delta = triple.getSecond();
        INDArray epsilonNext = ((INDArray)this.params.get("W")).mmul(delta.transpose()).transpose();
        return new Pair<Gradient, INDArray>(triple.getFirst(), epsilonNext);
    }

    @Override
    public Gradient gradient() {
        LinAlgExceptions.assertRows((INDArray)this.input, (INDArray)this.getLabels2d());
        return this.gradient;
    }

    private Triple<Gradient, INDArray, INDArray> getGradientsAndDelta(INDArray preOut) {
        Triple<Gradient, INDArray, INDArray> triple;
        INDArray output = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf().getLayer().getActivationFunction(), preOut.dup()));
        INDArray outSubLabels = output.sub(this.getLabels2d());
        DefaultGradient gradient = new DefaultGradient();
        INDArray weightGradView = (INDArray)this.gradientViews.get("W");
        INDArray biasGradView = (INDArray)this.gradientViews.get("b");
        gradient.gradientForVariable().put("W", weightGradView);
        gradient.gradientForVariable().put("b", biasGradView);
        if (this.maskArray != null) {
            outSubLabels.muliColumnVector(this.maskArray);
        }
        switch (((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getLossFunction()) {
            case NEGATIVELOGLIKELIHOOD: 
            case MCXENT: {
                Nd4j.gemm((INDArray)this.input, (INDArray)outSubLabels, (INDArray)weightGradView, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
                biasGradView.assign(outSubLabels.sum(new int[]{0}));
                triple = new Triple<Gradient, INDArray, INDArray>(gradient, outSubLabels, output);
                break;
            }
            case XENT: {
                Nd4j.gemm((INDArray)this.input, (INDArray)outSubLabels.div(output.mul(output.rsub((Number)1))), (INDArray)weightGradView, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
                biasGradView.assign(outSubLabels.sum(new int[]{0}));
                triple = new Triple<DefaultGradient, INDArray, INDArray>(gradient, outSubLabels, output);
                break;
            }
            case MSE: {
                INDArray delta = outSubLabels.mul(this.derivativeActivation(preOut));
                Nd4j.gemm((INDArray)this.input, (INDArray)delta, (INDArray)weightGradView, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
                biasGradView.assign(delta.sum(new int[]{0}));
                triple = new Triple<DefaultGradient, INDArray, INDArray>(gradient, delta, output);
                break;
            }
            case EXPLL: {
                Nd4j.gemm((INDArray)this.input, (INDArray)this.labels.rsub((Number)1).divi(output), (INDArray)weightGradView, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
                biasGradView.assign(outSubLabels.sum(new int[]{0}));
                triple = new Triple<DefaultGradient, INDArray, INDArray>(gradient, outSubLabels, output);
                break;
            }
            case RMSE_XENT: {
                INDArray squaredrmseXentDiff = Transforms.pow((INDArray)outSubLabels, (Number)2.0);
                INDArray sqrt = Transforms.sqrt((INDArray)squaredrmseXentDiff);
                Nd4j.gemm((INDArray)this.input, (INDArray)sqrt, (INDArray)weightGradView, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
                biasGradView.assign(outSubLabels.sum(new int[]{0}));
                triple = new Triple<DefaultGradient, INDArray, INDArray>(gradient, outSubLabels, output);
                break;
            }
            case SQUARED_LOSS: {
                Nd4j.gemm((INDArray)this.input, (INDArray)outSubLabels.mul(outSubLabels), (INDArray)weightGradView, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
                biasGradView.assign(outSubLabels.sum(new int[]{0}));
                triple = new Triple<DefaultGradient, INDArray, INDArray>(gradient, outSubLabels, output);
                break;
            }
            default: {
                throw new IllegalStateException("Invalid loss function: " + ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)this.layerConf()).getLossFunction());
            }
        }
        return triple;
    }

    @Override
    public INDArray activate(INDArray input, boolean training) {
        this.setInput(input);
        return this.output(training);
    }

    @Override
    public INDArray activate(INDArray input) {
        this.setInput(input);
        return this.output(true);
    }

    @Override
    public INDArray activate() {
        return this.output(false);
    }

    public INDArray output(INDArray input, boolean training) {
        this.setInput(input);
        return this.output(training);
    }

    public INDArray output(INDArray input) {
        this.setInput(input);
        return this.output(false);
    }

    public INDArray output(boolean training) {
        if (this.input == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        return super.activate(training);
    }

    @Override
    public double f1Score(DataSet data) {
        return this.f1Score(data.getFeatureMatrix(), data.getLabels());
    }

    @Override
    public double f1Score(INDArray examples, INDArray labels) {
        Evaluation eval = new Evaluation();
        eval.eval(labels, this.labelProbabilities(examples));
        return eval.f1();
    }

    @Override
    public int numLabels() {
        return this.labels.size(1);
    }

    @Override
    public void fit(DataSetIterator iter) {
        while (iter.hasNext()) {
            this.fit((DataSet)iter.next());
        }
    }

    @Override
    public int[] predict(INDArray input) {
        INDArray output = this.output(input);
        int[] ret = new int[input.rows()];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
        }
        return ret;
    }

    @Override
    public List<String> predict(DataSet dataSet) {
        int[] intRet = this.predict(dataSet.getFeatureMatrix());
        ArrayList<String> ret = new ArrayList<String>();
        for (int i : intRet) {
            ret.add(i, dataSet.getLabelName(i));
        }
        return ret;
    }

    @Override
    public INDArray labelProbabilities(INDArray examples) {
        return this.output(examples);
    }

    @Override
    public void fit(INDArray input, INDArray labels) {
        this.setInput(input);
        this.setLabels(labels);
        this.applyDropOutIfNecessary(true);
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
        }
        this.solver.optimize();
    }

    @Override
    public void fit(DataSet data) {
        this.fit(data.getFeatureMatrix(), data.getLabels());
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
        INDArray outcomeMatrix = FeatureUtil.toOutcomeMatrix((int[])labels, (int)this.numLabels());
        this.fit(examples, outcomeMatrix);
    }

    @Override
    public void clear() {
        super.clear();
        if (this.labels != null) {
            this.labels.data().destroy();
            this.labels = null;
        }
        this.solver = null;
    }

    @Override
    public void fit(INDArray data) {
    }

    @Override
    public void iterate(INDArray input) {
        throw new UnsupportedOperationException();
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    protected INDArray preOutput2d(boolean training) {
        return this.preOutput(training);
    }

    protected INDArray output2d(INDArray input) {
        return this.output(input);
    }

    protected INDArray getLabels2d() {
        if (this.labels.rank() > 2) {
            return this.labels.reshape(this.labels.size(2), this.labels.size(1));
        }
        return this.labels;
    }
}

