/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

public abstract class IntermediateOperation {
    public static final String FUNCTION_PREFIX = "imported_ml_function_";
    protected final String name;
    protected final String modelName;
    protected final List<IntermediateOperation> inputs;
    protected final List<IntermediateOperation> outputs = new ArrayList<IntermediateOperation>();
    protected OrderedTensorType type;
    protected TensorFunction function;
    protected TensorFunction rankingExpressionFunction = null;
    protected boolean exportAsRankingFunction = false;
    private final List<String> importWarnings = new ArrayList<String>();
    private Value constantValue = null;
    private List<IntermediateOperation> controlInputs = Collections.emptyList();
    protected Function<OrderedTensorType, Value> constantValueFunction = null;

    IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) {
        this.name = name;
        this.modelName = modelName;
        this.inputs = new ArrayList<IntermediateOperation>(inputs);
        this.inputs.forEach(i -> i.outputs.add(this));
    }

    protected abstract OrderedTensorType lazyGetType();

    protected abstract TensorFunction lazyGetFunction();

    public String modelName() {
        return this.modelName;
    }

    public Optional<OrderedTensorType> type() {
        if (this.type == null) {
            this.type = this.lazyGetType();
        }
        return Optional.ofNullable(this.type);
    }

    public Optional<TensorFunction> function() {
        if (this.function == null) {
            if (this.isConstant()) {
                ReferenceNode constant = new ReferenceNode(Reference.simple((String)"constant", (String)this.vespaName()));
                this.function = new TensorFunctionNode.ExpressionTensorFunction((ExpressionNode)constant);
            } else if (this.outputs.size() > 1 || this.exportAsRankingFunction) {
                this.rankingExpressionFunction = this.lazyGetFunction();
                this.function = new VariableTensor(this.rankingExpressionFunctionName(), this.type.type());
            } else {
                this.function = this.lazyGetFunction();
            }
        }
        return Optional.ofNullable(this.function);
    }

    public String name() {
        return this.name;
    }

    public List<IntermediateOperation> inputs() {
        return this.inputs;
    }

    public List<IntermediateOperation> outputs() {
        return Collections.unmodifiableList(this.outputs);
    }

    public Optional<TensorFunction> rankingExpressionFunction() {
        return Optional.ofNullable(this.rankingExpressionFunction);
    }

    public void addDimensionNameConstraints(DimensionRenamer renamer) {
    }

    protected void addConstraintsFrom(OrderedTensorType type, DimensionRenamer renamer) {
        for (int i = 0; i < type.dimensions().size(); ++i) {
            renamer.addDimension(type.dimensions().get(i).name());
            for (int j = i + 1; j < type.dimensions().size(); ++j) {
                renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), DimensionRenamer.Constraint.notEqual(false), this);
            }
        }
    }

    public void renameDimensions(DimensionRenamer renamer) {
        this.type = this.type.rename(renamer);
    }

    public boolean isInput() {
        return false;
    }

    public boolean isConstant() {
        return this.inputs.stream().allMatch(IntermediateOperation::isConstant);
    }

    public void setConstantValue(Value value) {
        this.constantValue = value;
    }

    public Optional<Value> getConstantValue() {
        if (this.constantValue != null) {
            return Optional.of(this.constantValue);
        }
        if (this.constantValueFunction != null) {
            return Optional.of(this.constantValueFunction.apply(this.type().orElse(null)));
        }
        return Optional.empty();
    }

    public void setConstantValueFunction(Function<OrderedTensorType, Value> func) {
        this.constantValueFunction = func;
    }

    public void setControlInputs(List<IntermediateOperation> inputs) {
        this.controlInputs = inputs;
    }

    public List<IntermediateOperation> getControlInputs() {
        return Collections.unmodifiableList(this.controlInputs);
    }

    public String vespaName() {
        return this.vespaName(this.name);
    }

    public String vespaName(String name) {
        return name != null ? IntermediateOperation.namePartOf(name).replace('/', '_').replace('.', '_') : null;
    }

    public String rankingExpressionFunctionName() {
        return this.vespaName() != null ? FUNCTION_PREFIX + this.modelName + "_" + this.vespaName() : null;
    }

    public List<String> warnings() {
        return Collections.unmodifiableList(this.importWarnings);
    }

    public void warning(String warning) {
        this.importWarnings.add(warning);
    }

    boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) {
        if (this.inputs.size() != expected) {
            throw new IllegalArgumentException("Expected " + expected + " inputs for '" + this.name + "', got " + this.inputs.size());
        }
        return this.inputs.stream().map(func).allMatch(Optional::isPresent);
    }

    boolean allInputTypesPresent(int expected) {
        return this.verifyInputs(expected, IntermediateOperation::type);
    }

    boolean allInputFunctionsPresent(int expected) {
        return this.verifyInputs(expected, IntermediateOperation::function);
    }

    public Value evaluateAsConstant(OrderedTensorType type) {
        if (!this.isConstant()) {
            throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant.");
        }
        Value val = this.evaluateAsConstant((Context)new MapContext((Value)DoubleValue.NaN));
        if (type != null && !val.asTensor().type().equals((Object)type.type())) {
            throw new IllegalArgumentException("Constant evaluation in " + this.name + " resulted in wrong type. Expected: " + type.type() + " Got: " + val.asTensor().type());
        }
        return val;
    }

    private Value evaluateAsConstant(Context context) {
        String constantName = "constant(" + this.vespaName() + ")";
        Value result = context.get(constantName);
        if (result == DoubleValue.NaN) {
            if (this.constantValue != null) {
                result = this.constantValue;
            } else if (this.inputs.size() == 0) {
                if (this.getConstantValue().isEmpty()) {
                    throw new IllegalArgumentException("Error in evaluating constant for " + this.name);
                }
                result = this.getConstantValue().get();
            } else {
                this.inputs.forEach(i -> i.evaluateAsConstant(context));
                result = new TensorValue(this.lazyGetFunction().evaluate((EvaluationContext)context));
            }
            context.put(constantName, result);
            if (this.outputs.size() > 1 || this.exportAsRankingFunction) {
                context.put(this.rankingExpressionFunctionName(), result);
            }
        }
        return result;
    }

    public void insert(IntermediateOperation operationToInsert, int inputNumber) {
        if (operationToInsert.inputs.size() > 0) {
            throw new IllegalArgumentException("Operation to insert to '" + this.name + "' has existing inputs which is not supported.");
        }
        IntermediateOperation previousInputOperation = this.inputs.get(inputNumber);
        int outputNumber = this.findOutputNumber(previousInputOperation, this);
        if (outputNumber == -1) {
            throw new IllegalArgumentException("Input '" + previousInputOperation.name + "' to '" + this.name + "' does not have '" + this.name + "' as output.");
        }
        previousInputOperation.outputs.set(outputNumber, operationToInsert);
        operationToInsert.inputs.add(previousInputOperation);
        operationToInsert.outputs.add(this);
        this.inputs.set(inputNumber, operationToInsert);
    }

    public void uninsert(int inputNumber) {
        IntermediateOperation operationToRemove = this.inputs.get(inputNumber);
        IntermediateOperation newInputOperation = operationToRemove.inputs.get(0);
        int outputNumber = this.findOutputNumber(newInputOperation, operationToRemove);
        newInputOperation.outputs.set(outputNumber, this);
        this.inputs.set(inputNumber, newInputOperation);
    }

    private int findOutputNumber(IntermediateOperation output, IntermediateOperation op) {
        for (int i = 0; i < output.outputs.size(); ++i) {
            if (!output.outputs.get(i).equals(op)) continue;
            return i;
        }
        return -1;
    }

    TensorType.Value resultValueType() {
        return TensorType.Value.largestOf(this.inputs.stream().map(input -> input.type().get().type().valueType()).collect(Collectors.toList()));
    }

    public abstract IntermediateOperation withInputs(List<IntermediateOperation> var1);

    String asString(Optional<OrderedTensorType> type) {
        return type.map(t -> t.toString()).orElse("(unknown)");
    }

    public static String namePartOf(String name) {
        name = name.startsWith("^") ? name.substring(1) : name;
        return name.split(":")[0];
    }

    public static int indexPartOf(String name) {
        int i = name.indexOf(":");
        return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
    }

    public abstract String operationName();

    public String toString() {
        return this.operationName() + "(" + this.inputs().stream().map(input -> this.asString(input.type())).collect(Collectors.joining(", ")) + ")";
    }

    public String toFullString() {
        return "\t" + this.type + ":\t" + this.operationName() + "(" + this.inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) + ")";
    }

    public static interface AttributeMap {
        public Optional<Value> get(String var1);

        public Optional<Value> get(String var1, OrderedTensorType var2);

        public Optional<List<Value>> getList(String var1);
    }
}

