/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.NegativeNode;
import com.yahoo.searchlib.rankingexpression.rule.NotNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SetMembershipNode;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.yolean.Exceptions;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class GBDTOptimizer
extends Optimizer {
    private OptimizationReport report;

    @Override
    public void optimize(RankingExpression expression, ContextIndex context, OptimizationReport report) {
        if (!this.isEnabled()) {
            return;
        }
        this.report = report;
        if (context.size() > 1000000) {
            report.note("Can not optimize expressions referencing more than 1000000 features: " + expression + " has " + context.size());
            return;
        }
        expression.setRoot(this.optimize(expression.getRoot(), context));
        report.note("GBDT tree optimization done");
    }

    private ExpressionNode optimize(ExpressionNode node, ContextIndex context) {
        if (node instanceof ArithmeticNode) {
            Iterator<ExpressionNode> childIt = ((ArithmeticNode)node).children().iterator();
            ExpressionNode ret = this.optimize(childIt.next(), context);
            Iterator<ArithmeticOperator> operIt = ((ArithmeticNode)node).operators().iterator();
            while (childIt.hasNext() && operIt.hasNext()) {
                ret = ArithmeticNode.resolve(ret, operIt.next(), this.optimize(childIt.next(), context));
            }
            return ret;
        }
        if (node instanceof IfNode) {
            return this.createGBDTNode((IfNode)node, context);
        }
        return node;
    }

    private ExpressionNode createGBDTNode(IfNode cNode, ContextIndex context) {
        ArrayList<Double> values = new ArrayList<Double>();
        try {
            this.consumeNode(cNode, values, context);
        }
        catch (IllegalArgumentException e) {
            this.report.note("Skipped optimization: " + Exceptions.toMessageString((Throwable)e) + ". Expression: " + cNode);
            return cNode;
        }
        this.report.incMetric("Optimized GDBT trees", 1);
        return new GBDTNode(this.toArray(values));
    }

    private int consumeNode(ExpressionNode node, List<Double> values, ContextIndex context) {
        int beforeIndex = values.size();
        if (node instanceof IfNode) {
            IfNode ifNode = (IfNode)node;
            int jumpValueIndex = this.consumeIfCondition(ifNode.getCondition(), values, context);
            values.add(0.0);
            int jumpValue = this.consumeNode(ifNode.getTrueExpression(), values, context) + 1;
            values.set(jumpValueIndex, Double.valueOf(jumpValue));
            this.consumeNode(ifNode.getFalseExpression(), values, context);
        } else {
            double value = this.toValue(node);
            if (Math.abs(value) > 2.0E9) {
                throw new IllegalArgumentException("Leaf value is too large for optimization: " + value);
            }
            values.add(this.toValue(node));
        }
        return values.size() - beforeIndex;
    }

    private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) {
        if (condition instanceof ComparisonNode) {
            ComparisonNode comparison = (ComparisonNode)condition;
            if (comparison.getOperator() == TruthOperator.SMALLER) {
                values.add(2.0E9 + this.getVariableIndex(comparison.getLeftCondition(), context));
            } else if (comparison.getOperator() == TruthOperator.EQUAL) {
                values.add(2.001E9 + this.getVariableIndex(comparison.getLeftCondition(), context));
            } else {
                throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.getOperator());
            }
            values.add(this.toValue(comparison.getRightCondition()));
        } else if (condition instanceof SetMembershipNode) {
            SetMembershipNode setMembership = (SetMembershipNode)condition;
            values.add(2.002E9 + this.getVariableIndex(setMembership.getTestValue(), context));
            values.add(Double.valueOf(setMembership.getSetValues().size()));
            for (ExpressionNode setElementNode : setMembership.getSetValues()) {
                values.add(this.toValue(setElementNode));
            }
        } else if (condition instanceof NotNode) {
            EmbracedNode embracedNode;
            NotNode notNode = (NotNode)condition;
            if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode && (embracedNode = (EmbracedNode)notNode.children().get(0)).children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) {
                ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0);
                if (comparison.getOperator() != TruthOperator.LARGEREQUAL) {
                    throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator());
                }
                values.add(2.003E9 + this.getVariableIndex(comparison.getLeftCondition(), context));
                values.add(this.toValue(comparison.getRightCondition()));
            }
        } else {
            throw new IllegalArgumentException("Node condition could not be optimized: " + condition);
        }
        return values.size();
    }

    private double getVariableIndex(ExpressionNode node, ContextIndex context) {
        if (!(node instanceof ReferenceNode)) {
            throw new IllegalArgumentException("Contained a left-hand comparison expression which was not a feature value but was: " + node);
        }
        ReferenceNode fNode = (ReferenceNode)node;
        Integer index = context.getIndex(fNode.toString());
        if (index == null) {
            throw new IllegalStateException("The ranking expression contained feature '" + fNode.getName() + "', which is not known to " + context + ": The context must be createdfrom the same ranking expression which is to be optimized");
        }
        return index.intValue();
    }

    private double toValue(ExpressionNode node) {
        if (node instanceof ConstantNode) {
            Value value = ((ConstantNode)node).getValue();
            if (value instanceof DoubleCompatibleValue || value instanceof StringValue) {
                return value.asDouble();
            }
            throw new IllegalArgumentException("Cannot optimize a node containing a value of type " + value.getClass().getSimpleName() + " (" + value + ") in a set test: " + node);
        }
        if (node instanceof NegativeNode) {
            NegativeNode nNode = (NegativeNode)node;
            if (!(nNode.getValue() instanceof ConstantNode)) {
                throw new IllegalArgumentException("Contained a negation of a non-number: " + nNode.getValue());
            }
            return -((ConstantNode)nNode.getValue()).getValue().asDouble();
        }
        throw new IllegalArgumentException("Node could not be optimized: " + node);
    }

    private double[] toArray(List<Double> valueList) {
        double[] valueArray = new double[valueList.size()];
        for (int i = 0; i < valueList.size(); ++i) {
            valueArray[i] = valueList.get(i);
        }
        return valueArray;
    }
}

