/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

public class MatMul
extends IntermediateOperation {
    public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) {
        super(modelName, nodeName, inputs);
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (!this.allInputTypesPresent(2)) {
            return null;
        }
        OrderedTensorType typeA = ((IntermediateOperation)this.inputs.get(0)).type().get();
        OrderedTensorType typeB = ((IntermediateOperation)this.inputs.get(1)).type().get();
        if (typeA.type().rank() < 1 || typeB.type().rank() < 1) {
            throw new IllegalArgumentException("Tensors in matmul must have rank of at least 1");
        }
        OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(this.resultValueType());
        OrderedTensorType largestRankType = typeA.rank() >= typeB.rank() ? typeA : typeB;
        OrderedTensorType smallestRankType = typeA.rank() >= typeB.rank() ? typeB : typeA;
        for (int i = 0; i < largestRankType.rank() - 2; ++i) {
            TensorType.Dimension dim = largestRankType.dimensions().get(i);
            int j = smallestRankType.rank() - largestRankType.rank() + i;
            if (j >= 0 && (Long)smallestRankType.dimensions().get(j).size().get() > (Long)dim.size().get()) {
                dim = smallestRankType.dimensions().get(j);
            }
            typeBuilder.add(dim);
        }
        if (typeA.rank() >= 2) {
            typeBuilder.add(typeA.dimensions().get(typeA.rank() - 2));
        }
        if (typeB.rank() >= 2) {
            typeBuilder.add(typeB.dimensions().get(typeB.rank() - 1));
        }
        return typeBuilder.build();
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        if (!this.allInputTypesPresent(2)) {
            return null;
        }
        if (!this.allInputFunctionsPresent(2)) {
            return null;
        }
        OrderedTensorType typeA = ((IntermediateOperation)this.inputs.get(0)).type().get();
        OrderedTensorType typeB = ((IntermediateOperation)this.inputs.get(1)).type().get();
        TensorFunction functionA = this.handleBroadcasting(((IntermediateOperation)this.inputs.get(0)).function().get(), typeA, typeB);
        TensorFunction functionB = this.handleBroadcasting(((IntermediateOperation)this.inputs.get(1)).function().get(), typeB, typeA);
        return new Reduce((TensorFunction)new Join(functionA, functionB, ScalarFunctions.multiply()), Reduce.Aggregator.sum, typeA.dimensions().get(typeA.rank() - 1).name());
    }

    private TensorFunction handleBroadcasting(TensorFunction tensorFunction, OrderedTensorType typeA, OrderedTensorType typeB) {
        ArrayList<Slice.DimensionValue> slices = new ArrayList<Slice.DimensionValue>();
        for (int i = 0; i < typeA.rank() - 2; ++i) {
            long dimSizeB;
            long dimSizeA = (Long)typeA.dimensions().get(i).size().get();
            String dimNameA = typeA.dimensionNames().get(i);
            int j = typeB.rank() - typeA.rank() + i;
            if (j < 0 || (dimSizeB = ((Long)typeB.dimensions().get(j).size().get()).longValue()) <= dimSizeA || dimSizeA != 1L) continue;
            EmbracedNode dimensionExpression = new EmbracedNode((ExpressionNode)new ConstantNode((Value)DoubleValue.zero));
            slices.add(new Slice.DimensionValue(Optional.of(dimNameA), TensorFunctionNode.wrapScalar((ExpressionNode)dimensionExpression)));
        }
        return slices.size() == 0 ? tensorFunction : new Slice(tensorFunction, slices);
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        String jDim;
        int j;
        String iDim;
        int i;
        if (!this.allInputTypesPresent(2)) {
            return;
        }
        OrderedTensorType typeA = ((IntermediateOperation)this.inputs.get(0)).type().get();
        OrderedTensorType typeB = ((IntermediateOperation)this.inputs.get(1)).type().get();
        String lastDimA = typeA.dimensions().get(typeA.rank() - 1).name();
        String lastDimB = typeB.dimensions().get(typeB.rank() - 1).name();
        String secondLastDimA = typeA.dimensions().get(Math.max(0, typeA.rank() - 2)).name();
        String secondLastDimB = typeB.dimensions().get(Math.max(0, typeB.rank() - 2)).name();
        renamer.addConstraint(lastDimA, secondLastDimB, DimensionRenamer.Constraint.equal(false), this);
        if (typeA.rank() >= 2 && typeB.rank() >= 2) {
            renamer.addConstraint(secondLastDimA, lastDimB, DimensionRenamer.Constraint.lessThan(false), this);
        }
        if (typeA.rank() >= 2) {
            renamer.addConstraint(secondLastDimA, lastDimA, DimensionRenamer.Constraint.lessThan(true), this);
        }
        if (typeB.rank() >= 2) {
            renamer.addConstraint(secondLastDimB, lastDimB, DimensionRenamer.Constraint.greaterThan(true), this);
        }
        for (i = 0; i < typeA.rank() - 2; ++i) {
            iDim = typeA.dimensionNames().get(i);
            for (j = i + 1; j < typeA.rank(); ++j) {
                jDim = typeA.dimensionNames().get(j);
                renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(false), this);
            }
            for (j = typeB.rank() - 2; j < typeB.rank(); ++j) {
                if (j < 0) continue;
                jDim = typeB.dimensionNames().get(j);
                renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.notEqual(false), this);
            }
            j = typeB.rank() - typeA.rank() + i;
            if (j < 0) continue;
            jDim = typeB.dimensionNames().get(j);
            renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.equal(false), this);
        }
        for (i = 0; i < typeB.rank() - 2; ++i) {
            iDim = typeB.dimensionNames().get(i);
            for (j = i + 1; j < typeB.rank(); ++j) {
                jDim = typeB.dimensionNames().get(j);
                renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(false), this);
            }
            for (j = typeA.rank() - 2; j < typeA.rank(); ++j) {
                if (j < 0) continue;
                jDim = typeA.dimensionNames().get(j);
                renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.notEqual(false), this);
            }
            j = typeA.rank() - typeB.rank() + i;
            if (j < 0) continue;
            jDim = typeA.dimensionNames().get(j);
            renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.equal(false), this);
        }
    }

    @Override
    public MatMul withInputs(List<IntermediateOperation> inputs) {
        return new MatMul(this.modelName(), this.name(), inputs);
    }

    @Override
    public String operationName() {
        return "MatMul";
    }
}

