/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.functions;

import com.rits.cloning.Cloner;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.functions.FunctionProperties;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public abstract class DifferentialFunction {
    private static final Logger log = LoggerFactory.getLogger(DifferentialFunction.class);
    @JsonIgnore
    protected SameDiff sameDiff;
    @JsonIgnore
    protected boolean inPlace;
    @JsonIgnore
    protected Number scalarValue;
    @JsonIgnore
    protected int[] dimensions;
    @JsonIgnore
    protected Object[] extraArgs;
    @JsonIgnore
    private String ownName;

    public DifferentialFunction() {
        this.setInstanceId();
    }

    public DifferentialFunction(SameDiff sameDiff, NodeDef nodeDef, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        this.sameDiff = sameDiff;
        this.setInstanceId();
        this.initFromTensorFlow(nodeDef, sameDiff, attributesForNode, graph);
    }

    public DifferentialFunction(SameDiff sameDiff, OnnxProto3.NodeProto node, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
        this.sameDiff = sameDiff;
        this.setInstanceId();
        this.initFromOnnx(node, sameDiff, attributesForNode, graph);
    }

    public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
        return Collections.emptyMap();
    }

    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        return Collections.emptyMap();
    }

    public Map<String, Object> propertiesForFunction() {
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        LinkedHashMap<String, Object> ret = new LinkedHashMap<String, Object>();
        for (Map.Entry<String, Field> entry : fields.entrySet()) {
            try {
                ret.put(entry.getKey(), fields.get(entry.getKey()).get(this));
            }
            catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }
        return ret;
    }

    public Object getValue(Field property) {
        try {
            return property.get(this);
        }
        catch (IllegalAccessException e) {
            e.printStackTrace();
            return null;
        }
    }

    public void setValueFor(Field target, Object value) {
        if (value == null) {
            throw new ND4JIllegalStateException("Unable to set field " + target + " using null value!");
        }
        value = this.ensureProperType(target, value);
        try {
            target.set(this, value);
        }
        catch (IllegalAccessException e) {
            e.printStackTrace();
        }
    }

    private Object ensureProperType(Field targetType, Object value) {
        Class<?> valueType;
        Class<?> firstClass = targetType.getType();
        if (!firstClass.equals(valueType = value.getClass())) {
            if (firstClass.equals(int[].class)) {
                if (value instanceof Number) {
                    Number number = (Number)value;
                    value = number.intValue();
                }
                int otherValue = (Integer)value;
                int[] setValue = new int[]{otherValue};
                return setValue;
            }
            if (firstClass.equals(Integer[].class)) {
                if (value instanceof Number) {
                    Number number = (Number)value;
                    value = number.intValue();
                }
                Integer otherValue = (Integer)value;
                Integer[] setValue = new Integer[]{otherValue};
                return setValue;
            }
            if (firstClass.equals(long[].class)) {
                if (value instanceof Number) {
                    Number number = (Number)value;
                    value = number.longValue();
                }
                long otherValue = (Long)value;
                long[] setValue = new long[]{otherValue};
                return setValue;
            }
            if (firstClass.equals(Long[].class)) {
                if (value instanceof Number) {
                    Number number = (Number)value;
                    value = number.longValue();
                }
                Long otherValue = (Long)value;
                Long[] setValue = new Long[]{otherValue};
                return setValue;
            }
            if (firstClass.equals(double[].class)) {
                if (value instanceof Number) {
                    Number number = (Number)value;
                    value = number.doubleValue();
                }
                double otherValue = (Double)value;
                double[] setValue = new double[]{otherValue};
                return setValue;
            }
            if (firstClass.equals(Double[].class)) {
                if (value instanceof Number) {
                    Number number = (Number)value;
                    value = number.doubleValue();
                }
                Double otherValue = (Double)value;
                Double[] setValue = new Double[]{otherValue};
                return setValue;
            }
            if (firstClass.equals(float[].class)) {
                if (value instanceof Number) {
                    Number number = (Number)value;
                    value = Float.valueOf(number.floatValue());
                }
                float otherValue = ((Float)value).floatValue();
                float[] setValue = new float[]{otherValue};
                return setValue;
            }
            if (firstClass.equals(Float[].class)) {
                if (value instanceof Number) {
                    Number number = (Number)value;
                    value = Float.valueOf(number.floatValue());
                }
                Float otherValue = (Float)value;
                Float[] setValue = new Float[]{otherValue};
                return setValue;
            }
        }
        return value;
    }

    public boolean isConfigProperties() {
        return false;
    }

    public String configFieldName() {
        return null;
    }

    public FunctionProperties asProperties() {
        return FunctionProperties.builder().name(this.opName()).fieldNames(this.propertiesForFunction()).build();
    }

    public DifferentialFunction(SameDiff sameDiff, boolean inPlace, Object[] extraArgs) {
        this.sameDiff = sameDiff;
        this.inPlace = inPlace;
        this.setInstanceId();
        this.extraArgs = extraArgs;
    }

    public DifferentialFunction(SameDiff sameDiff, Object[] extraArgs) {
        this.sameDiff = sameDiff;
        this.setInstanceId();
        this.extraArgs = extraArgs;
    }

    public DifferentialFunction(SameDiff sameDiff, SDVariable[] args) {
        this(sameDiff, false, args);
    }

    public DifferentialFunction(SameDiff sameDiff, boolean inPlace, SDVariable[] args) {
        this.sameDiff = sameDiff;
        this.inPlace = inPlace;
        this.setInstanceId();
        if (sameDiff != null) {
            sameDiff.addArgsFor(args, this);
            for (int i = 0; i < args.length; ++i) {
                if (!args[i].isPlaceHolder()) continue;
                sameDiff.addPropertyToResolve(this, args[i].getVarName());
            }
        }
    }

    public SDVariable[] outputVariables() {
        return this.outputVariables(this.getOwnName() != null ? this.getOwnName() : this.opName());
    }

    public SDVariable outputVariable() {
        return this.outputVariables()[0];
    }

    public String[] outputVariablesNames() {
        SDVariable[] outputVars = this.outputVariables();
        String[] out = new String[outputVars.length];
        for (int i = 0; i < out.length; ++i) {
            out[i] = outputVars[i].getVarName();
        }
        return out;
    }

    public abstract SDVariable[] outputVariables(String var1);

    public abstract List<SDVariable> doDiff(List<SDVariable> var1);

    public DifferentialFunctionFactory f() {
        return this.sameDiff.f();
    }

    public boolean hasPlaceHolderInputs() {
        SDVariable[] args;
        for (SDVariable arg : args = this.args()) {
            if (!this.sameDiff.hasPlaceHolderVariables(this.arg().getVarName())) continue;
            return true;
        }
        return false;
    }

    public SDVariable[] args() {
        return this.sameDiff.getInputVariablesForFunction(this);
    }

    public SDVariable arg(int num) {
        SDVariable[] args = this.args();
        Preconditions.checkNotNull((Object)args, (String)"Arguments are null for function %s", (Object)this.getOwnName());
        Preconditions.checkArgument((num >= 0 && num < args.length ? 1 : 0) != 0, (String)"Invalid index: must be 0 to numArgs (0 <= idx < %s)", (int)args.length);
        return args[num];
    }

    public String[] argNames() {
        SDVariable[] args = this.args();
        String[] out = new String[args.length];
        for (int i = 0; i < args.length; ++i) {
            out[i] = args[i].getVarName();
        }
        return out;
    }

    public void resolvePropertiesFromSameDiffBeforeExecution() {
        List<String> properties = this.sameDiff.propertiesToResolveForFunction(this);
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        Map<String, Object> currentFields = this.propertiesForFunction();
        for (String property : properties) {
            if (!fields.containsKey(property)) continue;
            String var = this.sameDiff.getVarNameForFieldAndFunction(this, property);
            Field fieldType = fields.get(property);
            INDArray varArr = this.sameDiff.getArrForVarName(var);
            if (currentFields.containsKey(property)) continue;
            if (varArr == null) {
                throw new ND4JIllegalStateException("Unable to set null array!");
            }
            if (fieldType.getType().equals(int[].class)) {
                this.setValueFor(fieldType, varArr.data().asInt());
                continue;
            }
            if (fieldType.equals(double[].class)) {
                this.setValueFor(fieldType, varArr.data().asDouble());
                continue;
            }
            if (fieldType.equals(Integer.TYPE)) {
                this.setValueFor(fieldType, varArr.getInt(0));
                continue;
            }
            if (!fieldType.equals(Double.TYPE)) continue;
            this.setValueFor(fieldType, varArr.getDouble(0L));
        }
    }

    public SDVariable arg() {
        if (this.args() == null || this.args().length == 0) {
            return null;
        }
        return this.args()[0];
    }

    public List<SDVariable> diff(List<SDVariable> i_v1) {
        List<SDVariable> vals = this.doDiff(i_v1);
        if (vals == null) {
            throw new IllegalStateException("Error executing diff operation: doDiff returned null for op: " + this.opName());
        }
        SDVariable[] outputVars = this.args();
        boolean copied = false;
        for (int i = 0; i < vals.size(); ++i) {
            SDVariable gradVar;
            SDVariable var = outputVars[i];
            SDVariable grad = var.getGradient();
            if (grad != null) {
                if (!copied) {
                    vals = new ArrayList<SDVariable>(vals);
                    copied = true;
                }
                gradVar = this.f().add(grad, vals.get(i));
                vals.set(i, gradVar);
                this.sameDiff.setGradientForVariableName(var.getVarName(), gradVar);
                continue;
            }
            gradVar = vals.get(i);
            this.sameDiff.updateVariableNameAndReference(gradVar, var.getVarName() + "-grad");
            this.sameDiff.setGradientForVariableName(var.getVarName(), gradVar);
            this.sameDiff.setForwardVariableForVarName(gradVar.getVarName(), var);
        }
        return vals;
    }

    protected void setInstanceId() {
        if (this.ownName == null) {
            if (this.sameDiff == null) {
                this.ownName = UUID.randomUUID().toString();
            } else {
                int argIndex = 0;
                String varName = this.sameDiff.generateNewVarName(this.opName(), argIndex);
                while (this.sameDiff.functionExists(varName)) {
                    varName = this.sameDiff.generateNewVarName(this.opName(), argIndex);
                    ++argIndex;
                }
                this.ownName = varName;
            }
            if (this.sameDiff != null && !(this instanceof SDVariable)) {
                this.sameDiff.putFunctionForId(this.ownName, this);
            }
        }
    }

    public String opName() {
        throw new UnsupportedOperationException();
    }

    public Op.Type opType() {
        throw new UnsupportedOperationException();
    }

    public int opNum() {
        throw new UnsupportedOperationException();
    }

    @JsonIgnore
    private INDArray getX() {
        INDArray ret = this.sameDiff.getArrForVarName(this.args()[0].getVarName());
        return ret;
    }

    @JsonIgnore
    private INDArray getY() {
        if (this.args().length > 1) {
            INDArray ret = this.sameDiff.getArrForVarName(this.args()[1].getVarName());
            return ret;
        }
        return null;
    }

    @JsonIgnore
    private INDArray getZ() {
        if (this.isInPlace()) {
            return this.getX();
        }
        SDVariable opId = this.outputVariables()[0];
        INDArray ret = opId.getArr();
        return ret;
    }

    public abstract void initFromTensorFlow(NodeDef var1, SameDiff var2, Map<String, AttrValue> var3, GraphDef var4);

    public abstract void initFromOnnx(OnnxProto3.NodeProto var1, SameDiff var2, Map<String, OnnxProto3.AttributeProto> var3, OnnxProto3.GraphProto var4);

    public SDVariable larg() {
        SDVariable[] args = this.args();
        if (args == null || args.length == 0) {
            throw new ND4JIllegalStateException("No arguments found.");
        }
        return this.args()[0];
    }

    public SDVariable rarg() {
        SDVariable[] args = this.args();
        if (args == null || args.length != 2) {
            throw new ND4JIllegalStateException("In order to use this function, the number of arguments for this function must be 2.");
        }
        return args[1];
    }

    public DifferentialFunction dup() {
        Cloner cloner = SameDiff.newCloner();
        return (DifferentialFunction)cloner.deepClone((Object)this);
    }

    public List<long[]> calculateOutputShape() {
        throw new UnsupportedOperationException();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        DifferentialFunction that = (DifferentialFunction)o;
        if (this.inPlace != that.inPlace) {
            return false;
        }
        if (this.scalarValue != null ? !this.scalarValue.equals(that.scalarValue) : that.scalarValue != null) {
            return false;
        }
        if (!Arrays.equals(this.dimensions, that.dimensions)) {
            return false;
        }
        return this.ownName != null ? this.ownName.equals(that.ownName) : that.ownName == null;
    }

    public int hashCode() {
        int result = 31;
        result = 31 * result + (this.inPlace ? 1 : 0);
        result = 31 * result + (this.scalarValue != null ? this.scalarValue.hashCode() : 0);
        result = 31 * result + Arrays.hashCode(this.dimensions);
        result = 31 * result + (this.ownName != null ? this.ownName.hashCode() : 0);
        return result;
    }

    public String[] onnxNames() {
        return new String[]{this.onnxName()};
    }

    public String[] tensorflowNames() {
        return new String[]{this.tensorflowName()};
    }

    public abstract String onnxName();

    public abstract String tensorflowName();

    public int getNumOutputs() {
        return -1;
    }

    public Object[] getExtraArgs() {
        return this.extraArgs;
    }

    public void setExtraArgs(Object[] extraArgs) {
        this.extraArgs = extraArgs;
    }

    public String toString() {
        return "DifferentialFunction(sameDiff=" + this.getSameDiff() + ", inPlace=" + this.isInPlace() + ", scalarValue=" + this.getScalarValue() + ", dimensions=" + Arrays.toString(this.getDimensions()) + ", extraArgs=" + Arrays.deepToString(this.getExtraArgs()) + ", ownName=" + this.getOwnName() + ")";
    }

    public SameDiff getSameDiff() {
        return this.sameDiff;
    }

    public void setSameDiff(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
    }

    public boolean isInPlace() {
        return this.inPlace;
    }

    public void setInPlace(boolean inPlace) {
        this.inPlace = inPlace;
    }

    public Number getScalarValue() {
        return this.scalarValue;
    }

    public void setScalarValue(Number scalarValue) {
        this.scalarValue = scalarValue;
    }

    public int[] getDimensions() {
        return this.dimensions;
    }

    public void setDimensions(int[] dimensions) {
        this.dimensions = dimensions;
    }

    public String getOwnName() {
        return this.ownName;
    }

    public void setOwnName(String ownName) {
        this.ownName = ownName;
    }
}

