/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class OneHot
extends DynamicCustomOp {
    public static final DataType DEFAULT_DTYPE = DataType.FLOAT;
    private int depth;
    private int jaxis = -1;
    private double on;
    private double off;
    private DataType outputType;

    public OneHot() {
    }

    public OneHot(SameDiff sameDiff, SDVariable indices, int depth) {
        this(sameDiff, indices, depth, -1, 1.0, 0.0, DEFAULT_DTYPE);
    }

    public OneHot(SameDiff sameDiff, SDVariable indices, int depth, int axis, double on, double off, DataType dataType) {
        super(null, sameDiff, new SDVariable[]{indices}, false);
        this.depth = depth;
        this.jaxis = axis;
        this.on = on;
        this.off = off;
        this.addArgs();
        this.outputType = dataType;
    }

    public OneHot(INDArray indices, INDArray output, int depth) {
        this(indices, output, depth, -1, 1.0, 0.0);
    }

    public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) {
        super(null, indices, output, null, null);
        this.depth = depth;
        this.jaxis = axis;
        this.on = on;
        this.off = off;
        this.addArgs();
    }

    protected void addArgs() {
        this.addIArgument(this.jaxis);
        this.addIArgument(this.depth);
        this.addTArgument(this.on);
        this.addTArgument(this.off);
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
        this.addArgs();
        if (attributesForNode.containsKey("T")) {
            this.outputType = TFGraphMapper.convertType(attributesForNode.get("T").getType());
        }
    }

    @Override
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap<String, Map<String, PropertyMapping>> ret = new HashMap<String, Map<String, PropertyMapping>>();
        LinkedHashMap<String, PropertyMapping> attrs = new LinkedHashMap<String, PropertyMapping>();
        PropertyMapping depth = PropertyMapping.builder().propertyNames(new String[]{"depth"}).tfInputPosition(1).build();
        attrs.put("depth", depth);
        PropertyMapping on = PropertyMapping.builder().propertyNames(new String[]{"on"}).tfInputPosition(2).build();
        attrs.put("on", on);
        PropertyMapping off = PropertyMapping.builder().propertyNames(new String[]{"off"}).tfInputPosition(3).build();
        attrs.put("off", off);
        PropertyMapping axis = PropertyMapping.builder().propertyNames(new String[]{"jaxis"}).tfAttrName("axis").build();
        attrs.put("jaxis", axis);
        ret.put(this.tensorflowName(), attrs);
        return ret;
    }

    @Override
    public String tensorflowName() {
        return "OneHot";
    }

    @Override
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx name found for " + this.opName());
    }

    @Override
    public String opName() {
        return "onehot";
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> i_v) {
        return Collections.singletonList(this.sameDiff.zerosLike(this.arg()));
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        Preconditions.checkState((dataTypes.size() >= 1 && dataTypes.size() <= 4 ? 1 : 0) != 0, (String)"Expected list with 1 to 4 datatypes for %s, got %s", this.getClass(), dataTypes);
        if (this.outputType != null) {
            return Collections.singletonList(this.outputType);
        }
        return Collections.singletonList(DEFAULT_DTYPE);
    }
}

