/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.normalization;

import java.util.Properties;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.util.OneTimeLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LocalResponseNormalization
extends AbstractLayer<org.deeplearning4j.nn.conf.layers.LocalResponseNormalization> {
    protected static final Logger log = LoggerFactory.getLogger(org.deeplearning4j.nn.conf.layers.LocalResponseNormalization.class);
    LocalResponseNormalizationHelper helper = null;
    private double k;
    private double n;
    private double alpha;
    private double beta;
    private int halfN;
    private INDArray activations;
    private INDArray unitScale;
    private INDArray scale;

    public LocalResponseNormalization(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
        this.initializeHelper();
    }

    @Override
    public Layer clone() {
        return new LocalResponseNormalization(this.conf.clone());
    }

    public LocalResponseNormalization(NeuralNetConfiguration conf) {
        super(conf);
        this.initializeHelper();
    }

    void initializeHelper() {
        block4: {
            try {
                this.helper = Class.forName("org.deeplearning4j.nn.layers.normalization.CudnnLocalResponseNormalizationHelper").asSubclass(LocalResponseNormalizationHelper.class).newInstance();
                log.debug("CudnnLocalResponseNormalizationHelper successfully initialized");
                if (!this.helper.checkSupported(((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getK(), ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getN(), ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getAlpha(), ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getBeta())) {
                    this.helper = null;
                }
            }
            catch (Throwable t) {
                if (!(t instanceof ClassNotFoundException)) {
                    log.warn("Could not initialize CudnnLocalResponseNormalizationHelper", t);
                }
                Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
                if (!p.getProperty("backend").equals("CUDA")) break block4;
                OneTimeLogger.info((Logger)log, (String)"cuDNN not found: use cuDNN for better GPU performance by including the deeplearning4j-cuda module. For more information, please refer to: https://deeplearning4j.org/cudnn", (Object[])new Object[]{t});
            }
        }
    }

    @Override
    public double calcL2(boolean backpropParamsOnly) {
        return 0.0;
    }

    @Override
    public double calcL1(boolean backpropParamsOnly) {
        return 0.0;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override
    public void fit(INDArray input) {
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        Pair<Gradient, INDArray> ret;
        if (this.helper != null && (ret = this.helper.backpropGradient(this.input, epsilon, this.k, this.n, this.alpha, this.beta)) != null) {
            return ret;
        }
        int channel = this.input.size(1);
        DefaultGradient retGradient = new DefaultGradient();
        INDArray reverse = this.activations.mul(epsilon);
        INDArray sumPart = reverse.dup();
        for (int i = 1; i < this.halfN + 1; ++i) {
            INDArray tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            INDArray addVal = reverse.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
            tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()});
            addVal = reverse.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
        }
        INDArray nextEpsilon = epsilon.mul(this.scale).subi(sumPart.muli(this.input).divi(this.unitScale).muli((Number)(2.0 * this.alpha * this.beta)));
        return new Pair((Object)retGradient, (Object)nextEpsilon);
    }

    @Override
    public INDArray activate(boolean training) {
        this.k = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getK();
        this.n = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getN();
        this.alpha = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getAlpha();
        this.beta = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getBeta();
        this.halfN = (int)this.n / 2;
        if (this.helper != null) {
            this.activations = this.helper.activate(this.input, training, this.k, this.n, this.alpha, this.beta);
            if (this.activations != null) {
                return this.activations;
            }
        }
        int channel = this.input.size(1);
        INDArray activitySqr = this.input.mul(this.input);
        INDArray sumPart = activitySqr.dup();
        for (int i = 1; i < this.halfN + 1; ++i) {
            INDArray tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            INDArray addVal = activitySqr.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
            tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()});
            addVal = activitySqr.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
        }
        this.unitScale = sumPart.mul((Number)this.alpha).addi((Number)this.k).leverageTo("LOOP_EXTERNAL");
        this.scale = Transforms.pow((INDArray)this.unitScale, (Number)(-this.beta)).leverageTo("LOOP_EXTERNAL");
        this.activations = this.input.mul(this.scale).leverageTo("LOOP_EXTERNAL");
        return this.activations;
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException("Not supported " + this.layerId());
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public void clearNoiseWeightParams() {
    }

    @Override
    public INDArray params() {
        return null;
    }

    @Override
    public INDArray getParam(String param) {
        return this.params();
    }

    @Override
    public void setParams(INDArray params) {
    }

    @Override
    public INDArray preOutput(boolean training) {
        return this.activate(training);
    }
}

