/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Collections;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;

public class Eye
extends DynamicCustomOp {
    public static final DataType DEFAULT_DTYPE = DataType.FLOAT;
    private int numRows;
    private int numCols;
    private int[] batchDimension = new int[0];
    private DataType dataType = DEFAULT_DTYPE;

    public Eye() {
    }

    public Eye(SameDiff sameDiff, SDVariable numRows) {
        super(null, sameDiff, new SDVariable[]{numRows}, false);
    }

    public Eye(SameDiff sameDiff, SDVariable numRows, SDVariable numCols) {
        super(null, sameDiff, new SDVariable[]{numRows, numCols}, false);
    }

    public Eye(SameDiff sameDiff, SDVariable numRows, SDVariable numCols, SDVariable batch_shape) {
        super(null, sameDiff, new SDVariable[]{numRows, numCols, batch_shape}, false);
    }

    public Eye(SameDiff sameDiff, int numRows) {
        super(null, sameDiff, new SDVariable[0], false);
        this.numRows = numRows;
        this.numCols = numRows;
        this.addArgs();
    }

    public Eye(SameDiff sameDiff, int numRows, int numCols) {
        this(sameDiff, numRows, numCols, DEFAULT_DTYPE);
    }

    public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType) {
        super(null, sameDiff, new SDVariable[0], false);
        this.numRows = numRows;
        this.numCols = numCols;
        this.dataType = dataType;
        this.addArgs();
    }

    public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType, int[] batchDimension) {
        super(null, sameDiff, new SDVariable[0], false);
        this.numRows = numRows;
        this.numCols = numCols;
        this.batchDimension = batchDimension;
        this.dataType = dataType;
        this.addArgs();
    }

    protected void addArgs() {
        this.iArguments.clear();
        this.tArguments.clear();
        this.addIArgument(this.numRows);
        this.addIArgument(this.numCols);
        if (this.batchDimension != null) {
            for (int dim : this.batchDimension) {
                this.addIArgument(dim);
            }
        }
        this.addTArgument(this.dataType.toInt());
    }

    @Override
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + this.opName());
    }

    @Override
    public String tensorflowName() {
        return "Eye";
    }

    @Override
    public String opName() {
        return "eye";
    }

    @Override
    public List<LongShapeDescriptor> calculateOutputShape() {
        List<LongShapeDescriptor> l = super.calculateOutputShape();
        if (this.dataType != null && l != null && l.size() > 0) {
            l.set(0, l.get(0).asDataType(this.dataType));
        }
        return l;
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> outGrad) {
        if (this.arg() != null) {
            return Collections.singletonList(this.sameDiff.onesLike(this.arg()));
        }
        return Collections.emptyList();
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        return Collections.singletonList(this.dataType == null ? DEFAULT_DTYPE : this.dataType);
    }
}

