/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.rule;

import com.google.common.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Beta
public class TensorFunctionNode
extends CompositeNode {
    private final TensorFunction function;

    public TensorFunctionNode(TensorFunction function) {
        this.function = function;
    }

    public TensorFunction function() {
        return this.function;
    }

    @Override
    public List<ExpressionNode> children() {
        return this.function.arguments().stream().map(this::toExpressionNode).collect(Collectors.toList());
    }

    private ExpressionNode toExpressionNode(TensorFunction f) {
        if (f instanceof ExpressionTensorFunction) {
            return ((ExpressionTensorFunction)f).expression;
        }
        return new TensorFunctionNode(f);
    }

    @Override
    public CompositeNode setChildren(List<ExpressionNode> children) {
        List wrappedChildren = children.stream().map(ExpressionTensorFunction::new).collect(Collectors.toList());
        return new TensorFunctionNode(this.function.withArguments(wrappedChildren));
    }

    @Override
    public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
        return string.append(this.function.toPrimitive().toString((ToStringContext)new ExpressionToStringContext(context, path, this)));
    }

    @Override
    public TensorType type(TypeContext<Reference> context) {
        return this.function.type(context);
    }

    @Override
    public Value evaluate(Context context) {
        return new TensorValue(this.function.evaluate((EvaluationContext)context));
    }

    public static ExpressionTensorFunction wrap(ExpressionNode node) {
        return new ExpressionTensorFunction(node);
    }

    public static Map<TensorAddress, ScalarFunction> wrap(Map<TensorAddress, ExpressionNode> nodes) {
        LinkedHashMap<TensorAddress, ScalarFunction> functions = new LinkedHashMap<TensorAddress, ScalarFunction>();
        for (Map.Entry<TensorAddress, ExpressionNode> entry : nodes.entrySet()) {
            functions.put(entry.getKey(), new ExpressionScalarFunction(entry.getValue()));
        }
        return functions;
    }

    public static List<ScalarFunction> wrap(List<ExpressionNode> nodes) {
        ArrayList<ScalarFunction> functions = new ArrayList<ScalarFunction>();
        for (ExpressionNode entry : nodes) {
            functions.add(new ExpressionScalarFunction(entry));
        }
        return functions;
    }

    private static class ExpressionToStringContext
    implements ToStringContext {
        final SerializationContext context;
        final Deque<String> path;
        final CompositeNode parent;
        public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(), null, null);

        public ExpressionToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
            this.context = context;
            this.path = path;
            this.parent = parent;
        }
    }

    public static class ExpressionTensorFunction
    extends PrimitiveTensorFunction {
        private final ExpressionNode expression;

        public ExpressionTensorFunction(ExpressionNode expression) {
            this.expression = expression;
        }

        public List<TensorFunction> arguments() {
            if (this.expression instanceof CompositeNode) {
                return ((CompositeNode)this.expression).children().stream().map(ExpressionTensorFunction::new).collect(Collectors.toList());
            }
            return Collections.emptyList();
        }

        public TensorFunction withArguments(List<TensorFunction> arguments) {
            if (arguments.size() == 0) {
                return this;
            }
            List<ExpressionNode> unwrappedChildren = arguments.stream().map(arg -> ((ExpressionTensorFunction)arg).expression).collect(Collectors.toList());
            return new ExpressionTensorFunction(((CompositeNode)this.expression).setChildren(unwrappedChildren));
        }

        public PrimitiveTensorFunction toPrimitive() {
            return this;
        }

        public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
            return this.expression.type(context);
        }

        public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
            return this.expression.evaluate((Context)context).asTensor();
        }

        public String toString() {
            return this.toString(ExpressionToStringContext.empty);
        }

        public String toString(ToStringContext c) {
            if (c instanceof ExpressionToStringContext) {
                ExpressionToStringContext context = (ExpressionToStringContext)c;
                return this.expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString();
            }
            return this.expression.toString();
        }
    }

    private static class ExpressionScalarFunction
    implements ScalarFunction {
        private final ExpressionNode expression;

        public ExpressionScalarFunction(ExpressionNode expression) {
            this.expression = expression;
        }

        public Double apply(EvaluationContext<?> context) {
            return this.expression.evaluate((Context)context).asDouble();
        }

        public String toString() {
            return this.toString(ExpressionToStringContext.empty);
        }

        public String toString(ToStringContext c) {
            if (c instanceof ExpressionToStringContext) {
                ExpressionToStringContext context = (ExpressionToStringContext)c;
                return this.expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString();
            }
            return this.expression.toString();
        }
    }
}

