/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.normalization;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.util.OneTimeLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LocalResponseNormalization
extends AbstractLayer<org.deeplearning4j.nn.conf.layers.LocalResponseNormalization> {
    protected static final Logger log = LoggerFactory.getLogger(org.deeplearning4j.nn.conf.layers.LocalResponseNormalization.class);
    protected LocalResponseNormalizationHelper helper = null;
    protected int helperCountFail = 0;

    public Layer clone() {
        return new LocalResponseNormalization(this.conf.clone(), this.dataType);
    }

    public LocalResponseNormalization(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
        this.initializeHelper();
    }

    void initializeHelper() {
        String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
        if ("CUDA".equalsIgnoreCase(backend)) {
            try {
                this.helper = Class.forName("org.deeplearning4j.nn.layers.normalization.CudnnLocalResponseNormalizationHelper").asSubclass(LocalResponseNormalizationHelper.class).getConstructor(DataType.class).newInstance(this.dataType);
                log.debug("CudnnLocalResponseNormalizationHelper successfully initialized");
            }
            catch (Throwable t) {
                if (!(t instanceof ClassNotFoundException)) {
                    log.warn("Could not initialize CudnnLocalResponseNormalizationHelper", t);
                }
                OneTimeLogger.info((Logger)log, (String)"cuDNN not found: use cuDNN for better GPU performance by including the deeplearning4j-cuda module. For more information, please refer to: https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn", (Object[])new Object[]{t});
            }
        }
        if (this.helper != null && !this.helper.checkSupported(((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getK(), ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getN(), ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getAlpha(), ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getBeta())) {
            log.debug("Removed helper {} as not supported (k={}, n={}, alpha={}, beta={})", new Object[]{this.helper.getClass(), ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getK(), ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getN(), ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getAlpha(), ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getBeta()});
            this.helper = null;
        }
    }

    @Override
    public double calcRegularizationScore(boolean backpropParamsOnly) {
        return 0.0;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override
    public void fit(INDArray input, LayerWorkspaceMgr workspaceMgr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(true);
        double k = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getK();
        double n = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getN();
        double alpha = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getAlpha();
        double beta = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getBeta();
        int halfN = (int)n / 2;
        if (!(this.helper == null || this.helperCountFail != 0 && ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).isCudnnAllowFallback())) {
            Pair<Gradient, INDArray> ret = null;
            try {
                ret = this.helper.backpropGradient(this.input, epsilon, k, n, alpha, beta, workspaceMgr);
            }
            catch (Throwable t) {
                if (t.getMessage().contains("Failed to allocate")) {
                    throw t;
                }
                if (((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).isCudnnAllowFallback()) {
                    ++this.helperCountFail;
                    log.warn("CuDNN LocalResponseNormalization backprop execution failed - falling back on built-in implementation", t);
                }
                throw new RuntimeException("Error during LocalResponseNormalization CuDNN helper backprop - isCudnnAllowFallback() is set to false", t);
            }
            if (ret != null) {
                return ret;
            }
        }
        Triple<INDArray, INDArray, INDArray> triple = this.activateHelper(true, workspaceMgr, true);
        INDArray activations = (INDArray)triple.getFirst();
        INDArray unitScale = (INDArray)triple.getSecond();
        INDArray scale = (INDArray)triple.getThird();
        long channel = this.input.size(1);
        DefaultGradient retGradient = new DefaultGradient();
        INDArray reverse = activations.mul(epsilon);
        INDArray sumPart = reverse.dup();
        for (int i = 1; i < halfN + 1; ++i) {
            INDArray tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)i, (long)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            INDArray addVal = reverse.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)0L, (long)(channel - (long)i)), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)i, (long)channel), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
            tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)0L, (long)(channel - (long)i)), NDArrayIndex.all(), NDArrayIndex.all()});
            addVal = reverse.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)i, (long)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)0L, (long)(channel - (long)i)), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
        }
        INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), epsilon.shape(), epsilon.ordering());
        Nd4j.getExecutioner().exec((Op)new OldMulOp(epsilon, scale, nextEpsilon));
        nextEpsilon.subi(sumPart.muli(this.input).divi(unitScale).muli((Number)(2.0 * alpha * beta)));
        return new Pair((Object)retGradient, (Object)nextEpsilon);
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        return (INDArray)this.activateHelper(training, workspaceMgr, false).getFirst();
    }

    private Triple<INDArray, INDArray, INDArray> activateHelper(boolean training, LayerWorkspaceMgr workspaceMgr, boolean forBackprop) {
        this.assertInputSet(false);
        double k = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getK();
        double n = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getN();
        double alpha = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getAlpha();
        double beta = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getBeta();
        int halfN = (int)n / 2;
        if (!(this.helper == null || this.helperCountFail != 0 && ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).isCudnnAllowFallback())) {
            INDArray activations = null;
            try {
                activations = this.helper.activate(this.input, training, k, n, alpha, beta, workspaceMgr);
            }
            catch (Throwable t) {
                if (t.getMessage().contains("Failed to allocate")) {
                    throw t;
                }
                if (((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).isCudnnAllowFallback()) {
                    ++this.helperCountFail;
                    log.warn("CuDNN LocalResponseNormalization backprop execution failed - falling back on built-in implementation", t);
                }
                throw new RuntimeException("Error during LocalRsponseNormalization CuDNN helper backprop - isCudnnAllowFallback() is set to false", t);
            }
            if (activations != null) {
                return new Triple((Object)activations, null, null);
            }
        }
        long channel = this.input.size(1);
        INDArray activitySqr = this.input.mul(this.input);
        INDArray sumPart = activitySqr.dup();
        for (int i = 1; i < halfN + 1; ++i) {
            INDArray tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)i, (long)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            INDArray addVal = activitySqr.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)0L, (long)(channel - (long)i)), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)i, (long)channel), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
            tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)0L, (long)(channel - (long)i)), NDArrayIndex.all(), NDArrayIndex.all()});
            addVal = activitySqr.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)i, (long)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)0L, (long)(channel - (long)i)), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
        }
        INDArray unitScale = null;
        INDArray scale = null;
        INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, this.input.dataType(), this.input.shape(), this.input.ordering());
        if (forBackprop) {
            unitScale = sumPart.mul((Number)alpha).addi((Number)k);
            scale = Transforms.pow((INDArray)unitScale, (Number)(-beta), (boolean)true);
            Nd4j.getExecutioner().exec((Op)new OldMulOp(this.input, scale, activations));
        } else {
            sumPart.muli((Number)alpha, activations).addi((Number)k);
            Transforms.pow((INDArray)activations, (Number)(-beta), (boolean)false);
            activations.muli(this.input);
        }
        if (forBackprop) {
            return new Triple((Object)activations, (Object)unitScale, (Object)scale);
        }
        return new Triple((Object)activations, null, null);
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public void clearNoiseWeightParams() {
    }

    @Override
    public LayerHelper getHelper() {
        return this.helper;
    }

    @Override
    public INDArray params() {
        return null;
    }

    @Override
    public INDArray getParam(String param) {
        return this.params();
    }

    @Override
    public void setParams(INDArray params) {
    }
}

