/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.samediff;

import java.util.Map;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.util.ArrayUtil;

public abstract class SameDiffVertex
extends org.deeplearning4j.nn.conf.graph.GraphVertex
implements TrainingConfig {
    private SDVertexParams vertexParams;
    private String name;
    protected double l1 = Double.NaN;
    protected double l2 = Double.NaN;
    protected double l1Bias = Double.NaN;
    protected double l2Bias = Double.NaN;
    protected IUpdater updater;
    protected IUpdater biasUpdater;
    protected GradientNormalization gradientNormalization;
    protected double gradientNormalizationThreshold = Double.NaN;

    public abstract SDVariable defineVertex(SameDiff var1, Map<String, SDVariable> var2, Map<String, SDVariable> var3);

    public abstract void defineParametersAndInputs(SDVertexParams var1);

    public abstract void initializeParameters(Map<String, INDArray> var1);

    public SDVertexParams getVertexParams() {
        if (this.vertexParams == null) {
            this.vertexParams = new SDVertexParams();
            this.defineParametersAndInputs(this.vertexParams);
        }
        return this.vertexParams;
    }

    @Override
    public org.deeplearning4j.nn.conf.graph.GraphVertex clone() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public long numParams(boolean backprop) {
        SDVertexParams params = this.getVertexParams();
        long count = 0L;
        for (long[] l : params.getParamShapes().values()) {
            count += ArrayUtil.prodLong((long[])l);
        }
        return (int)count;
    }

    @Override
    public int minVertexInputs() {
        return 1;
    }

    @Override
    public int maxVertexInputs() {
        return -1;
    }

    @Override
    public GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) {
        this.name = name;
        return new SameDiffGraphVertex(this, graph, name, idx, paramsView, initializeParams);
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType ... vertexInputs) throws InvalidInputTypeException {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public MemoryReport getMemoryReport(InputType ... inputTypes) {
        return null;
    }

    public char paramReshapeOrder(String paramName) {
        return 'c';
    }

    public void applyGlobalConfig(NeuralNetConfiguration.Builder b) {
        if (Double.isNaN(this.l1)) {
            this.l1 = b.getL1();
        }
        if (Double.isNaN(this.l2)) {
            this.l2 = b.getL2();
        }
        if (Double.isNaN(this.l1Bias)) {
            this.l1Bias = b.getL1Bias();
        }
        if (Double.isNaN(this.l2Bias)) {
            this.l2Bias = b.getL2Bias();
        }
        if (this.updater == null) {
            this.updater = b.getIUpdater();
        }
        if (this.biasUpdater == null) {
            this.biasUpdater = b.getBiasUpdater();
        }
        if (this.gradientNormalization == null) {
            this.gradientNormalization = b.getGradientNormalization();
        }
        if (Double.isNaN(this.gradientNormalizationThreshold)) {
            this.gradientNormalizationThreshold = b.getGradientNormalizationThreshold();
        }
        this.applyGlobalConfigToLayer(b);
    }

    public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) {
    }

    @Override
    public String getLayerName() {
        return this.name;
    }

    @Override
    public boolean isPretrain() {
        return false;
    }

    @Override
    public double getL1ByParam(String paramName) {
        if (this.l1 == 0.0 && this.l1Bias == 0.0) {
            return 0.0;
        }
        if (this.getVertexParams().isWeightParam(paramName)) {
            return this.l1;
        }
        if (this.getVertexParams().isBiasParam(paramName)) {
            return this.l1Bias;
        }
        throw new IllegalStateException("Unknown parameter name: " + paramName + " - not in weights (" + this.getVertexParams().getWeightParameterKeys() + ") or biases (" + this.getVertexParams().getBiasParameterKeys() + ")");
    }

    @Override
    public double getL2ByParam(String paramName) {
        if (this.l2 == 0.0 && this.l2Bias == 0.0) {
            return 0.0;
        }
        if (this.getVertexParams().isWeightParam(paramName)) {
            return this.l2;
        }
        if (this.getVertexParams().isBiasParam(paramName)) {
            return this.l2Bias;
        }
        throw new IllegalStateException("Unknown parameter name: " + paramName + " - not in weights (" + this.getVertexParams().getWeightParameterKeys() + ") or biases (" + this.getVertexParams().getBiasParameterKeys() + ")");
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        return false;
    }

    @Override
    public IUpdater getUpdaterByParam(String paramName) {
        if (this.getVertexParams().isWeightParam(paramName)) {
            return this.updater;
        }
        if (this.getVertexParams().isBiasParam(paramName)) {
            if (this.biasUpdater == null) {
                return this.updater;
            }
            return this.biasUpdater;
        }
        throw new IllegalStateException("Unknown parameter name: " + paramName + " - not in weights (" + this.getVertexParams().getWeightParameterKeys() + ") or biases (" + this.getVertexParams().getBiasParameterKeys() + ")");
    }

    @Override
    public GradientNormalization getGradientNormalization() {
        return this.gradientNormalization;
    }

    @Override
    public double getGradientNormalizationThreshold() {
        return this.gradientNormalizationThreshold;
    }

    @Override
    public void setPretrain(boolean pretrain) {
    }

    public String getName() {
        return this.name;
    }

    public double getL1() {
        return this.l1;
    }

    public double getL2() {
        return this.l2;
    }

    public double getL1Bias() {
        return this.l1Bias;
    }

    public double getL2Bias() {
        return this.l2Bias;
    }

    public IUpdater getUpdater() {
        return this.updater;
    }

    public IUpdater getBiasUpdater() {
        return this.biasUpdater;
    }

    public void setVertexParams(SDVertexParams vertexParams) {
        this.vertexParams = vertexParams;
    }

    public void setName(String name) {
        this.name = name;
    }

    public void setL1(double l1) {
        this.l1 = l1;
    }

    public void setL2(double l2) {
        this.l2 = l2;
    }

    public void setL1Bias(double l1Bias) {
        this.l1Bias = l1Bias;
    }

    public void setL2Bias(double l2Bias) {
        this.l2Bias = l2Bias;
    }

    public void setUpdater(IUpdater updater) {
        this.updater = updater;
    }

    public void setBiasUpdater(IUpdater biasUpdater) {
        this.biasUpdater = biasUpdater;
    }

    public void setGradientNormalization(GradientNormalization gradientNormalization) {
        this.gradientNormalization = gradientNormalization;
    }

    public void setGradientNormalizationThreshold(double gradientNormalizationThreshold) {
        this.gradientNormalizationThreshold = gradientNormalizationThreshold;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof SameDiffVertex)) {
            return false;
        }
        SameDiffVertex other = (SameDiffVertex)o;
        if (!other.canEqual(this)) {
            return false;
        }
        SDVertexParams this$vertexParams = this.getVertexParams();
        SDVertexParams other$vertexParams = other.getVertexParams();
        if (this$vertexParams == null ? other$vertexParams != null : !((Object)this$vertexParams).equals(other$vertexParams)) {
            return false;
        }
        String this$name = this.getName();
        String other$name = other.getName();
        if (this$name == null ? other$name != null : !this$name.equals(other$name)) {
            return false;
        }
        if (Double.compare(this.getL1(), other.getL1()) != 0) {
            return false;
        }
        if (Double.compare(this.getL2(), other.getL2()) != 0) {
            return false;
        }
        if (Double.compare(this.getL1Bias(), other.getL1Bias()) != 0) {
            return false;
        }
        if (Double.compare(this.getL2Bias(), other.getL2Bias()) != 0) {
            return false;
        }
        IUpdater this$updater = this.getUpdater();
        IUpdater other$updater = other.getUpdater();
        if (this$updater == null ? other$updater != null : !this$updater.equals(other$updater)) {
            return false;
        }
        IUpdater this$biasUpdater = this.getBiasUpdater();
        IUpdater other$biasUpdater = other.getBiasUpdater();
        if (this$biasUpdater == null ? other$biasUpdater != null : !this$biasUpdater.equals(other$biasUpdater)) {
            return false;
        }
        GradientNormalization this$gradientNormalization = this.getGradientNormalization();
        GradientNormalization other$gradientNormalization = other.getGradientNormalization();
        if (this$gradientNormalization == null ? other$gradientNormalization != null : !((Object)((Object)this$gradientNormalization)).equals((Object)other$gradientNormalization)) {
            return false;
        }
        return Double.compare(this.getGradientNormalizationThreshold(), other.getGradientNormalizationThreshold()) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof SameDiffVertex;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        SDVertexParams $vertexParams = this.getVertexParams();
        result = result * 59 + ($vertexParams == null ? 43 : ((Object)$vertexParams).hashCode());
        String $name = this.getName();
        result = result * 59 + ($name == null ? 43 : $name.hashCode());
        long $l1 = Double.doubleToLongBits(this.getL1());
        result = result * 59 + (int)($l1 >>> 32 ^ $l1);
        long $l2 = Double.doubleToLongBits(this.getL2());
        result = result * 59 + (int)($l2 >>> 32 ^ $l2);
        long $l1Bias = Double.doubleToLongBits(this.getL1Bias());
        result = result * 59 + (int)($l1Bias >>> 32 ^ $l1Bias);
        long $l2Bias = Double.doubleToLongBits(this.getL2Bias());
        result = result * 59 + (int)($l2Bias >>> 32 ^ $l2Bias);
        IUpdater $updater = this.getUpdater();
        result = result * 59 + ($updater == null ? 43 : $updater.hashCode());
        IUpdater $biasUpdater = this.getBiasUpdater();
        result = result * 59 + ($biasUpdater == null ? 43 : $biasUpdater.hashCode());
        GradientNormalization $gradientNormalization = this.getGradientNormalization();
        result = result * 59 + ($gradientNormalization == null ? 43 : ((Object)((Object)$gradientNormalization)).hashCode());
        long $gradientNormalizationThreshold = Double.doubleToLongBits(this.getGradientNormalizationThreshold());
        result = result * 59 + (int)($gradientNormalizationThreshold >>> 32 ^ $gradientNormalizationThreshold);
        return result;
    }

    public String toString() {
        return "SameDiffVertex(vertexParams=" + this.getVertexParams() + ", name=" + this.getName() + ", l1=" + this.getL1() + ", l2=" + this.getL2() + ", l1Bias=" + this.getL1Bias() + ", l2Bias=" + this.getL2Bias() + ", updater=" + this.getUpdater() + ", biasUpdater=" + this.getBiasUpdater() + ", gradientNormalization=" + (Object)((Object)this.getGradientNormalization()) + ", gradientNormalizationThreshold=" + this.getGradientNormalizationThreshold() + ")";
    }
}

