/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.normalization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BatchNormalization
extends BaseLayer<org.deeplearning4j.nn.conf.layers.BatchNormalization> {
    private static final Logger log = LoggerFactory.getLogger(BatchNormalization.class);
    BatchNormalizationHelper helper = null;
    protected int index = 0;
    protected List<IterationListener> listeners = new ArrayList<IterationListener>();
    protected INDArray std;
    protected INDArray xMu;
    protected INDArray xHat;

    public BatchNormalization(NeuralNetConfiguration conf) {
        super(conf);
        this.initializeHelper();
    }

    void initializeHelper() {
        block3: {
            try {
                this.helper = Class.forName("org.deeplearning4j.nn.layers.normalization.CudnnBatchNormalizationHelper").asSubclass(BatchNormalizationHelper.class).newInstance();
                log.debug("CudnnBatchNormalizationHelper successfully initialized");
                if (!this.helper.checkSupported(((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getEps())) {
                    this.helper = null;
                }
            }
            catch (Throwable t) {
                if (t instanceof ClassNotFoundException) break block3;
                log.warn("Could not initialize CudnnBatchNormalizationHelper", t);
            }
        }
    }

    @Override
    public double calcL2(boolean backpropParamsOnly) {
        return 0.0;
    }

    @Override
    public double calcL1(boolean backpropParamsOnly) {
        return 0.0;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override
    public Gradient error(INDArray input) {
        return null;
    }

    @Override
    public Gradient calcGradient(Gradient layerError, INDArray indArray) {
        return null;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        INDArray nextEpsilon;
        INDArray dBeta;
        INDArray dBetaView;
        INDArray dGammaView;
        int[] shape = this.getShape(epsilon);
        int batchSize = epsilon.size(0);
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = (org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf();
        INDArray gamma = null;
        INDArray dGlobalMeanView = (INDArray)this.gradientViews.get("mean");
        INDArray dGlobalVarView = (INDArray)this.gradientViews.get("var");
        if (layerConf.isLockGammaBeta()) {
            int[] tempShape = new int[]{1, shape[1]};
            dGammaView = Nd4j.createUninitialized((int[])tempShape, (char)'c');
            dBetaView = Nd4j.createUninitialized((int[])tempShape, (char)'c');
        } else {
            gamma = this.getParam("gamma");
            dGammaView = (INDArray)this.gradientViews.get("gamma");
            dBetaView = (INDArray)this.gradientViews.get("beta");
        }
        DefaultGradient retGradient = new DefaultGradient();
        if (this.helper != null && epsilon.rank() == 4) {
            Pair<Gradient, INDArray> ret;
            if (layerConf.isLockGammaBeta()) {
                gamma = Nd4j.valueArrayOf((int[])new int[]{1, shape[1]}, (double)layerConf.getGamma());
            }
            if ((ret = this.helper.backpropGradient(this.input, epsilon, shape, gamma, dGammaView, dBetaView, layerConf.getEps())) != null) {
                return ret;
            }
        }
        if (epsilon.rank() == 2) {
            dBeta = epsilon.sum(new int[]{0});
            INDArray dGamma = epsilon.mul(this.xHat).sum(new int[]{0});
            INDArray dxhat = layerConf.isLockGammaBeta() ? epsilon.mul((Number)layerConf.getGamma()) : epsilon.mulRowVector(gamma);
            INDArray dLdVar = dxhat.mul(this.xMu).sum(new int[]{0}).muli((Number)-0.5).muli(Transforms.pow((INDArray)this.std, (Number)-3.0, (boolean)true));
            INDArray dxmu1 = dxhat.sum(new int[]{0}).divi(this.std).negi();
            INDArray dxmu2 = this.xMu.sum(new int[]{0}).muli((Number)(-2.0 / (double)batchSize)).muli(dLdVar);
            INDArray dLdmu = dxmu1.addi(dxmu2);
            INDArray dLdx = dxhat.diviRowVector(this.std).addi(this.xMu.muliRowVector(dLdVar.muli((Number)(2.0 / (double)batchSize)))).addiRowVector(dLdmu.muli((Number)(1.0 / (double)batchSize)));
            dGammaView.assign(dGamma);
            dBetaView.assign(dBeta);
            retGradient.setGradientFor("gamma", dGammaView);
            retGradient.setGradientFor("beta", dBetaView);
            dGlobalMeanView.assign((Number)0);
            dGlobalVarView.assign((Number)0);
            retGradient.setGradientFor("mean", dGlobalMeanView);
            retGradient.setGradientFor("var", dGlobalVarView);
            nextEpsilon = dLdx;
        } else if (epsilon.rank() == 4) {
            dBeta = epsilon.sum(new int[]{0, 2, 3});
            INDArray dGamma = epsilon.mul(this.xHat).sum(new int[]{0, 2, 3});
            INDArray dxhat = layerConf.isLockGammaBeta() ? epsilon.mul((Number)layerConf.getGamma()) : Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(epsilon, gamma, Nd4j.createUninitialized((int[])epsilon.shape(), (char)epsilon.ordering()), new int[]{1}));
            INDArray dLdVar = dxhat.mul(this.xMu).sum(new int[]{0, 2, 3}).muli((Number)-0.5).muli(Transforms.pow((INDArray)this.std, (Number)-3.0, (boolean)true));
            int effectiveBatchSize = this.input.size(0) * this.input.size(2) * this.input.size(3);
            INDArray dxmu1 = dxhat.sum(new int[]{0, 2, 3}).divi(this.std).negi();
            INDArray dxmu2 = this.xMu.sum(new int[]{0, 2, 3}).muli((Number)(-2.0 / (double)effectiveBatchSize)).muli(dLdVar);
            INDArray dLdmu = dxmu1.addi(dxmu2);
            INDArray dLdx = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastDivOp(dxhat, this.std, dxhat, new int[]{1})).addi(Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(this.xMu, dLdVar.muli((Number)(2.0 / (double)effectiveBatchSize)), this.xMu, new int[]{1})));
            Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(dLdx, dLdmu.muli((Number)(1.0 / (double)effectiveBatchSize)), dLdx, new int[]{1}));
            dGammaView.assign(dGamma);
            dBetaView.assign(dBeta);
            retGradient.setGradientFor("gamma", dGammaView);
            retGradient.setGradientFor("beta", dBetaView);
            dGlobalMeanView.assign((Number)0);
            dGlobalVarView.assign((Number)0);
            retGradient.setGradientFor("mean", dGlobalMeanView);
            retGradient.setGradientFor("var", dGlobalVarView);
            nextEpsilon = dLdx;
        } else {
            throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported. " + this.layerId());
        }
        return new Pair((Object)retGradient, (Object)nextEpsilon);
    }

    @Override
    public void merge(Layer layer, int batchSize) {
        throw new UnsupportedOperationException(this.layerId());
    }

    @Override
    public void fit(INDArray data) {
    }

    @Override
    public INDArray activate(boolean training) {
        return this.preOutput(this.input, training ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST);
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    @Override
    public INDArray preOutput(INDArray x) {
        return this.preOutput(x, Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray preOutput(INDArray x, Layer.TrainingMode training) {
        INDArray activations;
        double decay;
        INDArray ret;
        INDArray var;
        INDArray mean;
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = (org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf();
        int[] shape = this.getShape(x);
        if (training == Layer.TrainingMode.TRAIN) {
            switch (x.rank()) {
                case 2: {
                    mean = x.mean(new int[]{0});
                    var = x.var(false, new int[]{0});
                    break;
                }
                case 4: {
                    mean = x.mean(new int[]{0, 2, 3});
                    var = x.var(false, new int[]{0, 2, 3});
                    break;
                }
                default: {
                    throw new IllegalStateException("Batch normalization on activations of rank " + x.rank() + " not supported " + this.layerId());
                }
            }
            var.addi((Number)layerConf.getEps());
        } else {
            mean = this.getParam("mean");
            var = this.getParam("var");
        }
        this.std = Transforms.sqrt((INDArray)var, (boolean)true).leverageTo("LOOP_EXTERNAL");
        INDArray gamma = null;
        INDArray beta = null;
        INDArray globalMeanView = this.getParam("mean");
        INDArray globalVarView = this.getParam("var");
        if (layerConf.isLockGammaBeta()) {
            if (this.helper != null && this.input.rank() == 4) {
                int[] gammaBetaShape = new int[]{1, ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getNOut()};
                gamma = Nd4j.valueArrayOf((int[])gammaBetaShape, (double)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getGamma());
                beta = Nd4j.valueArrayOf((int[])gammaBetaShape, (double)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getBeta());
            }
        } else {
            gamma = this.getParam("gamma");
            beta = this.getParam("beta");
        }
        if (this.helper != null && this.input.rank() == 4 && (ret = this.helper.preOutput(x, training == Layer.TrainingMode.TRAIN, shape, gamma, beta, globalMeanView, globalVarView, decay = layerConf.getDecay(), layerConf.getEps())) != null) {
            return ret;
        }
        if (x.rank() == 2) {
            this.xMu = x.subRowVector(mean).leverageTo("LOOP_EXTERNAL");
            this.xHat = this.xMu.divRowVector(this.std).leverageTo("LOOP_EXTERNAL");
            if (layerConf.isLockGammaBeta()) {
                double g = layerConf.getGamma();
                double b = layerConf.getBeta();
                activations = g != 1.0 && b != 0.0 ? this.xHat.mul((Number)g).addi((Number)b) : this.xHat;
            } else {
                activations = this.xHat.mulRowVector(gamma).addiRowVector(beta);
            }
        } else if (x.rank() == 4) {
            if (!Shape.strideDescendingCAscendingF((INDArray)x)) {
                x = x.dup();
            }
            this.xMu = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastSubOp(x, mean, Nd4j.createUninitialized((int[])x.shape(), (char)x.ordering()), new int[]{1})).leverageTo("LOOP_EXTERNAL");
            this.xHat = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastDivOp(this.xMu, this.std, Nd4j.createUninitialized((int[])x.shape(), (char)x.ordering()), new int[]{1})).leverageTo("LOOP_EXTERNAL");
            if (layerConf.isLockGammaBeta()) {
                double g = layerConf.getGamma();
                double b = layerConf.getBeta();
                activations = g != 1.0 && b != 0.0 ? this.xHat.mul((Number)g).addi((Number)b) : this.xHat;
            } else {
                activations = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(this.xHat, gamma, Nd4j.createUninitialized((int[])x.shape(), (char)x.ordering()), new int[]{1}));
                activations = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(activations, beta, activations, new int[]{1}));
            }
        } else {
            throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported. " + this.layerId());
        }
        if (training == Layer.TrainingMode.TRAIN) {
            if (layerConf.isMinibatch()) {
                decay = layerConf.getDecay();
                globalMeanView.muli((Number)decay).addi(mean.muli((Number)(1.0 - decay)));
                globalVarView.muli((Number)decay).addi(var.muli((Number)(1.0 - decay)));
            } else {
                globalMeanView.assign(mean);
                globalVarView.assign(var);
            }
        }
        return activations;
    }

    @Override
    public INDArray activate(Layer.TrainingMode training) {
        throw new UnsupportedOperationException(this.layerId());
    }

    @Override
    public INDArray activate(INDArray input, Layer.TrainingMode training) {
        return this.preOutput(input, training);
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        return this.preOutput(x, training ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST);
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException(this.layerId());
    }

    @Override
    public Layer clone() {
        throw new UnsupportedOperationException(this.layerId());
    }

    @Override
    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    @Override
    public void setListeners(IterationListener ... listeners) {
        this.listeners = new ArrayList<IterationListener>(Arrays.asList(listeners));
    }

    @Override
    public void setIndex(int index) {
        this.index = index;
    }

    @Override
    public int getIndex() {
        return this.index;
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    public int[] getShape(INDArray x) {
        if (x.rank() == 2 || x.rank() == 4) {
            return new int[]{1, x.size(1)};
        }
        if (x.rank() == 3) {
            int wDim = x.size(1);
            int hdim = x.size(2);
            if (x.size(0) > 1 && wDim * hdim == x.length()) {
                throw new IllegalArgumentException("Illegal input for batch size " + this.layerId());
            }
            return new int[]{1, wDim * hdim};
        }
        throw new IllegalStateException("Unable to process input of rank " + x.rank() + " " + this.layerId());
    }
}

