/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.ocnn;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;

public class OCNNOutputLayer
extends BaseOutputLayer<org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer> {
    private IActivation activation = new ActivationReLU();
    private static IActivation relu = new ActivationReLU();
    private ILossFunction lossFunction;

    public OCNNOutputLayer(NeuralNetConfiguration conf) {
        super(conf);
        this.lossFunction = new OCNNLossFunction();
        org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)conf.getLayer();
        ocnnOutputLayer.setLossFn(this.lossFunction);
    }

    public OCNNOutputLayer(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
        org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)conf.getLayer();
        ocnnOutputLayer.setLossFn(this.lossFunction);
    }

    @Override
    public INDArray getLabels() {
        return super.getLabels();
    }

    @Override
    public void setLabels(INDArray labels) {
    }

    @Override
    public double computeScore(double fullNetworkL1, double fullNetworkL2, boolean training, LayerWorkspaceMgr workspaceMgr) {
        if (this.input == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        INDArray preOut = this.preOutput2d(training, workspaceMgr);
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getLossFn();
        double score = lossFunction.computeScore(this.getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM), preOut, ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn(), this.maskArray, false);
        score += fullNetworkL1 + fullNetworkL2;
        if (this.conf().isMiniBatch()) {
            score /= (double)this.getInputMiniBatchSize();
        }
        this.score = score;
        return score;
    }

    @Override
    public boolean needsLabels() {
        return false;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(true);
        Pair<Gradient, INDArray> pair = this.getGradientsAndDelta(this.preOutput2d(true, workspaceMgr), workspaceMgr);
        int inputShape = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.getConf().getLayer()).getNIn();
        INDArray delta = (INDArray)pair.getSecond();
        INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new int[]{inputShape, delta.length()}, 'f');
        epsilonNext = epsilonNext.assign(delta.broadcast(epsilonNext.shape())).transpose();
        return new Pair(pair.getFirst(), (Object)epsilonNext);
    }

    private Pair<Gradient, INDArray> getGradientsAndDelta(INDArray preOut, LayerWorkspaceMgr workspaceMgr) {
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getLossFn();
        INDArray labels2d = this.getLabels2d(workspaceMgr, ArrayType.BP_WORKING_MEM);
        INDArray delta = lossFunction.computeGradient(labels2d, preOut, ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn(), this.maskArray);
        org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer conf = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.conf().getLayer();
        if (conf.getLastEpochSinceRUpdated() != this.epochCount) {
            INDArray currentR = this.doOutput(false, workspaceMgr);
            double percentile = currentR.percentileNumber((Number)4.0).doubleValue();
            this.getParam("r").putScalar(0, percentile);
            conf.setLastEpochSinceRUpdated(this.epochCount);
        }
        DefaultGradient gradient = new DefaultGradient();
        INDArray vGradView = (INDArray)this.gradientViews.get("v");
        double oneDivNu = 1.0 / ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getNu();
        INDArray xTimesV = this.input.mmul(this.getParam("v"));
        INDArray derivW = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn().getActivation(xTimesV.dup(), true).negi();
        derivW = delta.isRowVector() ? derivW.muliRowVector(delta).mean(new int[]{0}).muli((Number)oneDivNu).addi(this.getParam("w")) : derivW.muli(delta).mean(new int[]{0}).muli((Number)oneDivNu).addi(this.getParam("w"));
        gradient.setGradientFor("w", ((INDArray)this.gradientViews.get("w")).assign(derivW));
        INDArray firstVertDerivV = ((INDArray)((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn().backprop(xTimesV.dup(), Nd4j.scalar((double)1.0)).getFirst()).muliRowVector(this.getParam("w").neg());
        firstVertDerivV = delta.isRowVector() ? firstVertDerivV.muliRowVector(delta).reshape('f', new int[]{this.input.size(0), 1, ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getHiddenSize()}) : firstVertDerivV.muli(delta).reshape('f', new int[]{this.input.size(0), 1, ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getHiddenSize()});
        INDArray secondTermDerivV = this.input.reshape('f', new int[]{this.input.size(0), this.getParam("v").size(0), 1});
        int[] shape = new int[firstVertDerivV.shape().length];
        for (int i = 0; i < firstVertDerivV.rank(); ++i) {
            shape[i] = Math.max(firstVertDerivV.size(i), secondTermDerivV.size(i));
        }
        INDArray firstDerivVBroadcast = Nd4j.createUninitialized((int[])shape);
        INDArray mulResult = firstVertDerivV.broadcast(firstDerivVBroadcast);
        int[] bcDims = Shape.getBroadcastDimensions((int[])mulResult.shape(), (int[])secondTermDerivV.shape());
        Broadcast.mul((INDArray)mulResult, (INDArray)secondTermDerivV, (INDArray)mulResult, (int[])bcDims);
        INDArray derivV = mulResult.mean(new int[]{0}).muli((Number)oneDivNu).addi(this.getParam("v"));
        gradient.setGradientFor("v", vGradView.assign(derivV));
        INDArray derivR = Nd4j.scalar((Number)delta.meanNumber()).muli((Number)oneDivNu).addi((Number)-1);
        gradient.setGradientFor("r", Nd4j.scalar((double)0.0));
        gradient.setGradientFor("r", ((INDArray)this.gradientViews.get("r")).assign(derivR));
        this.clearNoiseWeightParams();
        return new Pair((Object)gradient, (Object)delta);
    }

    @Override
    public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.input = input;
        return this.doOutput(training, workspaceMgr);
    }

    @Override
    public double f1Score(INDArray examples, INDArray labels) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray labelProbabilities(INDArray examples) {
        float[] decision = examples.data().asFloat();
        for (int i = 0; i < decision.length; ++i) {
            decision[i] = decision[i] < 0.0f ? 0.0f : 1.0f;
        }
        return Nd4j.create((float[])decision);
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.FEED_FORWARD;
    }

    @Override
    protected INDArray preOutput2d(boolean training, LayerWorkspaceMgr workspaceMgr) {
        return this.doOutput(training, workspaceMgr);
    }

    @Override
    protected INDArray getLabels2d(LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        return this.labels;
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        return this.doOutput(training, workspaceMgr);
    }

    private INDArray doOutput(boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(false);
        INDArray w = this.getParamWithNoise("w", training, workspaceMgr);
        INDArray v = this.getParamWithNoise("v", training, workspaceMgr);
        this.applyDropOutIfNecessary(training, workspaceMgr);
        INDArray first = Nd4j.createUninitialized((int)this.input.size(0), (int)v.size(1));
        this.input.mmuli(v, first);
        INDArray act2d = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn().getActivation(first, training);
        INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new int[]{this.input.size(0)});
        act2d.mmuli(w.reshape(new int[]{w.length()}), output);
        this.labels = output;
        return output;
    }

    @Override
    public INDArray computeScoreForExamples(double fullNetworkL1, double fullNetworkL2, LayerWorkspaceMgr workspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        INDArray preOut = this.preOutput2d(false, workspaceMgr);
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getLossFn();
        INDArray scoreArray = lossFunction.computeScoreArray(this.getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM), preOut, ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn(), this.maskArray);
        INDArray summedScores = scoreArray.sum(new int[]{1});
        double l1l2 = fullNetworkL1 + fullNetworkL2;
        if (l1l2 != 0.0) {
            summedScores.addi((Number)l1l2);
        }
        return summedScores;
    }

    public void setActivation(IActivation activation) {
        this.activation = activation;
    }

    public IActivation getActivation() {
        return this.activation;
    }

    public class OCNNLossFunction
    implements ILossFunction {
        public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
            double wSum = Transforms.pow((INDArray)OCNNOutputLayer.this.getParam("w"), (Number)2).sumNumber().doubleValue() * 0.5;
            double vSum = Transforms.pow((INDArray)OCNNOutputLayer.this.getParam("v"), (Number)2).sumNumber().doubleValue() * 0.5;
            org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)OCNNOutputLayer.this.conf().getLayer();
            INDArray rMeanSub = relu.getActivation(OCNNOutputLayer.this.getParam("r").sub(preOutput), true);
            double rMean = rMeanSub.meanNumber().doubleValue();
            double rSum = OCNNOutputLayer.this.getParam("r").getDouble(0);
            double nuDiv = 1.0 / ocnnOutputLayer.getNu() * rMean;
            double lastTerm = -rSum;
            return wSum + vSum + nuDiv + lastTerm;
        }

        public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
            INDArray r = OCNNOutputLayer.this.getParam("r").sub(preOutput);
            return r;
        }

        public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
            INDArray preAct = OCNNOutputLayer.this.getParam("r").sub(preOutput);
            INDArray target = (INDArray)relu.backprop(OCNNOutputLayer.this.getParam("r").sub(preOutput), Nd4j.ones((int[])preAct.shape())).getFirst();
            return target;
        }

        public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
            return new Pair((Object)this.computeScore(labels, preOutput, activationFn, mask, average), (Object)this.computeGradient(labels, preOutput, activationFn, mask));
        }

        public String name() {
            return "OCNNLossFunction";
        }
    }
}

