/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.schema.expressiontransforms;

import com.yahoo.schema.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;

public class TokenTransformer
extends ExpressionTransformer<RankProfileTransformContext> {
    private static final ConstantNode ZERO = new ConstantNode((Value)new DoubleValue(0.0));
    private static final ConstantNode ONE = new ConstantNode((Value)new DoubleValue(1.0));
    private static final ConstantNode TWO = new ConstantNode((Value)new DoubleValue(2.0));
    private static final ConstantNode CLS = new ConstantNode((Value)new DoubleValue(101.0));
    private static final ConstantNode SEP = new ConstantNode((Value)new DoubleValue(102.0));

    public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
        if (node instanceof ReferenceNode) {
            return this.transformFeature((ReferenceNode)node, context);
        }
        if (node instanceof CompositeNode) {
            return super.transformChildren((CompositeNode)node, (TransformContext)context);
        }
        return node;
    }

    private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
        if (feature.getName().equals("tokenInputIds") && this.shouldTransform(feature, context)) {
            return this.transformTokenInputIds(feature, context);
        }
        if (feature.getName().equals("tokenTypeIds") && this.shouldTransform(feature, context)) {
            return this.transformTokenTypeIds(feature, context);
        }
        if (feature.getName().equals("tokenAttentionMask") && this.shouldTransform(feature, context)) {
            return this.transformTokenAttentionMask(feature, context);
        }
        return feature;
    }

    private ExpressionNode transformTokenInputIds(ReferenceNode feature, RankProfileTransformContext context) {
        this.checkArguments(feature);
        TensorType type = TokenTransformer.createTensorType(feature.getName(), (ExpressionNode)feature.getArguments().expressions().get(0));
        this.createTokenLengthFunctions(feature, context);
        ExpressionNode tokenSequenceExpr = this.createTokenSequenceExpr(0, this.createTokenSequence(feature));
        return new TensorFunctionNode((TensorFunction)Generate.bound((TensorType)type, (ScalarFunction)TensorFunctionNode.wrapScalar((ExpressionNode)tokenSequenceExpr)));
    }

    private ExpressionNode transformTokenTypeIds(ReferenceNode feature, RankProfileTransformContext context) {
        this.checkArguments(feature);
        TensorType type = TokenTransformer.createTensorType(feature.getName(), (ExpressionNode)feature.getArguments().expressions().get(0));
        this.createTokenLengthFunctions(feature, context);
        List<ExpressionNode> tokenSequence = this.createTokenSequence(feature);
        ExpressionNode queryLengthExpr = this.createLengthExpr(2, tokenSequence);
        ExpressionNode restLengthExpr = this.createLengthExpr(tokenSequence.size() - 1, tokenSequence);
        IfNode expr = new IfNode((ExpressionNode)new ComparisonNode((ExpressionNode)new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr), (ExpressionNode)ZERO, (ExpressionNode)new IfNode((ExpressionNode)new ComparisonNode((ExpressionNode)new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr), (ExpressionNode)ONE, (ExpressionNode)ZERO));
        return new TensorFunctionNode((TensorFunction)Generate.bound((TensorType)type, (ScalarFunction)TensorFunctionNode.wrapScalar((ExpressionNode)expr)));
    }

    private ExpressionNode transformTokenAttentionMask(ReferenceNode feature, RankProfileTransformContext context) {
        this.checkArguments(feature);
        TensorType type = TokenTransformer.createTensorType(feature.getName(), (ExpressionNode)feature.getArguments().expressions().get(0));
        this.createTokenLengthFunctions(feature, context);
        List<ExpressionNode> tokenSequence = this.createTokenSequence(feature);
        ExpressionNode lengthExpr = this.createLengthExpr(tokenSequence.size() - 1, tokenSequence);
        ComparisonNode comparison = new ComparisonNode((ExpressionNode)new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
        IfNode expr = new IfNode((ExpressionNode)comparison, (ExpressionNode)ONE, (ExpressionNode)ZERO);
        return new TensorFunctionNode((TensorFunction)Generate.bound((TensorType)type, (ScalarFunction)TensorFunctionNode.wrapScalar((ExpressionNode)expr)));
    }

    private boolean shouldTransform(ReferenceNode feature, RankProfileTransformContext context) {
        if (context.rankProfile().getFunctions().containsKey(feature.getName())) {
            return false;
        }
        return feature.getArguments().size() >= 2;
    }

    private void checkArguments(ReferenceNode feature) {
        for (int i = 1; i < feature.getArguments().size(); ++i) {
            ExpressionNode arg = (ExpressionNode)feature.getArguments().expressions().get(i);
            if (arg instanceof ReferenceNode) continue;
            throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": the argument must be a reference. Got " + arg.toString());
        }
    }

    public static TensorType createTensorType(String featureName, ExpressionNode argument) {
        try {
            int length = Integer.parseInt(argument.toString());
            return new TensorType.Builder(TensorType.Value.FLOAT).indexed("d0", 1L).indexed("d1", (long)length).build();
        }
        catch (NumberFormatException ex) {
            throw new IllegalArgumentException("Invalid argument to " + featureName + ": the first argument must be the length to the token sequence to generate. Got " + argument);
        }
    }

    private String lengthFunctionName(ReferenceNode arg) {
        return "__token_length@" + arg.hashCode();
    }

    private List<ExpressionNode> createTokenSequence(ReferenceNode feature) {
        ArrayList<ExpressionNode> sequence = new ArrayList<ExpressionNode>();
        sequence.add((ExpressionNode)CLS);
        for (int i = 1; i < feature.getArguments().size(); ++i) {
            sequence.add((ExpressionNode)feature.getArguments().expressions().get(i));
            sequence.add((ExpressionNode)SEP);
        }
        return sequence;
    }

    private void createTokenLengthFunctions(ReferenceNode feature, RankProfileTransformContext context) {
        for (int i = 1; i < feature.getArguments().size(); ++i) {
            ExpressionNode arg = (ExpressionNode)feature.getArguments().expressions().get(i);
            if (!(arg instanceof ReferenceNode)) {
                throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": the argument must be a reference. Got " + arg.toString());
            }
            ReferenceNode ref = (ReferenceNode)arg;
            String functionName = this.lengthFunctionName(ref);
            if (context.rankProfile().getFunctions().containsKey(functionName)) continue;
            context.rankProfile().addFunction(functionName, List.of(), "sum(map(" + ref + ", f(x)(x > 0)))", false);
        }
    }

    private ExpressionNode createTokenSequenceExpr(int iter, List<ExpressionNode> sequence) {
        ExpressionNode lengthExpr = this.createLengthExpr(iter, sequence);
        ComparisonNode comparison = new ComparisonNode((ExpressionNode)new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
        ExpressionNode trueExpr = sequence.get(iter);
        if (sequence.get(iter) instanceof ReferenceNode) {
            trueExpr = this.createTokenExtractExpr(iter, sequence);
        }
        Object falseExpr = iter < sequence.size() - 1 ? this.createTokenSequenceExpr(iter + 1, sequence) : ZERO;
        return new IfNode((ExpressionNode)comparison, trueExpr, (ExpressionNode)falseExpr);
    }

    private ExpressionNode createLengthExpr(int iter, List<ExpressionNode> sequence) {
        ArrayList<Object> factors = new ArrayList<Object>();
        ArrayList<ArithmeticOperator> operators = new ArrayList<ArithmeticOperator>();
        for (int i = 0; i < iter + 1; ++i) {
            if (sequence.get(i) instanceof ConstantNode) {
                factors.add(ONE);
            } else if (sequence.get(i) instanceof ReferenceNode) {
                factors.add(new ReferenceNode(this.lengthFunctionName((ReferenceNode)sequence.get(i))));
            }
            if (i < 1) continue;
            operators.add(ArithmeticOperator.PLUS);
        }
        return new ArithmeticNode(factors, operators);
    }

    private ExpressionNode createTokenExtractExpr(int iter, List<ExpressionNode> sequence) {
        ReferenceNode expr;
        if (iter >= 1) {
            EmbracedNode lengthExpr = new EmbracedNode(this.createLengthExpr(iter - 1, sequence));
            expr = new EmbracedNode((ExpressionNode)new ArithmeticNode((ExpressionNode)new ReferenceNode("d1"), ArithmeticOperator.MINUS, (ExpressionNode)lengthExpr));
        } else {
            expr = new ReferenceNode("d1");
        }
        List<Slice.DimensionValue> slices = List.of(new Slice.DimensionValue("d0", TensorFunctionNode.wrapScalar((ExpressionNode)expr)));
        TensorFunctionNode.ExpressionTensorFunction argument = new TensorFunctionNode.ExpressionTensorFunction(sequence.get(iter));
        return new TensorFunctionNode((TensorFunction)new Slice((TensorFunction)argument, slices));
    }
}

