/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.accum;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class Mmul
extends DynamicCustomOp {
    protected MMulTranspose mMulTranspose;

    public Mmul(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, MMulTranspose mMulTranspose) {
        super(null, sameDiff, new SDVariable[]{i_v1, i_v2});
        this.mMulTranspose = mMulTranspose;
        this.addIArgument(ArrayUtil.fromBoolean((boolean)mMulTranspose.isTransposeA()), ArrayUtil.fromBoolean((boolean)mMulTranspose.isTransposeB()));
    }

    public Mmul(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
        this(sameDiff, i_v1, i_v2, MMulTranspose.allFalse());
    }

    public Mmul(INDArray x, INDArray y, INDArray z, MMulTranspose mMulTranspose) {
        INDArray[] iNDArrayArray;
        INDArray[] iNDArrayArray2 = new INDArray[]{x, y};
        if (z == null) {
            iNDArrayArray = null;
        } else {
            INDArray[] iNDArrayArray3 = new INDArray[1];
            iNDArrayArray = iNDArrayArray3;
            iNDArrayArray3[0] = z;
        }
        super(null, iNDArrayArray2, iNDArrayArray);
        if (mMulTranspose != null) {
            this.mMulTranspose = mMulTranspose;
            this.addIArgument(ArrayUtil.fromBoolean((boolean)mMulTranspose.isTransposeA()), ArrayUtil.fromBoolean((boolean)mMulTranspose.isTransposeB()));
        }
    }

    public Mmul() {
    }

    @Override
    public List<int[]> calculateOutputShape() {
        int[] bShape;
        if (this.mMulTranspose == null) {
            this.mMulTranspose = MMulTranspose.allFalse();
        }
        ArrayList<int[]> ret = new ArrayList<int[]>(1);
        int[] aShape = this.mMulTranspose.isTransposeA() ? ArrayUtil.reverseCopy((int[])this.larg().getShape()) : this.larg().getShape();
        int[] nArray = bShape = this.mMulTranspose.isTransposeB() ? ArrayUtil.reverseCopy((int[])this.rarg().getShape()) : this.rarg().getShape();
        if (Shape.isPlaceholderShape(aShape) || Shape.isPlaceholderShape(bShape)) {
            return Collections.emptyList();
        }
        if (aShape != null && bShape != null) {
            int[] shape = Shape.getMatrixMultiplyShape(aShape, bShape);
            ret.add(shape);
        }
        if (!ret.isEmpty()) {
            for (int i = 0; i < ((int[])ret.get(0)).length; ++i) {
                if (((int[])ret.get(0))[i] >= 1) continue;
                throw new ND4JIllegalStateException("Invalid shape computed at index " + i);
            }
        }
        return ret;
    }

    @Override
    public String onnxName() {
        return "MatMul";
    }

    @Override
    public String tensorflowName() {
        return "MatMul";
    }

    @Override
    public String opName() {
        return "mmul";
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        SDVariable[] args;
        MMulTranspose mMulTranspose;
        super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
        boolean isTransposeA = attributesForNode.get("transpose_a").getB();
        boolean isTransposeB = attributesForNode.get("transpose_b").getB();
        this.mMulTranspose = mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
        for (SDVariable arg : args = this.args()) {
            if (!this.sameDiff.isPlaceHolder(arg.getVarName()) && arg.getShape() != null) continue;
            this.sameDiff.addPropertyToResolve(this, arg.getVarName());
        }
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
        MMulTranspose mMulTranspose;
        boolean isTransposeA;
        boolean bl = !attributesForNode.containsKey("transA") ? false : (isTransposeA = attributesForNode.get("transA").getI() > 0L);
        boolean isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0L;
        this.mMulTranspose = mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> i_v1) {
        ArrayList<SDVariable> ret = new ArrayList<SDVariable>();
        SDVariable setup = this.sameDiff.setupFunction(i_v1.get(0));
        SDVariable gradWrtX = this.sameDiff.setupFunction(this.f().reshape(this.f().mmul(setup, this.rarg(), MMulTranspose.builder().transposeB(!this.mMulTranspose.isTransposeB()).transposeResult(this.mMulTranspose.isTransposeA()).build()), this.larg().getShape()));
        SDVariable gradWrtY = this.sameDiff.setupFunction(this.f().reshape(this.f().mmul(this.larg(), setup, MMulTranspose.builder().transposeA(!this.mMulTranspose.isTransposeA()).transposeResult(this.mMulTranspose.isTransposeB()).build()), this.rarg().getShape()));
        ret.add(gradWrtX);
        ret.add(gradWrtY);
        return ret;
    }

    @Override
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap<String, Map<String, PropertyMapping>> ret = new HashMap<String, Map<String, PropertyMapping>>();
        HashMap<String, PropertyMapping> map = new HashMap<String, PropertyMapping>();
        PropertyMapping transposeA = PropertyMapping.builder().onnxAttrName("transA").tfAttrName("transpose_a").propertyNames(new String[]{"transposeA"}).build();
        PropertyMapping transposeB = PropertyMapping.builder().onnxAttrName("transB").tfAttrName("transpose_b").propertyNames(new String[]{"transposeB"}).build();
        map.put("transposeA", transposeA);
        map.put("transposeB", transposeB);
        ret.put(this.tensorflowName(), map);
        ret.put(this.onnxName(), map);
        return ret;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Mmul mmul = (Mmul)o;
        return this.mMulTranspose != null ? this.mMulTranspose.equals(mmul.mMulTranspose) : mmul.mMulTranspose == null;
    }

    @Override
    public int hashCode() {
        int result = super.hashCode();
        result = 31 * result + (this.mMulTranspose != null ? this.mMulTranspose.hashCode() : 0);
        return result;
    }
}

