/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.accum;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class Mmul
extends DynamicCustomOp {
    protected MMulTranspose mt;

    public Mmul(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, MMulTranspose mt) {
        super(null, sameDiff, new SDVariable[]{i_v1, i_v2});
        this.mt = mt;
        this.addIArgument(ArrayUtil.fromBoolean((boolean)mt.isTransposeA()), ArrayUtil.fromBoolean((boolean)mt.isTransposeB()), ArrayUtil.fromBoolean((boolean)mt.isTransposeResult()));
    }

    public Mmul(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
        this(sameDiff, i_v1, i_v2, MMulTranspose.allFalse());
    }

    public Mmul(INDArray x, INDArray y, INDArray z, MMulTranspose mt) {
        INDArray[] iNDArrayArray;
        INDArray[] iNDArrayArray2 = new INDArray[]{x, y};
        if (z == null) {
            iNDArrayArray = null;
        } else {
            INDArray[] iNDArrayArray3 = new INDArray[1];
            iNDArrayArray = iNDArrayArray3;
            iNDArrayArray3[0] = z;
        }
        super(null, iNDArrayArray2, iNDArrayArray);
        if (mt != null) {
            this.mt = mt;
            this.addIArgument(ArrayUtil.fromBoolean((boolean)mt.isTransposeA()), ArrayUtil.fromBoolean((boolean)mt.isTransposeB()), ArrayUtil.fromBoolean((boolean)mt.isTransposeResult()));
        }
    }

    public Mmul() {
    }

    public long[] transposeShapeArray(long[] shape) {
        if (shape.length == 2) {
            return ArrayUtil.reverseCopy((long[])shape);
        }
        if (shape.length == 3) {
            return new long[]{shape[0], shape[2], shape[1]};
        }
        throw new IllegalArgumentException("Matrix input has to be of length 2 or 3, got: " + shape.length);
    }

    @Override
    public List<long[]> calculateOutputShape() {
        if (this.mt == null) {
            this.mt = MMulTranspose.allFalse();
        }
        long[] aShape = this.larg().getShape();
        long[] bShape = this.rarg().getShape();
        if (Shape.isPlaceholderShape(aShape) || Shape.isPlaceholderShape(bShape)) {
            return Collections.emptyList();
        }
        aShape = this.mt.isTransposeA() ? this.transposeShapeArray(aShape) : aShape;
        bShape = this.mt.isTransposeB() ? this.transposeShapeArray(bShape) : bShape;
        long[] shape = Shape.getMatrixMultiplyShape(aShape, bShape);
        if (this.mt.isTransposeResult()) {
            shape = this.transposeShapeArray(shape);
        }
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] >= 1L) continue;
            throw new ND4JIllegalStateException("Invalid shape computed at index " + i + ": shape " + Arrays.toString(shape));
        }
        return Collections.singletonList(shape);
    }

    @Override
    public String onnxName() {
        return "MatMul";
    }

    @Override
    public String tensorflowName() {
        return "MatMul";
    }

    @Override
    public String opName() {
        return "mmul";
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        SDVariable[] args;
        MMulTranspose mMulTranspose;
        super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
        boolean isTransposeA = attributesForNode.get("transpose_a").getB();
        boolean isTransposeB = attributesForNode.get("transpose_b").getB();
        this.mt = mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
        for (SDVariable arg : args = this.args()) {
            if (!this.sameDiff.isPlaceHolder(arg.getVarName()) && arg.getShape() != null) continue;
            this.sameDiff.addPropertyToResolve(this, arg.getVarName());
        }
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
        MMulTranspose mMulTranspose;
        boolean isTransposeA;
        boolean bl = !attributesForNode.containsKey("transA") ? false : (isTransposeA = attributesForNode.get("transA").getI() > 0L);
        boolean isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0L;
        this.mt = mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> i_v1) {
        ArrayList<SDVariable> ret = new ArrayList<SDVariable>();
        SDVariable dLdOut = i_v1.get(0);
        SDVariable dLdx = this.sameDiff.mmul(dLdOut, this.rarg(), MMulTranspose.builder().transposeA(this.mt.isTransposeResult()).transposeB(!this.mt.isTransposeB()).transposeResult(this.mt.isTransposeA()).build());
        SDVariable dLdy = this.sameDiff.mmul(this.larg(), dLdOut, MMulTranspose.builder().transposeA(!this.mt.isTransposeA()).transposeB(this.mt.isTransposeResult()).transposeResult(this.mt.isTransposeB()).build());
        ret.add(dLdx);
        ret.add(dLdy);
        return ret;
    }

    @Override
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap<String, Map<String, PropertyMapping>> ret = new HashMap<String, Map<String, PropertyMapping>>();
        HashMap<String, PropertyMapping> map = new HashMap<String, PropertyMapping>();
        PropertyMapping transposeA = PropertyMapping.builder().onnxAttrName("transA").tfAttrName("transpose_a").propertyNames(new String[]{"transposeA"}).build();
        PropertyMapping transposeB = PropertyMapping.builder().onnxAttrName("transB").tfAttrName("transpose_b").propertyNames(new String[]{"transposeB"}).build();
        map.put("transposeA", transposeA);
        map.put("transposeB", transposeB);
        ret.put(this.tensorflowName(), map);
        ret.put(this.onnxName(), map);
        return ret;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof Mmul)) {
            return false;
        }
        Mmul other = (Mmul)o;
        if (!other.canEqual(this)) {
            return false;
        }
        MMulTranspose this$mt = this.mt;
        MMulTranspose other$mt = other.mt;
        return !(this$mt == null ? other$mt != null : !((Object)this$mt).equals(other$mt));
    }

    protected boolean canEqual(Object other) {
        return other instanceof Mmul;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        MMulTranspose $mt = this.mt;
        result = result * 59 + ($mt == null ? 43 : ((Object)$mt).hashCode());
        return result;
    }
}

