/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;

import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

public abstract class IntermediateOperation {
    private static final String FUNCTION_PREFIX = "imported_ml_function_";
    protected final String name;
    protected final String modelName;
    protected final List<IntermediateOperation> inputs;
    protected final List<IntermediateOperation> outputs = new ArrayList<IntermediateOperation>();
    protected OrderedTensorType type;
    protected TensorFunction function;
    protected TensorFunction rankingExpressionFunction = null;
    private final List<String> importWarnings = new ArrayList<String>();
    private Value constantValue = null;
    private List<IntermediateOperation> controlInputs = Collections.emptyList();
    protected Function<OrderedTensorType, Value> constantValueFunction = null;

    IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) {
        this.name = name;
        this.modelName = modelName;
        this.inputs = Collections.unmodifiableList(inputs);
        this.inputs.forEach(i -> i.outputs.add(this));
    }

    protected abstract OrderedTensorType lazyGetType();

    protected abstract TensorFunction lazyGetFunction();

    public Optional<OrderedTensorType> type() {
        if (this.type == null) {
            this.type = this.lazyGetType();
        }
        return Optional.ofNullable(this.type);
    }

    public Optional<TensorFunction> function() {
        if (this.function == null) {
            if (this.isConstant()) {
                ReferenceNode constant = new ReferenceNode(Reference.simple("constant", this.vespaName()));
                this.function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
            } else if (this.outputs.size() > 1) {
                this.rankingExpressionFunction = this.lazyGetFunction();
                this.function = new VariableTensor(this.rankingExpressionFunctionName(), this.type.type());
            } else {
                this.function = this.lazyGetFunction();
            }
        }
        return Optional.ofNullable(this.function);
    }

    public String name() {
        return this.name;
    }

    public List<IntermediateOperation> inputs() {
        return this.inputs;
    }

    public List<IntermediateOperation> outputs() {
        return Collections.unmodifiableList(this.outputs);
    }

    public Optional<TensorFunction> rankingExpressionFunction() {
        return Optional.ofNullable(this.rankingExpressionFunction);
    }

    public void addDimensionNameConstraints(DimensionRenamer renamer) {
    }

    public void renameDimensions(DimensionRenamer renamer) {
        this.type = this.type.rename(renamer);
    }

    public boolean isInput() {
        return false;
    }

    public boolean isConstant() {
        return this.inputs.stream().allMatch(IntermediateOperation::isConstant);
    }

    public void setConstantValue(Value value) {
        this.constantValue = value;
    }

    public Optional<Value> getConstantValue() {
        if (this.constantValue != null) {
            return Optional.of(this.constantValue);
        }
        if (this.constantValueFunction != null) {
            return Optional.of(this.constantValueFunction.apply(this.type));
        }
        return Optional.empty();
    }

    public void setConstantValueFunction(Function<OrderedTensorType, Value> func) {
        this.constantValueFunction = func;
    }

    public void setControlInputs(List<IntermediateOperation> inputs) {
        this.controlInputs = inputs;
    }

    public List<IntermediateOperation> getControlInputs() {
        return Collections.unmodifiableList(this.controlInputs);
    }

    public String vespaName() {
        return this.vespaName(this.name);
    }

    public String vespaName(String name) {
        return name != null ? IntermediateOperation.namePartOf(name).replace('/', '_') : null;
    }

    public String rankingExpressionFunctionName() {
        return this.vespaName() != null ? FUNCTION_PREFIX + this.modelName + "_" + this.vespaName() : null;
    }

    public List<String> warnings() {
        return Collections.unmodifiableList(this.importWarnings);
    }

    public void warning(String warning) {
        this.importWarnings.add(warning);
    }

    boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) {
        if (this.inputs.size() != expected) {
            throw new IllegalArgumentException("Expected " + expected + " inputs for '" + this.name + "', got " + this.inputs.size());
        }
        return this.inputs.stream().map(func).allMatch(Optional::isPresent);
    }

    boolean allInputTypesPresent(int expected) {
        return this.verifyInputs(expected, IntermediateOperation::type);
    }

    boolean allInputFunctionsPresent(int expected) {
        return this.verifyInputs(expected, IntermediateOperation::function);
    }

    public static String namePartOf(String name) {
        name = name.startsWith("^") ? name.substring(1) : name;
        return name.split(":")[0];
    }

    public static int indexPartOf(String name) {
        int i = name.indexOf(":");
        return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
    }

    public static interface AttributeMap {
        public Optional<Value> get(String var1);

        public Optional<Value> get(String var1, OrderedTensorType var2);

        public Optional<List<Value>> getList(String var1);
    }
}

