/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.normalization;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

public class LocalResponseNormalization
extends BaseLayer<org.deeplearning4j.nn.conf.layers.LocalResponseNormalization> {
    private double k;
    private double n;
    private double alpha;
    private double beta;
    private int halfN;
    private INDArray activations;
    private INDArray unitScale;
    private INDArray scale;

    public LocalResponseNormalization(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    public LocalResponseNormalization(NeuralNetConfiguration conf) {
        super(conf);
    }

    @Override
    public double calcL2() {
        return 0.0;
    }

    @Override
    public double calcL1() {
        return 0.0;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override
    public void fit(INDArray input) {
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        int channel = this.input.shape()[1];
        DefaultGradient retGradient = new DefaultGradient();
        INDArray reverse = this.activations.mul(epsilon);
        INDArray sumPart = reverse.dup();
        for (int i = 1; i < this.halfN + 1; ++i) {
            INDArray tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            INDArray addVal = reverse.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
            tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()});
            addVal = reverse.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
        }
        INDArray nextEpsilon = epsilon.mul(this.scale).sub(this.input.mul((Number)(2.0 * this.alpha * this.beta)).mul(sumPart.div(this.unitScale)));
        return new Pair<Gradient, INDArray>(retGradient, nextEpsilon);
    }

    @Override
    public INDArray activate(boolean training) {
        this.k = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getK();
        this.n = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getN();
        this.alpha = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getAlpha();
        this.beta = ((org.deeplearning4j.nn.conf.layers.LocalResponseNormalization)this.layerConf()).getBeta();
        this.halfN = (int)this.n / 2;
        int channel = this.input.shape()[1];
        INDArray activitySqr = this.input.mul(this.input);
        INDArray sumPart = activitySqr.dup();
        for (int i = 1; i < this.halfN + 1; ++i) {
            INDArray tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            INDArray addVal = activitySqr.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
            tmp = sumPart.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()});
            addVal = activitySqr.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)i, (int)channel), NDArrayIndex.all(), NDArrayIndex.all()});
            sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)(channel - i)), NDArrayIndex.all(), NDArrayIndex.all()}, tmp.addi(addVal));
        }
        this.unitScale = sumPart.mul((Number)this.alpha).add((Number)this.k);
        this.scale = Transforms.pow((INDArray)this.unitScale, (Number)(-this.beta));
        this.activations = this.input.mul(this.scale);
        return this.activations;
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Gradient calcGradient(Gradient layerError, INDArray indArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public void merge(Layer layer, int batchSize) {
        throw new UnsupportedOperationException();
    }
}

