/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.rule;

import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.Deque;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.stream.Collectors;

public class LambdaFunctionNode
extends CompositeNode {
    private final List<String> arguments;
    private final ExpressionNode functionExpression;

    public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) {
        if (!arguments.containsAll(LambdaFunctionNode.featuresAccessedIn(functionExpression))) {
            throw new IllegalArgumentException("Lambda " + String.valueOf(functionExpression) + " accesses features outside its scope: " + LambdaFunctionNode.featuresAccessedIn(functionExpression).stream().filter(f -> !arguments.contains(f)).collect(Collectors.joining(", ")));
        }
        this.arguments = List.copyOf(arguments);
        this.functionExpression = functionExpression;
    }

    public String singleArgumentName() {
        if (this.arguments.size() != 1) {
            throw new IllegalArgumentException("Cannot apply " + String.valueOf(this) + " in map, must have a single argument");
        }
        return this.arguments.get(0);
    }

    @Override
    public List<ExpressionNode> children() {
        return List.of(this.functionExpression);
    }

    @Override
    public CompositeNode setChildren(List<ExpressionNode> children) {
        if (children.size() != 1) {
            throw new IllegalArgumentException("A lambda function must have a single child expression");
        }
        return new LambdaFunctionNode(this.arguments, children.get(0));
    }

    @Override
    public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
        string.append("f(").append(this.commaSeparated(this.arguments)).append(")(");
        return this.functionExpression.toString(string, context, path, this).append(")");
    }

    private String commaSeparated(List<String> list) {
        StringBuilder b = new StringBuilder();
        for (String element : list) {
            b.append(element).append(",");
        }
        if (b.length() > 0) {
            b.setLength(b.length() - 1);
        }
        return b.toString();
    }

    @Override
    public TensorType type(TypeContext<Reference> context) {
        return TensorType.empty;
    }

    @Override
    public Value evaluate(Context context) {
        return this.functionExpression.evaluate(context);
    }

    public DoubleUnaryOperator asDoubleUnaryOperator() {
        if (this.arguments.size() > 1) {
            throw new IllegalStateException("Cannot apply " + String.valueOf(this) + " as a DoubleUnaryOperator: Must have at most one argument  but has " + String.valueOf(this.arguments));
        }
        return new DoubleUnaryLambda();
    }

    public DoubleBinaryOperator asDoubleBinaryOperator() {
        if (this.arguments.size() > 2) {
            throw new IllegalStateException("Cannot apply " + String.valueOf(this) + " as a DoubleBinaryOperator: Must have at most two argument  but has " + String.valueOf(this.arguments));
        }
        return this.getDirectEvaluator().orElseGet(() -> new DoubleBinaryLambda());
    }

    private Optional<DoubleBinaryOperator> getDirectEvaluator() {
        ReferenceNode lhs;
        ExpressionNode expressionNode;
        OperationNode node;
        block17: {
            block16: {
                ExpressionNode expressionNode2 = this.functionExpression;
                if (!(expressionNode2 instanceof OperationNode)) {
                    return Optional.empty();
                }
                node = (OperationNode)expressionNode2;
                expressionNode = node.children().get(0);
                if (!(expressionNode instanceof ReferenceNode)) break block16;
                lhs = (ReferenceNode)expressionNode;
                expressionNode = node.children().get(1);
                if (expressionNode instanceof ReferenceNode) break block17;
            }
            return Optional.empty();
        }
        ReferenceNode rhs = (ReferenceNode)expressionNode;
        if (!lhs.getName().equals(this.arguments.get(0)) || !rhs.getName().equals(this.arguments.get(1))) {
            return Optional.empty();
        }
        if (node.operators().size() != 1) {
            return Optional.empty();
        }
        Operator operator = node.operators().get(0);
        return switch (operator) {
            case Operator.or -> this.asFunctionExpression((left, right) -> left != 0.0 || right != 0.0 ? 1.0 : 0.0);
            case Operator.and -> this.asFunctionExpression((left, right) -> left != 0.0 && right != 0.0 ? 1.0 : 0.0);
            case Operator.plus -> this.asFunctionExpression((left, right) -> left + right);
            case Operator.minus -> this.asFunctionExpression((left, right) -> left - right);
            case Operator.multiply -> this.asFunctionExpression((left, right) -> left * right);
            case Operator.divide -> this.asFunctionExpression((left, right) -> left / right);
            case Operator.modulo -> this.asFunctionExpression((left, right) -> left % right);
            case Operator.power -> this.asFunctionExpression(Math::pow);
            default -> Optional.empty();
        };
    }

    private Optional<DoubleBinaryOperator> asFunctionExpression(final DoubleBinaryOperator operator) {
        return Optional.of(new DoubleBinaryOperator(){

            @Override
            public double applyAsDouble(double left, double right) {
                return operator.applyAsDouble(left, right);
            }

            public String toString() {
                return LambdaFunctionNode.this.toString();
            }
        });
    }

    private static Set<String> featuresAccessedIn(ExpressionNode node) {
        HashSet<String> features = new HashSet<String>();
        new FeatureFinder(features).process(node);
        return features;
    }

    @Override
    public int hashCode() {
        return Objects.hash("lambdaFunction", this.arguments, this.functionExpression);
    }

    private class DoubleUnaryLambda
    implements DoubleUnaryOperator {
        private DoubleUnaryLambda() {
        }

        @Override
        public double applyAsDouble(double operand) {
            MapContext context = new MapContext();
            if (LambdaFunctionNode.this.arguments.size() > 0) {
                context.put(LambdaFunctionNode.this.arguments.get(0), operand);
            }
            return LambdaFunctionNode.this.evaluate(context).asDouble();
        }

        public String toString() {
            return LambdaFunctionNode.this.toString();
        }
    }

    private static class FeatureFinder {
        private final Set<String> target;
        private final Set<String> localVariables = new HashSet<String>();

        FeatureFinder(Set<String> target) {
            this.target = target;
        }

        void process(ExpressionNode node) {
            TensorFunctionNode t;
            TensorFunction<Reference> fun;
            if (node instanceof ReferenceNode) {
                ReferenceNode refNode = (ReferenceNode)node;
                String featureName = refNode.reference().toString();
                if (!this.localVariables.contains(featureName)) {
                    this.target.add(featureName);
                }
                return;
            }
            Optional<FeatureFinder> subProcessor = Optional.empty();
            if (node instanceof TensorFunctionNode && (fun = (t = (TensorFunctionNode)node).function()) instanceof Generate) {
                Generate g = (Generate)fun;
                FeatureFinder ff = new FeatureFinder(this.target);
                TensorType genType = g.type(null);
                for (TensorType.Dimension dim : genType.dimensions()) {
                    ff.localVariables.add(dim.name());
                }
                subProcessor = Optional.of(ff);
            }
            if (node instanceof CompositeNode) {
                CompositeNode composite = (CompositeNode)node;
                FeatureFinder processor = subProcessor.orElse(this);
                composite.children().forEach(child -> processor.process((ExpressionNode)child));
            }
        }
    }

    private class DoubleBinaryLambda
    implements DoubleBinaryOperator {
        private DoubleBinaryLambda() {
        }

        @Override
        public double applyAsDouble(double left, double right) {
            MapContext context = new MapContext();
            if (LambdaFunctionNode.this.arguments.size() > 0) {
                context.put(LambdaFunctionNode.this.arguments.get(0), left);
            }
            if (LambdaFunctionNode.this.arguments.size() > 1) {
                context.put(LambdaFunctionNode.this.arguments.get(1), right);
            }
            return LambdaFunctionNode.this.evaluate(context).asDouble();
        }

        public String toString() {
            return LambdaFunctionNode.this.toString();
        }
    }
}

