/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.rule;

import com.google.common.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.Collections;
import java.util.Deque;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Beta
public class TensorFunctionNode
extends CompositeNode {
    private final TensorFunction<Reference> function;

    public TensorFunctionNode(TensorFunction<Reference> function) {
        this.function = function;
    }

    public TensorFunction<Reference> function() {
        return this.function;
    }

    @Override
    public List<ExpressionNode> children() {
        return this.function.arguments().stream().map(this::toExpressionNode).collect(Collectors.toList());
    }

    private ExpressionNode toExpressionNode(TensorFunction<Reference> f) {
        if (f instanceof ExpressionTensorFunction) {
            return ((ExpressionTensorFunction)f).expression;
        }
        return new TensorFunctionNode(f);
    }

    @Override
    public CompositeNode setChildren(List<ExpressionNode> children) {
        List wrappedChildren = children.stream().map(ExpressionTensorFunction::new).collect(Collectors.toList());
        return new TensorFunctionNode((TensorFunction<Reference>)this.function.withArguments(wrappedChildren));
    }

    @Override
    public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
        return string.append(this.function.toPrimitive().toString((ToStringContext)new ExpressionToStringContext(context, path, this)));
    }

    @Override
    public TensorType type(TypeContext<Reference> context) {
        return this.function.type(context);
    }

    @Override
    public Value evaluate(Context context) {
        return new TensorValue(this.function.evaluate((EvaluationContext)context));
    }

    public static ExpressionTensorFunction wrap(ExpressionNode node) {
        return new ExpressionTensorFunction(node);
    }

    public static Map<TensorAddress, ScalarFunction<Reference>> wrapScalars(Map<TensorAddress, ExpressionNode> nodes) {
        LinkedHashMap<TensorAddress, ScalarFunction<Reference>> functions = new LinkedHashMap<TensorAddress, ScalarFunction<Reference>>();
        for (Map.Entry<TensorAddress, ExpressionNode> entry : nodes.entrySet()) {
            functions.put(entry.getKey(), TensorFunctionNode.wrapScalar(entry.getValue()));
        }
        return functions;
    }

    public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) {
        return nodes.stream().map(node -> TensorFunctionNode.wrapScalar(node)).collect(Collectors.toList());
    }

    public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) {
        return new ExpressionScalarFunction(node);
    }

    private static class ContextWrapper
    extends Context {
        private final EvaluationContext<Reference> delegate;

        public ContextWrapper(EvaluationContext<Reference> delegate) {
            this.delegate = delegate;
        }

        @Override
        public Value get(String name) {
            return new TensorValue(this.delegate.getTensor(name));
        }

        public TensorType getType(Reference name) {
            return this.delegate.getType((Name)name);
        }
    }

    private static class ExpressionToStringContext
    implements ToStringContext {
        final SerializationContext context;
        final Deque<String> path;
        final CompositeNode parent;
        public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(), null, null);

        public ExpressionToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
            this.context = context;
            this.path = path;
            this.parent = parent;
        }
    }

    public static class ExpressionTensorFunction
    extends PrimitiveTensorFunction<Reference> {
        private final ExpressionNode expression;

        public ExpressionTensorFunction(ExpressionNode expression) {
            this.expression = expression;
        }

        public List<TensorFunction<Reference>> arguments() {
            if (this.expression instanceof CompositeNode) {
                return ((CompositeNode)this.expression).children().stream().map(ExpressionTensorFunction::new).collect(Collectors.toList());
            }
            return Collections.emptyList();
        }

        public TensorFunction<Reference> withArguments(List<TensorFunction<Reference>> arguments) {
            if (arguments.size() == 0) {
                return this;
            }
            List<ExpressionNode> unwrappedChildren = arguments.stream().map(arg -> ((ExpressionTensorFunction)arg).expression).collect(Collectors.toList());
            return new ExpressionTensorFunction(((CompositeNode)this.expression).setChildren(unwrappedChildren));
        }

        public PrimitiveTensorFunction<Reference> toPrimitive() {
            return this;
        }

        public TensorType type(TypeContext<Reference> context) {
            return this.expression.type(context);
        }

        public Tensor evaluate(EvaluationContext<Reference> context) {
            return this.expression.evaluate((Context)context).asTensor();
        }

        public String toString() {
            return this.toString(ExpressionToStringContext.empty);
        }

        public String toString(ToStringContext c) {
            if (c instanceof ExpressionToStringContext) {
                ExpressionToStringContext context = (ExpressionToStringContext)c;
                return this.expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString();
            }
            return this.expression.toString();
        }
    }

    private static class ExpressionScalarFunction
    implements ScalarFunction<Reference> {
        private final ExpressionNode expression;

        public ExpressionScalarFunction(ExpressionNode expression) {
            this.expression = expression;
        }

        public Double apply(EvaluationContext<Reference> context) {
            return this.expression.evaluate(new ContextWrapper(context)).asDouble();
        }

        public String toString() {
            return this.toString(ExpressionToStringContext.empty);
        }

        public String toString(ToStringContext c) {
            if (c instanceof ExpressionToStringContext) {
                ExpressionToStringContext context = (ExpressionToStringContext)c;
                return this.expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString();
            }
            return this.expression.toString();
        }
    }
}

