/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ReduceJoin;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;

public class TensorOptimizer
extends Optimizer {
    private OptimizationReport report;

    @Override
    public void optimize(RankingExpression expression, ContextIndex context, OptimizationReport report) {
        if (!this.isEnabled()) {
            return;
        }
        this.report = report;
        expression.setRoot(this.optimize(expression.getRoot(), context));
        report.note("Tensor expression optimization done");
    }

    private ExpressionNode optimize(ExpressionNode node, ContextIndex context) {
        if ((node = this.optimizeReduceJoin(node)) instanceof CompositeNode) {
            return this.optimizeChildren((CompositeNode)node, context);
        }
        return node;
    }

    private ExpressionNode optimizeChildren(CompositeNode node, ContextIndex context) {
        List<ExpressionNode> children = node.children();
        ArrayList<ExpressionNode> optimizedChildren = new ArrayList<ExpressionNode>(children.size());
        for (ExpressionNode child : children) {
            optimizedChildren.add(this.optimize(child, context));
        }
        return node.setChildren(optimizedChildren);
    }

    private ExpressionNode optimizeReduceJoin(ExpressionNode node) {
        if (!(node instanceof TensorFunctionNode)) {
            return node;
        }
        TensorFunction function = ((TensorFunctionNode)node).function();
        if (!(function instanceof Reduce)) {
            return node;
        }
        List<ExpressionNode> children = ((TensorFunctionNode)node).children();
        if (children.size() != 1) {
            return node;
        }
        ExpressionNode child = children.get(0);
        if (!(child instanceof TensorFunctionNode)) {
            return node;
        }
        TensorFunction argument = ((TensorFunctionNode)child).function();
        if (argument instanceof Join) {
            this.report.incMetric("Replaced reduce->join", 1);
            return new TensorFunctionNode((TensorFunction)new ReduceJoin((Reduce)function, (Join)argument));
        }
        return node;
    }
}

