/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;

import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import org.tensorflow.framework.NodeDef;

public abstract class TensorFlowOperation {
    protected static final String MACRO_PREFIX = "tf_macro_";
    protected final NodeDef node;
    protected final int port;
    protected final List<TensorFlowOperation> inputs;
    protected final List<TensorFlowOperation> outputs = new ArrayList<TensorFlowOperation>();
    protected final List<String> importWarnings = new ArrayList<String>();
    protected OrderedTensorType type;
    protected TensorFunction function;
    protected TensorFunction macro = null;
    private Value constantValue = null;
    private List<TensorFlowOperation> controlInputs = Collections.emptyList();

    TensorFlowOperation(NodeDef node, List<TensorFlowOperation> inputs, int port) {
        this.node = node;
        this.port = port;
        this.inputs = Collections.unmodifiableList(inputs);
        this.inputs.forEach(i -> i.outputs.add(this));
    }

    protected abstract OrderedTensorType lazyGetType();

    protected abstract TensorFunction lazyGetFunction();

    public Optional<OrderedTensorType> type() {
        if (this.type == null) {
            this.type = this.lazyGetType();
        }
        OrderedTensorType.verifyType(this.node, this.type);
        return Optional.ofNullable(this.type);
    }

    public Optional<TensorFunction> function() {
        if (this.function == null) {
            if (this.isConstant()) {
                ReferenceNode constant = new ReferenceNode(Reference.simple("constant", this.vespaName()));
                this.function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
            } else if (this.outputs.size() > 1) {
                this.macro = this.lazyGetFunction();
                this.function = new VariableTensor(this.macroName(), this.type.type());
            } else {
                this.function = this.lazyGetFunction();
            }
        }
        return Optional.ofNullable(this.function);
    }

    public NodeDef node() {
        return this.node;
    }

    public List<TensorFlowOperation> inputs() {
        return this.inputs;
    }

    public List<TensorFlowOperation> outputs() {
        return Collections.unmodifiableList(this.outputs);
    }

    public Optional<TensorFunction> macro() {
        return Optional.ofNullable(this.macro);
    }

    public void addDimensionNameConstraints(DimensionRenamer renamer) {
    }

    public void renameDimensions(DimensionRenamer renamer) {
        this.type = OrderedTensorType.rename(this.type, renamer);
    }

    public boolean isInput() {
        return false;
    }

    public boolean isConstant() {
        return this.inputs.stream().allMatch(TensorFlowOperation::isConstant);
    }

    public void setConstantValue(Value value) {
        this.constantValue = value;
    }

    public Optional<Value> getConstantValue() {
        return Optional.ofNullable(this.constantValue);
    }

    public void setControlInputs(List<TensorFlowOperation> inputs) {
        this.controlInputs = inputs;
    }

    public List<TensorFlowOperation> getControlInputs() {
        return Collections.unmodifiableList(this.controlInputs);
    }

    public String vespaName() {
        return this.node.getName() != null ? this.node.getName().replace('/', '_') : null;
    }

    public String macroName() {
        return this.vespaName() != null ? MACRO_PREFIX + this.vespaName() : null;
    }

    public List<String> warnings() {
        return Collections.unmodifiableList(this.importWarnings);
    }

    boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) {
        if (!this.controlInputs.stream().map(func).allMatch(Optional::isPresent)) {
            return false;
        }
        if (this.inputs.size() != expected) {
            throw new IllegalArgumentException("Expected " + expected + " inputs for '" + this.node.getName() + "', got " + this.inputs.size());
        }
        return this.inputs.stream().map(func).allMatch(Optional::isPresent);
    }

    boolean allInputTypesPresent(int expected) {
        return this.verifyInputs(expected, TensorFlowOperation::type);
    }

    boolean allInputFunctionsPresent(int expected) {
        return this.verifyInputs(expected, TensorFlowOperation::function);
    }
}

