/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.normalization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BatchNormalization
extends BaseLayer<org.deeplearning4j.nn.conf.layers.BatchNormalization> {
    protected static final Logger log = LoggerFactory.getLogger(BatchNormalization.class);
    BatchNormalizationHelper helper = null;
    protected int index = 0;
    protected List<IterationListener> listeners = new ArrayList<IterationListener>();
    protected int[] shape;
    protected INDArray mean;
    protected INDArray var;
    protected INDArray std;
    protected INDArray xMu;
    protected INDArray xHat;
    protected Layer.TrainingMode trainingMode;
    protected boolean setMeanVar = true;

    public BatchNormalization(NeuralNetConfiguration conf) {
        super(conf);
        this.initializeHelper();
    }

    void initializeHelper() {
        block2: {
            try {
                this.helper = Class.forName("org.deeplearning4j.nn.layers.normalization.CudnnBatchNormalizationHelper").asSubclass(BatchNormalizationHelper.class).newInstance();
            }
            catch (Throwable t) {
                if (t instanceof ClassNotFoundException) break block2;
                log.warn("Could not load CudnnBatchNormalizationHelper", t);
            }
        }
    }

    @Override
    public double calcL2() {
        return 0.0;
    }

    @Override
    public double calcL1() {
        return 0.0;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override
    public Gradient error(INDArray input) {
        return null;
    }

    @Override
    public Gradient calcGradient(Gradient layerError, INDArray indArray) {
        return null;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        INDArray nextEpsilon;
        INDArray dGamma;
        Pair<Gradient, INDArray> ret;
        this.shape = this.getShape(epsilon);
        int batchSize = epsilon.size(0);
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = (org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf();
        INDArray gamma = layerConf.isLockGammaBeta() ? Nd4j.ones((int[])this.shape) : this.getParam("gamma");
        DefaultGradient retGradient = new DefaultGradient();
        INDArray dGammaView = (INDArray)this.gradientViews.get("gamma");
        INDArray dBetaView = (INDArray)this.gradientViews.get("beta");
        if (this.helper != null && (ret = this.helper.backpropGradient(this.input, epsilon, this.shape, gamma, dGammaView, dBetaView, layerConf.getEps())) != null) {
            return ret;
        }
        if (epsilon.rank() == 2) {
            dGamma = epsilon.mul(this.xHat).sum(new int[]{0});
            INDArray dBeta = epsilon.sum(new int[]{0});
            INDArray dxhat = epsilon.mulRowVector(gamma);
            INDArray dsq = dxhat.mul(this.xMu).sum(new int[]{0}).mul((Number)0.5).div(Transforms.pow((INDArray)this.std, (Number)3)).neg().div((Number)batchSize);
            INDArray dxmu1 = dxhat.divRowVector(this.std);
            INDArray dxmu2 = this.xMu.mul((Number)2).mulRowVector(dsq);
            INDArray dx1 = dxmu1.add(dxmu2);
            INDArray dmu = dx1.sum(new int[]{0}).neg();
            INDArray dx2 = dmu.div((Number)batchSize);
            nextEpsilon = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(dx1, dx2, dx1.dup(), new int[]{-1}));
            INDArray r = this.xMu.divRowVector(Transforms.pow((INDArray)this.std, (Number)2)).mulRowVector(epsilon.mul(this.xMu).sum(new int[]{0}));
            INDArray otherEp = epsilon.mul((Number)2).subRowVector(dBeta).mulRowVector(gamma.div(this.std.mul((Number)2))).sub(r);
            dGammaView.assign(dGamma);
            dBetaView.assign(dBeta);
            retGradient.setGradientFor("gamma", dGammaView);
            retGradient.setGradientFor("beta", dBetaView);
        } else if (epsilon.rank() == 4) {
            dGamma = epsilon.mul(this.xHat).sum(new int[]{0, 2, 3});
            INDArray dBeta = epsilon.sum(new int[]{0, 2, 3});
            INDArray dxhat = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(epsilon, gamma, epsilon.dup(), new int[]{1}));
            INDArray dsq = dxhat.mul(this.xMu).sum(new int[]{0}).mul((Number)0.5).div(Transforms.pow((INDArray)this.std, (Number)3)).neg().div((Number)batchSize);
            INDArray dxmu1 = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastDivOp(dxhat, this.std, dxhat, new int[]{1, 2, 3}));
            INDArray dxmu2 = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(this.xMu.mul((Number)2), dsq, this.xMu.mul((Number)2), new int[]{1, 2, 3}));
            INDArray dx1 = dxmu1.add(dxmu2);
            INDArray dmu = dx1.sum(new int[]{0}).neg();
            INDArray dx2 = dmu.div((Number)batchSize);
            nextEpsilon = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(dx1, dx2, dx1.dup(), new int[]{1, 2, 3}));
            dGammaView.assign(dGamma);
            dBetaView.assign(dBeta);
            retGradient.setGradientFor("gamma", dGammaView);
            retGradient.setGradientFor("beta", dBetaView);
        } else {
            throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported.");
        }
        return new Pair<Gradient, INDArray>(retGradient, nextEpsilon);
    }

    @Override
    public void merge(Layer layer, int batchSize) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void fit(INDArray data) {
    }

    @Override
    public INDArray activate(boolean training) {
        return this.preOutput(this.input, training ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST);
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    @Override
    public INDArray preOutput(INDArray x) {
        return this.preOutput(x, Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray preOutput(INDArray x, Layer.TrainingMode training) {
        double decay;
        INDArray beta;
        INDArray gamma;
        INDArray var;
        INDArray mean;
        INDArray activations = null;
        this.trainingMode = training;
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = (org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf();
        int batchSize = x.size(0);
        this.shape = this.getShape(x);
        if (this.trainingMode == Layer.TrainingMode.TRAIN && layerConf.isUseBatchMean()) {
            mean = x.mean(new int[]{0});
            var = x.var(false, new int[]{0});
            var.addi((Number)layerConf.getEps());
        } else {
            mean = this.mean;
            var = this.var;
        }
        this.std = Transforms.sqrt((INDArray)var);
        if (layerConf.isLockGammaBeta()) {
            gamma = Nd4j.ones((int[])this.shape);
            beta = Nd4j.zeros((int[])this.shape);
        } else {
            gamma = this.getParam("gamma");
            beta = this.getParam("beta");
        }
        if (this.helper != null) {
            INDArray ret;
            double d = decay = this.setMeanVar ? 1.0 : layerConf.getDecay();
            if (this.setMeanVar) {
                this.mean = this.mean == null ? Nd4j.zeros((int[])mean.shape()) : this.mean;
                this.var = this.var == null ? Nd4j.valueArrayOf((int[])var.shape(), (double)layerConf.getEps()) : this.var;
                this.setMeanVar = false;
            }
            if ((ret = this.helper.preOutput(x, training == Layer.TrainingMode.TRAIN && layerConf.isUseBatchMean(), this.shape, gamma, beta, this.mean, this.var, decay, layerConf.getEps())) != null) {
                return ret;
            }
        }
        if (x.rank() == 2) {
            this.xMu = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastSubOp(x, mean, x.dup(), new int[]{-1}));
            this.xHat = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastDivOp(this.xMu, this.std, this.xMu.dup(), new int[]{-1}));
            activations = this.xHat.dup().mulRowVector(gamma).addRowVector(beta);
        } else if (x.rank() == 4) {
            this.xMu = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastSubOp(x, mean, x.dup(), new int[]{1, 2, 3}));
            this.xHat = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastDivOp(this.xMu, this.std, this.xMu.dup(), new int[]{1, 2, 3}));
            activations = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(this.xHat, gamma, this.xHat.dup(), new int[]{1}));
            activations = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(activations, beta, activations, new int[]{1}));
        } else {
            throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported.");
        }
        if (training == Layer.TrainingMode.TRAIN && layerConf.isUseBatchMean()) {
            if (this.setMeanVar) {
                this.mean = this.mean == null ? Nd4j.zeros((int[])mean.shape()) : this.mean;
                this.var = this.var == null ? Nd4j.valueArrayOf((int[])var.shape(), (double)layerConf.getEps()) : this.var;
                this.setMeanVar = false;
            }
            decay = layerConf.getDecay();
            double adjust = (double)batchSize / Math.max((double)batchSize - 1.0, 1.0);
            this.mean = mean.mul((Number)decay).add(this.mean.mul((Number)(1.0 - decay)));
            this.var = var.mul((Number)decay).add(this.var.mul((Number)((1.0 - decay) * adjust)));
        }
        return activations;
    }

    @Override
    public INDArray activate(Layer.TrainingMode training) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray activate(INDArray input, Layer.TrainingMode training) {
        return this.preOutput(input, training);
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        return this.preOutput(x, training ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST);
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Layer clone() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    @Override
    public void setListeners(IterationListener ... listeners) {
        this.listeners = new ArrayList<IterationListener>(Arrays.asList(listeners));
    }

    @Override
    public void setIndex(int index) {
        this.index = index;
    }

    @Override
    public int getIndex() {
        return this.index;
    }

    public int[] getShape(INDArray x) {
        if (x.rank() == 2 || x.rank() == 4) {
            return new int[]{1, x.size(1)};
        }
        if (x.rank() == 3) {
            int wDim = x.size(1);
            int hdim = x.size(2);
            if (x.size(0) > 1 && wDim * hdim == x.length()) {
                throw new IllegalArgumentException("Illegal input for batch size");
            }
            return new int[]{1, wDim * hdim};
        }
        throw new IllegalStateException("Unable to process input of rank " + x.rank());
    }
}

