/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.samediff;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.params.SameDiffParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.NetworkUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractSameDiffLayer
extends Layer {
    private static final Logger log = LoggerFactory.getLogger(AbstractSameDiffLayer.class);
    protected List<Regularization> regularization;
    protected List<Regularization> regularizationBias;
    protected IUpdater updater;
    protected IUpdater biasUpdater;
    protected GradientNormalization gradientNormalization;
    protected double gradientNormalizationThreshold = Double.NaN;
    private SDLayerParams layerParams;

    @Override
    public List<Regularization> getRegularizationByParam(String paramName) {
        if (this.layerParams.isWeightParam(paramName)) {
            return this.regularization;
        }
        if (this.layerParams.isBiasParam(paramName)) {
            return this.regularizationBias;
        }
        return null;
    }

    protected AbstractSameDiffLayer(Builder builder) {
        super(builder);
        this.regularization = builder.regularization;
        this.regularizationBias = builder.regularizationBias;
        this.updater = builder.updater;
        this.biasUpdater = builder.biasUpdater;
        try {
            this.getClass().getDeclaredConstructor(new Class[0]);
        }
        catch (NoSuchMethodException e) {
            log.warn("***SameDiff layer {} does not have a zero argument (no-arg) constructor.***\nA no-arg constructor is required for JSON deserialization, which is used for both model saving and distributed (Spark) training.\nA no-arg constructor (private, protected or public) as well as setters (or simply a Lombok @Data annotation) should be added to avoid JSON errors later.", (Object)this.getClass().getName());
        }
        catch (SecurityException securityException) {
            // empty catch block
        }
    }

    protected AbstractSameDiffLayer() {
    }

    public SDLayerParams getLayerParams() {
        if (this.layerParams == null) {
            this.layerParams = new SDLayerParams();
            this.defineParameters(this.layerParams);
        }
        return this.layerParams;
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return null;
    }

    public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) {
    }

    public abstract void defineParameters(SDLayerParams var1);

    public abstract void initializeParameters(Map<String, INDArray> var1);

    @Override
    public abstract org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration var1, Collection<TrainingListener> var2, int var3, INDArray var4, boolean var5, DataType var6);

    @Override
    public ParamInitializer initializer() {
        return SameDiffParamInitializer.getInstance();
    }

    @Override
    public IUpdater getUpdaterByParam(String paramName) {
        if (this.biasUpdater != null && this.initializer().isBiasParam(this, paramName)) {
            return this.biasUpdater;
        }
        if (this.initializer().isBiasParam(this, paramName) || this.initializer().isWeightParam(this, paramName)) {
            return this.updater;
        }
        throw new IllegalStateException("Unknown parameter key: " + paramName);
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        return false;
    }

    @Override
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        return new LayerMemoryReport();
    }

    public char paramReshapeOrder(String param) {
        return 'c';
    }

    protected void initWeights(int fanIn, int fanOut, WeightInit weightInit, INDArray array) {
        WeightInitUtil.initWeights((double)fanIn, (double)fanOut, array.shape(), weightInit, null, this.paramReshapeOrder(null), array);
    }

    public void applyGlobalConfig(NeuralNetConfiguration.Builder b) {
        if (this.regularization == null || this.regularization.isEmpty()) {
            this.regularization = b.getRegularization();
        }
        if (this.regularizationBias == null || this.regularizationBias.isEmpty()) {
            this.regularizationBias = b.getRegularizationBias();
        }
        if (this.updater == null) {
            this.updater = b.getIUpdater();
        }
        if (this.biasUpdater == null) {
            this.biasUpdater = b.getBiasUpdater();
        }
        if (this.gradientNormalization == null) {
            this.gradientNormalization = b.getGradientNormalization();
        }
        if (Double.isNaN(this.gradientNormalizationThreshold)) {
            this.gradientNormalizationThreshold = b.getGradientNormalizationThreshold();
        }
        this.applyGlobalConfigToLayer(b);
    }

    public INDArray onesMaskForInput(INDArray input) {
        if (input.rank() == 2) {
            return Nd4j.ones((DataType)input.dataType(), (long[])new long[]{input.size(0), 1L});
        }
        if (input.rank() == 3) {
            return Nd4j.ones((DataType)input.dataType(), (long[])new long[]{input.size(0), input.size(2)});
        }
        if (input.rank() == 4) {
            return Nd4j.ones((DataType)input.dataType(), (long[])new long[]{input.size(0), 1L, 1L, 1L});
        }
        if (input.rank() == 5) {
            return Nd4j.ones((DataType)input.dataType(), (long[])new long[]{input.size(0), 1L, 1L, 1L, 1L});
        }
        throw new IllegalStateException("When using masking with rank 1 or 6+ inputs, the onesMaskForInput method must be implemented, in order to determine the correct mask shape for this layer");
    }

    public List<Regularization> getRegularization() {
        return this.regularization;
    }

    public List<Regularization> getRegularizationBias() {
        return this.regularizationBias;
    }

    public IUpdater getUpdater() {
        return this.updater;
    }

    public IUpdater getBiasUpdater() {
        return this.biasUpdater;
    }

    @Override
    public GradientNormalization getGradientNormalization() {
        return this.gradientNormalization;
    }

    @Override
    public double getGradientNormalizationThreshold() {
        return this.gradientNormalizationThreshold;
    }

    public void setRegularization(List<Regularization> regularization) {
        this.regularization = regularization;
    }

    public void setRegularizationBias(List<Regularization> regularizationBias) {
        this.regularizationBias = regularizationBias;
    }

    public void setUpdater(IUpdater updater) {
        this.updater = updater;
    }

    public void setBiasUpdater(IUpdater biasUpdater) {
        this.biasUpdater = biasUpdater;
    }

    public void setGradientNormalization(GradientNormalization gradientNormalization) {
        this.gradientNormalization = gradientNormalization;
    }

    public void setGradientNormalizationThreshold(double gradientNormalizationThreshold) {
        this.gradientNormalizationThreshold = gradientNormalizationThreshold;
    }

    public void setLayerParams(SDLayerParams layerParams) {
        this.layerParams = layerParams;
    }

    @Override
    public String toString() {
        return "AbstractSameDiffLayer(regularization=" + this.getRegularization() + ", regularizationBias=" + this.getRegularizationBias() + ", updater=" + this.getUpdater() + ", biasUpdater=" + this.getBiasUpdater() + ", gradientNormalization=" + (Object)((Object)this.getGradientNormalization()) + ", gradientNormalizationThreshold=" + this.getGradientNormalizationThreshold() + ", layerParams=" + this.getLayerParams() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AbstractSameDiffLayer)) {
            return false;
        }
        AbstractSameDiffLayer other = (AbstractSameDiffLayer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (Double.compare(this.gradientNormalizationThreshold, other.gradientNormalizationThreshold) != 0) {
            return false;
        }
        List<Regularization> this$regularization = this.regularization;
        List<Regularization> other$regularization = other.regularization;
        if (this$regularization == null ? other$regularization != null : !((Object)this$regularization).equals(other$regularization)) {
            return false;
        }
        List<Regularization> this$regularizationBias = this.regularizationBias;
        List<Regularization> other$regularizationBias = other.regularizationBias;
        if (this$regularizationBias == null ? other$regularizationBias != null : !((Object)this$regularizationBias).equals(other$regularizationBias)) {
            return false;
        }
        IUpdater this$updater = this.updater;
        IUpdater other$updater = other.updater;
        if (this$updater == null ? other$updater != null : !this$updater.equals(other$updater)) {
            return false;
        }
        IUpdater this$biasUpdater = this.biasUpdater;
        IUpdater other$biasUpdater = other.biasUpdater;
        if (this$biasUpdater == null ? other$biasUpdater != null : !this$biasUpdater.equals(other$biasUpdater)) {
            return false;
        }
        GradientNormalization this$gradientNormalization = this.gradientNormalization;
        GradientNormalization other$gradientNormalization = other.gradientNormalization;
        if (this$gradientNormalization == null ? other$gradientNormalization != null : !((Object)((Object)this$gradientNormalization)).equals((Object)other$gradientNormalization)) {
            return false;
        }
        SDLayerParams this$layerParams = this.layerParams;
        SDLayerParams other$layerParams = other.layerParams;
        return !(this$layerParams == null ? other$layerParams != null : !((Object)this$layerParams).equals(other$layerParams));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof AbstractSameDiffLayer;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        long $gradientNormalizationThreshold = Double.doubleToLongBits(this.gradientNormalizationThreshold);
        result = result * 59 + (int)($gradientNormalizationThreshold >>> 32 ^ $gradientNormalizationThreshold);
        List<Regularization> $regularization = this.regularization;
        result = result * 59 + ($regularization == null ? 43 : ((Object)$regularization).hashCode());
        List<Regularization> $regularizationBias = this.regularizationBias;
        result = result * 59 + ($regularizationBias == null ? 43 : ((Object)$regularizationBias).hashCode());
        IUpdater $updater = this.updater;
        result = result * 59 + ($updater == null ? 43 : $updater.hashCode());
        IUpdater $biasUpdater = this.biasUpdater;
        result = result * 59 + ($biasUpdater == null ? 43 : $biasUpdater.hashCode());
        GradientNormalization $gradientNormalization = this.gradientNormalization;
        result = result * 59 + ($gradientNormalization == null ? 43 : ((Object)((Object)$gradientNormalization)).hashCode());
        SDLayerParams $layerParams = this.layerParams;
        result = result * 59 + ($layerParams == null ? 43 : ((Object)$layerParams).hashCode());
        return result;
    }

    public static abstract class Builder<T extends Builder<T>>
    extends Layer.Builder<T> {
        protected List<Regularization> regularization = new ArrayList<Regularization>();
        protected List<Regularization> regularizationBias = new ArrayList<Regularization>();
        protected IUpdater updater = null;
        protected IUpdater biasUpdater = null;

        public T l1(double l1) {
            NetworkUtils.removeInstances(this.regularization, L1Regularization.class);
            if (l1 > 0.0) {
                this.regularization.add((Regularization)new L1Regularization(l1));
            }
            return (T)this;
        }

        public T l2(double l2) {
            NetworkUtils.removeInstances(this.regularization, L2Regularization.class);
            if (l2 > 0.0) {
                NetworkUtils.removeInstancesWithWarning(this.regularization, WeightDecay.class, "WeightDecay regularization removed: incompatible with added L2 regularization");
                this.regularization.add((Regularization)new L2Regularization(l2));
            }
            return (T)this;
        }

        public T l1Bias(double l1Bias) {
            NetworkUtils.removeInstances(this.regularizationBias, L1Regularization.class);
            if (l1Bias > 0.0) {
                this.regularizationBias.add((Regularization)new L1Regularization(l1Bias));
            }
            return (T)this;
        }

        public T l2Bias(double l2Bias) {
            NetworkUtils.removeInstances(this.regularizationBias, L2Regularization.class);
            if (l2Bias > 0.0) {
                NetworkUtils.removeInstancesWithWarning(this.regularizationBias, WeightDecay.class, "WeightDecay bias regularization removed: incompatible with added L2 regularization");
                this.regularizationBias.add((Regularization)new L2Regularization(l2Bias));
            }
            return (T)this;
        }

        public Builder weightDecay(double coefficient) {
            return this.weightDecay(coefficient, true);
        }

        public Builder weightDecay(double coefficient, boolean applyLR) {
            NetworkUtils.removeInstances(this.regularization, WeightDecay.class);
            if (coefficient > 0.0) {
                NetworkUtils.removeInstancesWithWarning(this.regularization, L2Regularization.class, "L2 regularization removed: incompatible with added WeightDecay regularization");
                this.regularization.add((Regularization)new WeightDecay(coefficient, applyLR));
            }
            return this;
        }

        public Builder weightDecayBias(double coefficient) {
            return this.weightDecayBias(coefficient, true);
        }

        public Builder weightDecayBias(double coefficient, boolean applyLR) {
            NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class);
            if (coefficient > 0.0) {
                NetworkUtils.removeInstancesWithWarning(this.regularizationBias, L2Regularization.class, "L2 bias regularization removed: incompatible with added WeightDecay regularization");
                this.regularizationBias.add((Regularization)new WeightDecay(coefficient, applyLR));
            }
            return this;
        }

        public Builder regularization(List<Regularization> regularization) {
            this.setRegularization(regularization);
            return this;
        }

        public Builder regularizationBias(List<Regularization> regularizationBias) {
            this.setRegularizationBias(regularizationBias);
            return this;
        }

        public T updater(IUpdater updater) {
            this.setUpdater(updater);
            return (T)this;
        }

        public T biasUpdater(IUpdater biasUpdater) {
            this.setBiasUpdater(biasUpdater);
            return (T)this;
        }

        public List<Regularization> getRegularization() {
            return this.regularization;
        }

        public List<Regularization> getRegularizationBias() {
            return this.regularizationBias;
        }

        public IUpdater getUpdater() {
            return this.updater;
        }

        public IUpdater getBiasUpdater() {
            return this.biasUpdater;
        }

        public void setRegularization(List<Regularization> regularization) {
            this.regularization = regularization;
        }

        public void setRegularizationBias(List<Regularization> regularizationBias) {
            this.regularizationBias = regularizationBias;
        }

        public void setUpdater(IUpdater updater) {
            this.updater = updater;
        }

        public void setBiasUpdater(IUpdater biasUpdater) {
            this.biasUpdater = biasUpdater;
        }
    }
}

