/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.rule;

import com.google.common.collect.ImmutableMap;
import com.yahoo.api.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

@Beta
public class TensorFunctionNode
extends CompositeNode {
    private final TensorFunction<Reference> function;

    public TensorFunctionNode(TensorFunction<Reference> function) {
        this.function = function;
    }

    public TensorFunction<Reference> function() {
        return this.function;
    }

    @Override
    public List<ExpressionNode> children() {
        return this.function.arguments().stream().map(this::toExpressionNode).collect(Collectors.toList());
    }

    private ExpressionNode toExpressionNode(TensorFunction<Reference> f) {
        if (f instanceof ExpressionTensorFunction) {
            return ((ExpressionTensorFunction)f).expression;
        }
        return new TensorFunctionNode(f);
    }

    @Override
    public CompositeNode setChildren(List<ExpressionNode> children) {
        List wrappedChildren = children.stream().map(ExpressionTensorFunction::new).collect(Collectors.toList());
        return new TensorFunctionNode((TensorFunction<Reference>)this.function.withArguments(wrappedChildren));
    }

    @Override
    public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
        return string.append(this.function.toPrimitive().toString((ToStringContext)new ExpressionToStringContext(context, path, this)));
    }

    @Override
    public TensorType type(TypeContext<Reference> context) {
        return this.function.type(context);
    }

    @Override
    public Value evaluate(Context context) {
        return new TensorValue(this.function.evaluate((EvaluationContext)context));
    }

    public static ExpressionTensorFunction wrap(ExpressionNode node) {
        return new ExpressionTensorFunction(node);
    }

    public static Map<TensorAddress, ScalarFunction<Reference>> wrapScalars(Map<TensorAddress, ExpressionNode> nodes) {
        LinkedHashMap<TensorAddress, ScalarFunction<Reference>> functions = new LinkedHashMap<TensorAddress, ScalarFunction<Reference>>();
        for (Map.Entry<TensorAddress, ExpressionNode> entry : nodes.entrySet()) {
            functions.put(entry.getKey(), TensorFunctionNode.wrapScalar(entry.getValue()));
        }
        return functions;
    }

    public static void wrapScalarBlock(TensorType type, List<String> dimensionOrder, String mappedDimensionLabel, List<ExpressionNode> nodes, Map<TensorAddress, ScalarFunction<Reference>> receivingMap) {
        TensorType denseSubtype = new TensorType(type.valueType(), (Collection)type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList()));
        ArrayList<String> denseDimensionOrder = new ArrayList<String>(dimensionOrder);
        denseDimensionOrder.retainAll(denseSubtype.dimensionNames());
        IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of((TensorType)denseSubtype, denseDimensionOrder);
        if (indexes.size() != (long)nodes.size()) {
            throw new IllegalArgumentException("At '" + mappedDimensionLabel + "': Need " + indexes.size() + " values to fill a dense subspace of " + type + " but got " + nodes.size());
        }
        for (ExpressionNode node : nodes) {
            indexes.next();
            String[] labels = new String[type.rank()];
            int indexedDimensionsIndex = 0;
            int allDimensionsIndex = 0;
            for (TensorType.Dimension dimension : type.dimensions()) {
                if (dimension.isIndexed()) {
                    labels[allDimensionsIndex++] = String.valueOf(indexes.indexesForReading()[indexedDimensionsIndex++]);
                    continue;
                }
                labels[allDimensionsIndex++] = mappedDimensionLabel;
            }
            receivingMap.put(TensorAddress.of((String[])labels), TensorFunctionNode.wrapScalar(node));
        }
    }

    public static List<ScalarFunction<Reference>> wrapScalars(TensorType type, List<String> dimensionOrder, List<ExpressionNode> nodes) {
        IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of((TensorType)type, dimensionOrder);
        if (indexes.size() != (long)nodes.size()) {
            throw new IllegalArgumentException("Need " + indexes.size() + " values to fill " + type + " but got " + nodes.size());
        }
        ArrayList<ScalarFunction<Reference>> wrapped = new ArrayList<ScalarFunction<Reference>>(nodes.size());
        while (indexes.hasNext()) {
            indexes.next();
            wrapped.add(TensorFunctionNode.wrapScalar(nodes.get((int)indexes.toSourceValueIndex())));
        }
        return wrapped;
    }

    public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) {
        return new ExpressionScalarFunction(node);
    }

    @Override
    public int hashCode() {
        return this.function.hashCode();
    }

    public static class ExpressionTensorFunction
    extends PrimitiveTensorFunction<Reference> {
        private final ExpressionNode expression;

        public ExpressionTensorFunction(ExpressionNode expression) {
            this.expression = expression;
        }

        public List<TensorFunction<Reference>> arguments() {
            if (this.expression instanceof CompositeNode) {
                return ((CompositeNode)this.expression).children().stream().map(ExpressionTensorFunction::new).collect(Collectors.toList());
            }
            return Collections.emptyList();
        }

        public TensorFunction<Reference> withArguments(List<TensorFunction<Reference>> arguments) {
            if (arguments.size() == 0) {
                return this;
            }
            List<ExpressionNode> unwrappedChildren = arguments.stream().map(arg -> ((ExpressionTensorFunction)arg).expression).collect(Collectors.toList());
            return new ExpressionTensorFunction(((CompositeNode)this.expression).setChildren(unwrappedChildren));
        }

        public PrimitiveTensorFunction<Reference> toPrimitive() {
            return this;
        }

        public TensorType type(TypeContext<Reference> context) {
            return this.expression.type(context);
        }

        public Optional<ScalarFunction<Reference>> asScalarFunction() {
            return Optional.of(new ExpressionScalarFunction(this.expression));
        }

        public Tensor evaluate(EvaluationContext<Reference> context) {
            return this.expression.evaluate((Context)context).asTensor();
        }

        public String toString() {
            return this.toString(ExpressionToStringContext.empty);
        }

        public int hashCode() {
            return this.expression.hashCode();
        }

        public String toString(ToStringContext<Reference> c) {
            ToStringContext outermost = c;
            while (outermost.parent() != null) {
                outermost = outermost.parent();
            }
            if (outermost instanceof ExpressionToStringContext) {
                ExpressionToStringContext context = (ExpressionToStringContext)outermost;
                return this.expression.toString(new StringBuilder(), new ExpressionToStringContext(context.wrappedSerializationContext, (ToStringContext<Reference>)c, context.path, context.parent), context.path, context.parent).toString();
            }
            return this.expression.toString();
        }
    }

    private static class ExpressionToStringContext
    extends SerializationContext
    implements ToStringContext<Reference> {
        private final ToStringContext<Reference> wrappedToStringContext;
        private final SerializationContext wrappedSerializationContext;
        private final Deque<String> path;
        private final CompositeNode parent;
        public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(), null, null);

        ExpressionToStringContext(SerializationContext wrappedSerializationContext, Deque<String> path, CompositeNode parent) {
            this(wrappedSerializationContext, null, path, parent);
        }

        ExpressionToStringContext(SerializationContext wrappedSerializationContext, ToStringContext<Reference> wrappedToStringContext, Deque<String> path, CompositeNode parent) {
            this.wrappedSerializationContext = wrappedSerializationContext;
            this.wrappedToStringContext = wrappedToStringContext;
            this.path = path;
            this.parent = parent;
        }

        @Override
        public void addFunctionSerialization(String name, String expressionString) {
            this.wrappedSerializationContext.addFunctionSerialization(name, expressionString);
        }

        @Override
        public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) {
            this.wrappedSerializationContext.addArgumentTypeSerialization(functionName, argumentName, type);
        }

        @Override
        public void addFunctionTypeSerialization(String functionName, TensorType type) {
            this.wrappedSerializationContext.addFunctionTypeSerialization(functionName, type);
        }

        @Override
        public Map<String, String> serializedFunctions() {
            return this.wrappedSerializationContext.serializedFunctions();
        }

        @Override
        public ExpressionFunction getFunction(String name) {
            return this.wrappedSerializationContext.getFunction(name);
        }

        @Override
        public Optional<TypeContext<Reference>> typeContext() {
            return this.wrappedSerializationContext.typeContext();
        }

        @Override
        @Deprecated(forRemoval=true, since="7")
        protected ImmutableMap<String, ExpressionFunction> functions() {
            return ImmutableMap.copyOf(this.wrappedSerializationContext.getFunctions());
        }

        @Override
        protected Map<String, ExpressionFunction> getFunctions() {
            return this.wrappedSerializationContext.getFunctions();
        }

        public ToStringContext<Reference> parent() {
            return this.wrappedToStringContext;
        }

        @Override
        public String getBinding(String name) {
            if (this.wrappedToStringContext != null && this.wrappedToStringContext.getBinding(name) != null) {
                return this.wrappedToStringContext.getBinding(name);
            }
            return this.wrappedSerializationContext.getBinding(name);
        }

        @Override
        public ExpressionToStringContext withBindings(Map<String, String> bindings) {
            SerializationContext serializationContext = new SerializationContext(this.getFunctions(), bindings, this.typeContext(), this.serializedFunctions());
            return new ExpressionToStringContext(serializationContext, this.wrappedToStringContext, this.path, this.parent);
        }

        @Override
        public SerializationContext withoutBindings() {
            SerializationContext serializationContext = new SerializationContext(this.getFunctions(), null, this.typeContext(), this.serializedFunctions());
            return new ExpressionToStringContext(serializationContext, null, this.path, this.parent);
        }

        public String toString() {
            return "TensorFunctionNode.ExpressionToStringContext with wrapped serialization context: " + this.wrappedSerializationContext;
        }
    }

    private static class ExpressionScalarFunction
    implements ScalarFunction<Reference> {
        private final ExpressionNode expression;

        public ExpressionScalarFunction(ExpressionNode expression) {
            this.expression = expression;
        }

        public Double apply(EvaluationContext<Reference> context) {
            return this.expression.evaluate(new ContextWrapper(context)).asDouble();
        }

        public Optional<TensorFunction<Reference>> asTensorFunction() {
            return Optional.of(new ExpressionTensorFunction(this.expression));
        }

        public String toString() {
            return this.toString(ExpressionToStringContext.empty);
        }

        public String toString(ToStringContext<Reference> c) {
            ToStringContext outermost = c;
            while (outermost.parent() != null) {
                outermost = outermost.parent();
            }
            if (outermost instanceof ExpressionToStringContext) {
                ExpressionToStringContext context = (ExpressionToStringContext)outermost;
                ExpressionNode root = this.expression;
                if (root instanceof CompositeNode && !(root instanceof EmbracedNode) && !this.isIdentifierReference(root)) {
                    root = new EmbracedNode(root);
                }
                return root.toString(new StringBuilder(), new ExpressionToStringContext(context.wrappedSerializationContext, (ToStringContext<Reference>)c, context.path, context.parent), context.path, context.parent).toString();
            }
            return this.expression.toString();
        }

        private boolean isIdentifierReference(ExpressionNode node) {
            if (!(node instanceof ReferenceNode)) {
                return false;
            }
            return ((ReferenceNode)node).reference().isIdentifier();
        }
    }

    private static class ContextWrapper
    extends Context {
        private final EvaluationContext<Reference> delegate;

        public ContextWrapper(EvaluationContext<Reference> delegate) {
            this.delegate = delegate;
        }

        @Override
        public Value get(String name) {
            return new TensorValue(this.delegate.getTensor(name));
        }

        public TensorType getType(Reference name) {
            return this.delegate.getType((Name)name);
        }
    }
}

