/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;

import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import onnx.Onnx;

public abstract class OnnxOperation {
    protected final Onnx.NodeProto node;
    protected final List<OnnxOperation> inputs;
    protected final List<OnnxOperation> outputs = new ArrayList<OnnxOperation>();
    protected final List<String> importWarnings = new ArrayList<String>();
    protected OrderedTensorType type;
    protected TensorFunction function;
    protected Value constantValue = null;

    OnnxOperation(Onnx.NodeProto node, List<OnnxOperation> inputs) {
        this.node = node;
        this.inputs = Collections.unmodifiableList(inputs);
        this.inputs.forEach(i -> i.outputs.add(this));
    }

    protected abstract OrderedTensorType lazyGetType();

    protected abstract TensorFunction lazyGetFunction();

    public Optional<OrderedTensorType> type() {
        if (this.type == null) {
            this.type = this.lazyGetType();
        }
        return Optional.ofNullable(this.type);
    }

    public Optional<TensorFunction> function() {
        if (this.function == null) {
            if (this.isConstant()) {
                ReferenceNode constant = new ReferenceNode(Reference.simple("constant", this.vespaName()));
                this.function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
            } else {
                this.function = this.lazyGetFunction();
            }
        }
        return Optional.ofNullable(this.function);
    }

    public Onnx.NodeProto node() {
        return this.node;
    }

    public List<OnnxOperation> inputs() {
        return this.inputs;
    }

    public List<OnnxOperation> outputs() {
        return Collections.unmodifiableList(this.outputs);
    }

    public void addDimensionNameConstraints(DimensionRenamer renamer) {
    }

    public void renameDimensions(DimensionRenamer renamer) {
        this.type = this.type.rename(renamer);
    }

    public boolean isInput() {
        return false;
    }

    public boolean isConstant() {
        return this.inputs.stream().allMatch(OnnxOperation::isConstant);
    }

    public Optional<Value> getConstantValue() {
        return Optional.ofNullable(this.constantValue);
    }

    public String vespaName() {
        return this.vespaName(this.node.getName());
    }

    public String vespaName(String name) {
        return name != null ? name.replace('/', '_').replace(':', '_') : null;
    }

    public List<String> warnings() {
        return Collections.unmodifiableList(this.importWarnings);
    }

    public void warning(String warning) {
        this.importWarnings.add(warning);
    }

    boolean verifyInputs(int expected, Function<OnnxOperation, Optional<?>> func) {
        if (this.inputs.size() != expected) {
            throw new IllegalArgumentException("Expected " + expected + " inputs for '" + this.node.getName() + "', got " + this.inputs.size());
        }
        return this.inputs.stream().map(func).allMatch(Optional::isPresent);
    }

    boolean allInputTypesPresent(int expected) {
        return this.verifyInputs(expected, OnnxOperation::type);
    }

    boolean allInputFunctionsPresent(int expected) {
        return this.verifyInputs(expected, OnnxOperation::function);
    }
}

