/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.regularization.Regularization;

public class LocalResponseNormalization
extends Layer {
    protected double n = 5.0;
    protected double k = 2.0;
    protected double beta = 0.75;
    protected double alpha = 1.0E-4;
    protected boolean cudnnAllowFallback = true;

    private LocalResponseNormalization(Builder builder) {
        super(builder);
        this.k = builder.k;
        this.n = builder.n;
        this.alpha = builder.alpha;
        this.beta = builder.beta;
        this.cudnnAllowFallback = builder.cudnnAllowFallback;
    }

    @Override
    public LocalResponseNormalization clone() {
        LocalResponseNormalization clone = (LocalResponseNormalization)super.clone();
        return clone;
    }

    @Override
    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
        org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization ret = new org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization(conf, networkDataType);
        ret.setListeners(trainingListeners);
        ret.setIndex(layerIndex);
        ret.setParamsViewArray(layerParamsView);
        Map<String, INDArray> paramTable = this.initializer().init(conf, layerParamsView, initializeParams);
        ret.setParamTable(paramTable);
        ret.setConf(conf);
        return ret;
    }

    @Override
    public ParamInitializer initializer() {
        return EmptyParamInitializer.getInstance();
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN) {
            throw new IllegalStateException("Invalid input type for LRN layer (layer index = " + layerIndex + ", layer name = \"" + this.getLayerName() + "\"): Expected input of type CNN, got " + inputType);
        }
        return inputType;
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        if (inputType == null) {
            throw new IllegalStateException("Invalid input type for LRN layer (layer name = \"" + this.getLayerName() + "\"): null");
        }
        return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, this.getLayerName());
    }

    @Override
    public List<Regularization> getRegularizationByParam(String paramName) {
        return null;
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        return false;
    }

    @Override
    public GradientNormalization getGradientNormalization() {
        return GradientNormalization.None;
    }

    @Override
    public double getGradientNormalizationThreshold() {
        return 0.0;
    }

    @Override
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        long actElementsPerEx = inputType.arrayElementsPerExample();
        return new LayerMemoryReport.Builder(this.layerName, DenseLayer.class, inputType, inputType).standardMemory(0L, 0L).workingMemory(0L, 2L * actElementsPerEx, 0L, 3L * actElementsPerEx).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS).build();
    }

    public double getN() {
        return this.n;
    }

    public double getK() {
        return this.k;
    }

    public double getBeta() {
        return this.beta;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public boolean isCudnnAllowFallback() {
        return this.cudnnAllowFallback;
    }

    public void setN(double n) {
        this.n = n;
    }

    public void setK(double k) {
        this.k = k;
    }

    public void setBeta(double beta) {
        this.beta = beta;
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    public void setCudnnAllowFallback(boolean cudnnAllowFallback) {
        this.cudnnAllowFallback = cudnnAllowFallback;
    }

    public LocalResponseNormalization() {
    }

    @Override
    public String toString() {
        return "LocalResponseNormalization(super=" + super.toString() + ", n=" + this.getN() + ", k=" + this.getK() + ", beta=" + this.getBeta() + ", alpha=" + this.getAlpha() + ", cudnnAllowFallback=" + this.isCudnnAllowFallback() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LocalResponseNormalization)) {
            return false;
        }
        LocalResponseNormalization other = (LocalResponseNormalization)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (Double.compare(this.getN(), other.getN()) != 0) {
            return false;
        }
        if (Double.compare(this.getK(), other.getK()) != 0) {
            return false;
        }
        if (Double.compare(this.getBeta(), other.getBeta()) != 0) {
            return false;
        }
        if (Double.compare(this.getAlpha(), other.getAlpha()) != 0) {
            return false;
        }
        return this.isCudnnAllowFallback() == other.isCudnnAllowFallback();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof LocalResponseNormalization;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        long $n = Double.doubleToLongBits(this.getN());
        result = result * 59 + (int)($n >>> 32 ^ $n);
        long $k = Double.doubleToLongBits(this.getK());
        result = result * 59 + (int)($k >>> 32 ^ $k);
        long $beta = Double.doubleToLongBits(this.getBeta());
        result = result * 59 + (int)($beta >>> 32 ^ $beta);
        long $alpha = Double.doubleToLongBits(this.getAlpha());
        result = result * 59 + (int)($alpha >>> 32 ^ $alpha);
        result = result * 59 + (this.isCudnnAllowFallback() ? 79 : 97);
        return result;
    }

    public static class Builder
    extends Layer.Builder<Builder> {
        private double k = 2.0;
        private double n = 5.0;
        private double alpha = 1.0E-4;
        private double beta = 0.75;
        protected boolean cudnnAllowFallback = true;

        public Builder(double k, double n, double alpha, double beta) {
            this(k, n, alpha, beta, true);
        }

        public Builder(double k, double alpha, double beta) {
            this.setK(k);
            this.setAlpha(alpha);
            this.setBeta(beta);
        }

        public Builder() {
        }

        public Builder k(double k) {
            this.setK(k);
            return this;
        }

        public Builder n(double n) {
            this.setN(n);
            return this;
        }

        public Builder alpha(double alpha) {
            this.setAlpha(alpha);
            return this;
        }

        public Builder beta(double beta) {
            this.setBeta(beta);
            return this;
        }

        public Builder cudnnAllowFallback(boolean allowFallback) {
            this.setCudnnAllowFallback(allowFallback);
            return this;
        }

        @Override
        public LocalResponseNormalization build() {
            return new LocalResponseNormalization(this);
        }

        public Builder(double k, double n, double alpha, double beta, boolean cudnnAllowFallback) {
            this.k = k;
            this.n = n;
            this.alpha = alpha;
            this.beta = beta;
            this.cudnnAllowFallback = cudnnAllowFallback;
        }

        public double getK() {
            return this.k;
        }

        public double getN() {
            return this.n;
        }

        public double getAlpha() {
            return this.alpha;
        }

        public double getBeta() {
            return this.beta;
        }

        public boolean isCudnnAllowFallback() {
            return this.cudnnAllowFallback;
        }

        public void setK(double k) {
            this.k = k;
        }

        public void setN(double n) {
            this.n = n;
        }

        public void setAlpha(double alpha) {
            this.alpha = alpha;
        }

        public void setBeta(double beta) {
            this.beta = beta;
        }

        public void setCudnnAllowFallback(boolean cudnnAllowFallback) {
            this.cudnnAllowFallback = cudnnAllowFallback;
        }
    }
}

