/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.rule;

import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.BooleanNode;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;

public class SetMembershipNode
extends BooleanNode {
    private final ExpressionNode testValue;
    private final List<ExpressionNode> setValues;

    public SetMembershipNode(ExpressionNode testValue, List<ExpressionNode> setValues) {
        this.testValue = testValue;
        this.setValues = List.copyOf(setValues);
    }

    public ExpressionNode getTestValue() {
        return this.testValue;
    }

    public List<ExpressionNode> getSetValues() {
        return this.setValues;
    }

    @Override
    public List<ExpressionNode> children() {
        ArrayList<ExpressionNode> children = new ArrayList<ExpressionNode>();
        children.add(this.testValue);
        children.addAll(this.setValues);
        return children;
    }

    @Override
    public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
        this.testValue.toString(string, context, path, this);
        string.append(" in [");
        int len = this.setValues.size();
        for (int i = 0; i < len; ++i) {
            this.setValues.get(i).toString(string, context, path, this);
            if (i >= len - 1) continue;
            string.append(", ");
        }
        string.append("]");
        return string;
    }

    @Override
    public TensorType type(TypeContext<Reference> context) {
        return TensorType.empty;
    }

    @Override
    public Value evaluate(Context context) {
        Value value = this.testValue.evaluate(context);
        if (value instanceof TensorValue) {
            return this.evaluateTensor(((TensorValue)value).asTensor(), context);
        }
        return this.evaluateValue(value, context);
    }

    private Value evaluateValue(Value value, Context context) {
        return new BooleanValue(this.testMembership(value::equals, context));
    }

    private Value evaluateTensor(Tensor tensor, Context context) {
        return new TensorValue(tensor.map(value -> this.contains(value, context) ? 1.0 : 0.0));
    }

    private boolean contains(double value, Context context) {
        return this.testMembership(setValue -> setValue.asDouble() == value, context);
    }

    private boolean testMembership(Predicate<Value> test, Context context) {
        for (ExpressionNode setValue : this.setValues) {
            if (!test.test(setValue.evaluate(context))) continue;
            return true;
        }
        return false;
    }

    @Override
    public SetMembershipNode setChildren(List<ExpressionNode> children) {
        if (children.size() < 1) {
            throw new IllegalArgumentException("A set membership test must have at least 1 child");
        }
        return new SetMembershipNode(children.get(0), children.subList(1, children.size()));
    }

    @Override
    public int hashCode() {
        return Objects.hash("setMembership", this.testValue, this.setValues);
    }
}

