/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;

import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorConverter;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TypedTensorFunction;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Map;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;

class OperationMapper {
    private TensorConverter tensorConverter = new TensorConverter();

    OperationMapper() {
    }

    TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) {
        this.ensureArguments(2, arguments, "join");
        TypedTensorFunction a = arguments.get(0);
        TypedTensorFunction b = arguments.get(1);
        if (a.type().rank() < b.type().rank()) {
            throw new IllegalArgumentException("Attempt to join " + a.type() + " and " + b.type() + ", but this is not supported when the second argument has a higher rank");
        }
        TensorFunction bFunction = b.function();
        if (a.type().rank() > b.type().rank()) {
            ArrayList<String> renameFrom = new ArrayList<String>();
            ArrayList<String> renameTo = new ArrayList<String>();
            int sizeDifference = a.type().rank() - b.type().rank();
            for (int i = 0; i < b.type().rank(); ++i) {
                renameFrom.add(((TensorType.Dimension)b.type().dimensions().get(i)).name());
                renameTo.add("d" + (sizeDifference + i));
            }
            bFunction = new Rename(bFunction, renameFrom, renameTo);
        }
        Join function = new Join(a.function(), bFunction, doubleFunction);
        return new TypedTensorFunction(a.type(), (TensorFunction)function);
    }

    TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) {
        this.ensureArguments(1, arguments, "apply");
        TypedTensorFunction a = arguments.get(0);
        TensorType resultType = Map.outputType((TensorType)a.type());
        Map function = new Map(a.function(), doubleFunction);
        return new TypedTensorFunction(resultType, (TensorFunction)function);
    }

    TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) {
        String name = tfNode.getName();
        TensorType type = result.arguments().get(name);
        if (type == null) {
            throw new IllegalArgumentException("An placeholder operation node is referencing input '" + name + "', but there is no such input");
        }
        return new TypedTensorFunction(type, (TensorFunction)new VariableTensor(name));
    }

    TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) {
        if (!tfNode.getName().endsWith("/read")) {
            throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify nodes are only supported when reading variables");
        }
        if (tfNode.getInputList().size() != 1) {
            throw new IllegalArgumentException("A Variable/read node must have one input but has " + tfNode.getInputList().size());
        }
        String name = tfNode.getInput(0);
        AttrValue shapes = (AttrValue)tfNode.getAttrMap().get("_output_shapes");
        if (shapes == null) {
            throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape");
        }
        Session.Runner fetched = model.session().runner().fetch(name);
        List importedTensors = fetched.run();
        if (importedTensors.size() != 1) {
            throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + importedTensors.size());
        }
        com.yahoo.tensor.Tensor constant = this.tensorConverter.toVespaTensor((Tensor)importedTensors.get(0));
        result.set(name, constant);
        return new TypedTensorFunction(constant.type(), (TensorFunction)new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
    }

    TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
        this.ensureArguments(2, arguments, "matmul");
        TypedTensorFunction a = arguments.get(0);
        TypedTensorFunction b = arguments.get(1);
        if (a.type().rank() < 2 || b.type().rank() < 2) {
            throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
        }
        if (a.type().rank() != b.type().rank()) {
            throw new IllegalArgumentException("Tensors in matmul must have the same rank");
        }
        String afterLastDim = "d" + (a.type().rank() + 1);
        Rename renamedB = new Rename(b.function(), (List)ImmutableList.of((Object)"d0", (Object)"d1"), (List)ImmutableList.of((Object)"d1", (Object)afterLastDim));
        Matmul matmul = new Matmul(a.function(), (TensorFunction)renamedB, "d1");
        return new TypedTensorFunction(Matmul.outputType((TensorType)a.type(), (TensorType)b.type(), (String)"d1"), (TensorFunction)new Rename((TensorFunction)matmul, afterLastDim, "d1"));
    }

    TypedTensorFunction softmax(List<TypedTensorFunction> arguments) {
        this.ensureArguments(1, arguments, "softmax");
        TypedTensorFunction a = arguments.get(0);
        String dimension = "d" + (a.type().rank() - 1);
        Softmax softmax = new Softmax(a.function(), dimension);
        return new TypedTensorFunction(Softmax.outputType((TensorType)a.type(), (String)dimension), (TensorFunction)softmax);
    }

    private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) {
        if (arguments.size() != count) {
            throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName + ", but got " + arguments.size());
        }
    }
}

