/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops;

import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseScalarOp
extends BaseOp
implements ScalarOp {
    private static final Logger log = LoggerFactory.getLogger(BaseScalarOp.class);

    public BaseScalarOp() {
        this.scalarValue = Nd4j.scalar(0.0f);
    }

    public BaseScalarOp(INDArray x, INDArray y, INDArray z, Number num) {
        super(x, y, z);
        if (x.isCompressed()) {
            Nd4j.getCompressor().decompressi(x);
        }
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            this.scalarValue = Nd4j.scalar(x.dataType(), num);
        }
    }

    public BaseScalarOp(INDArray x, Number num) {
        super(x);
        if (x.isCompressed()) {
            Nd4j.getCompressor().decompressi(x);
        }
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            this.scalarValue = Nd4j.scalar(x.dataType(), num);
        }
    }

    public BaseScalarOp(INDArray x, INDArray z, Number set) {
        super(x, null, z);
        if (x.isCompressed()) {
            Nd4j.getCompressor().decompressi(x);
        }
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            this.scalarValue = Nd4j.scalar(x.dataType(), set);
        }
    }

    public BaseScalarOp(SameDiff sameDiff, SDVariable i_v, Number scalar) {
        this(sameDiff, i_v, scalar, false, null);
    }

    public BaseScalarOp(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace) {
        this(sameDiff, i_v, scalar, inPlace, null);
    }

    public BaseScalarOp(SameDiff sameDiff, @NonNull SDVariable i_v, Number scalar, boolean inPlace, Object[] extraArgs) {
        super(sameDiff, inPlace, extraArgs);
        if (i_v == null) {
            throw new NullPointerException("i_v is marked @NonNull but is null");
        }
        this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar);
        this.xVertexId = i_v.getVarName();
        sameDiff.addArgsFor(new String[]{this.xVertexId}, (DifferentialFunction)this);
        if (Shape.isPlaceholderShape(i_v.getShape())) {
            sameDiff.addPropertyToResolve(this, i_v.getVarName());
        }
        this.f().validateDifferentialFunctionsameDiff(i_v);
    }

    public BaseScalarOp(SameDiff sameDiff, SDVariable i_v, Number scalar, Object[] extraArgs) {
        this(sameDiff, i_v, scalar, false, extraArgs);
    }

    @Override
    public INDArray z() {
        return this.z;
    }

    @Override
    public List<LongShapeDescriptor> calculateOutputShape() {
        ArrayList<LongShapeDescriptor> ret = new ArrayList<LongShapeDescriptor>(1);
        long[] s = this.x != null ? this.x.shape() : this.arg().getShape();
        DataType aT = this.arg().dataType();
        DataType sT = this.scalarValue.dataType();
        ret.add(LongShapeDescriptor.fromShape(s, Shape.pickPairwiseDataType(aT, sT)));
        return ret;
    }

    @Override
    public Op.Type opType() {
        return Op.Type.SCALAR;
    }

    @Override
    public void setScalar(Number scalar) {
        this.scalarValue = Nd4j.scalar(this.x.dataType(), scalar);
    }

    @Override
    public void setScalar(INDArray scalar) {
        this.scalarValue = scalar;
    }

    @Override
    public INDArray scalar() {
        if (this.y() != null && this.y().isScalar()) {
            return this.y();
        }
        return this.scalarValue;
    }

    @Override
    public int[] getDimension() {
        return this.dimensions;
    }

    @Override
    public void setDimension(int ... dimension) {
        this.defineDimensions(dimension);
    }

    @Override
    public boolean validateDataTypes(boolean experimentalMode) {
        if (this.y() != null) {
            if (this.y().isR() || this.x().isR()) {
                Preconditions.checkArgument((boolean)this.z().isR(), (String)"Op.Z must have floating point type, since one of operands is floating point: x.dataType=%s, y.dataType=%s, z.dataType=%s, op=%s", (Object)this.x.dataType(), (Object)this.y.dataType(), (Object)this.z.dataType(), (Object)this.getClass().getName());
            }
            if (!experimentalMode) {
                Preconditions.checkArgument((this.x.dataType() == this.y.dataType() || this.y.dataType() == DataType.BOOL ? 1 : 0) != 0, (String)"Op.X must have same data type as Op.Y");
            }
        } else if (this.x().isR()) {
            Preconditions.checkArgument((boolean)this.z().isR(), (String)"Op.Z must have floating point type, since one of operands is floating point: x.dataType=%s, z.dataType=%s, op=%s", (Object)this.x.dataType(), (Object)this.z.dataType(), (Object)this.getClass().getName());
        }
        return true;
    }

    @Override
    public Op.Type getOpType() {
        return Op.Type.SCALAR;
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        Preconditions.checkState((dataTypes != null && dataTypes.size() == 1 ? 1 : 0) != 0, (String)"Expected exactly 1 input datatype %s, got input %s", this.getClass(), dataTypes);
        return dataTypes;
    }
}

