/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.training;

import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.primitives.Pair;

public class CenterLossOutputLayer
extends BaseOutputLayer<org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer> {
    private double fullNetworkL1;
    private double fullNetworkL2;

    public CenterLossOutputLayer(NeuralNetConfiguration conf) {
        super(conf);
    }

    public CenterLossOutputLayer(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    @Override
    public double computeScore(double fullNetworkL1, double fullNetworkL2, boolean training) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        this.fullNetworkL1 = fullNetworkL1;
        this.fullNetworkL2 = fullNetworkL2;
        INDArray preOut = this.preOutput2d(training);
        ILossFunction interClassLoss = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLossFn();
        INDArray centers = (INDArray)this.params.get("cL");
        INDArray centersForExamples = this.labels.mmul(centers);
        INDArray norm2DifferenceSquared = this.input.sub(centersForExamples).norm2(new int[]{1});
        norm2DifferenceSquared.muli(norm2DifferenceSquared);
        double sum = norm2DifferenceSquared.sumNumber().doubleValue();
        double lambda = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLambda();
        double intraClassScore = lambda / 2.0 * sum;
        if (System.getenv("PRINT_CENTERLOSS") != null) {
            System.out.println("Center loss is " + intraClassScore);
        }
        double interClassScore = interClassLoss.computeScore(this.getLabels2d(), preOut, ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getActivationFn(), this.maskArray, false);
        double score = interClassScore + intraClassScore;
        score += fullNetworkL1 + fullNetworkL2;
        this.score = score /= (double)this.getInputMiniBatchSize();
        return score;
    }

    @Override
    public INDArray computeScoreForExamples(double fullNetworkL1, double fullNetworkL2) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        INDArray preOut = this.preOutput2d(false);
        INDArray centers = (INDArray)this.params.get("cL");
        INDArray centersForExamples = this.labels.mmul(centers);
        INDArray intraClassScoreArray = this.input.sub(centersForExamples);
        ILossFunction interClassLoss = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLossFn();
        INDArray scoreArray = interClassLoss.computeScoreArray(this.getLabels2d(), preOut, ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getActivationFn(), this.maskArray);
        scoreArray.addi(intraClassScoreArray.muli((Number)(((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLambda() / 2.0)));
        double l1l2 = fullNetworkL1 + fullNetworkL2;
        if (l1l2 != 0.0) {
            scoreArray.addi((Number)l1l2);
        }
        return scoreArray;
    }

    @Override
    public void computeGradientAndScore() {
        if (this.input == null || this.labels == null) {
            return;
        }
        INDArray preOut = this.preOutput2d(true);
        Pair<Gradient, INDArray> pair = this.getGradientsAndDelta(preOut);
        this.gradient = (Gradient)pair.getFirst();
        this.score = this.computeScore(this.fullNetworkL1, this.fullNetworkL2, true);
    }

    @Override
    protected void setScoreWithZ(INDArray z) {
        throw new RuntimeException("Not supported " + this.layerId());
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair((Object)this.gradient(), (Object)this.score());
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        Pair<Gradient, INDArray> pair = this.getGradientsAndDelta(this.preOutput2d(true));
        INDArray delta = (INDArray)pair.getSecond();
        INDArray centers = (INDArray)this.params.get("cL");
        INDArray centersForExamples = this.labels.mmul(centers);
        INDArray dLcdai = this.input.sub(centersForExamples);
        INDArray epsilonNext = ((INDArray)this.params.get("W")).mmul(delta.transpose()).transpose();
        double lambda = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLambda();
        epsilonNext.addi(dLcdai.muli((Number)lambda));
        return new Pair(pair.getFirst(), (Object)epsilonNext);
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    private Pair<Gradient, INDArray> getGradientsAndDelta(INDArray preOut) {
        INDArray deltaC;
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLossFn();
        INDArray labels2d = this.getLabels2d();
        if (labels2d.size(1) != preOut.size(1)) {
            throw new DL4JInvalidInputException("Labels array numColumns (size(1) = " + labels2d.size(1) + ") does not match output layer number of outputs (nOut = " + preOut.size(1) + ") " + this.layerId());
        }
        INDArray delta = lossFunction.computeGradient(labels2d, preOut, ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getActivationFn(), this.maskArray);
        DefaultGradient gradient = new DefaultGradient();
        INDArray weightGradView = (INDArray)this.gradientViews.get("W");
        INDArray biasGradView = (INDArray)this.gradientViews.get("b");
        INDArray centersGradView = (INDArray)this.gradientViews.get("cL");
        double alpha = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getAlpha();
        INDArray centers = (INDArray)this.params.get("cL");
        INDArray centersForExamples = this.labels.mmul(centers);
        INDArray diff = centersForExamples.sub(this.input).muli((Number)alpha);
        INDArray numerator = this.labels.transpose().mmul(diff);
        INDArray denominator = this.labels.sum(new int[]{0}).addi((Number)1.0).transpose();
        if (((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getGradientCheck()) {
            double lambda = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLambda();
            deltaC = numerator.muli((Number)lambda);
        } else {
            deltaC = numerator.diviColumnVector(denominator);
        }
        centersGradView.assign(deltaC);
        Nd4j.gemm((INDArray)this.input, (INDArray)delta, (INDArray)weightGradView, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
        delta.sum(biasGradView, new int[]{0});
        gradient.gradientForVariable().put("W", weightGradView);
        gradient.gradientForVariable().put("b", biasGradView);
        gradient.gradientForVariable().put("cL", centersGradView);
        return new Pair((Object)gradient, (Object)delta);
    }
}

