/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.Set;

public class RuntimeReorderJoinSides
implements Rule<JoinNode> {
    private static final Logger log = Logger.get(RuntimeReorderJoinSides.class);
    private static final Pattern<JoinNode> PATTERN = Patterns.join();

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        if (PlanNodeSearcher.searchFrom(joinNode, context.getLookup()).where(node -> node.getSources().isEmpty() && !(node instanceof TableScanNode)).matches()) {
            return Rule.Result.empty();
        }
        StatsProvider statsProvider = context.getStatsProvider();
        double leftOutputSizeInBytes = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes(joinNode.getLeft().getOutputVariables());
        double rightOutputSizeInBytes = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes(joinNode.getRight().getOutputVariables());
        if (Double.isNaN(leftOutputSizeInBytes) || Double.isNaN(rightOutputSizeInBytes)) {
            return Rule.Result.empty();
        }
        if (rightOutputSizeInBytes <= leftOutputSizeInBytes) {
            return Rule.Result.empty();
        }
        if (!this.isSwappedJoinValid(joinNode)) {
            return Rule.Result.empty();
        }
        JoinNode swapped = joinNode.flipChildren();
        PlanNode newLeft = swapped.getLeft();
        PlanNode resolvedSwappedLeft = context.getLookup().resolve(newLeft);
        Optional<Object> leftHashVariable = swapped.getLeftHashVariable();
        if (resolvedSwappedLeft instanceof ExchangeNode && resolvedSwappedLeft.getSources().size() == 1) {
            newLeft = (PlanNode)resolvedSwappedLeft.getSources().get(0);
            if (swapped.getLeftHashVariable().isPresent()) {
                int hashVariableIndex = resolvedSwappedLeft.getOutputVariables().indexOf(swapped.getLeftHashVariable().get());
                leftHashVariable = Optional.of(((PlanNode)resolvedSwappedLeft.getSources().get(0)).getOutputVariables().get(hashVariableIndex));
                if (swapped.getOutputVariables().contains(swapped.getLeftHashVariable().get())) {
                    return Rule.Result.empty();
                }
            }
        }
        List buildJoinVariables = (List)swapped.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight).collect(ImmutableList.toImmutableList());
        PlanNode newRight = swapped.getRight();
        if (this.needLocalExchange(swapped.getRight(), (Set<VariableReferenceExpression>)ImmutableSet.copyOf((Collection)buildJoinVariables), context)) {
            newRight = SystemSessionProperties.getTaskConcurrency(context.getSession()) > 1 ? ExchangeNode.systemPartitionedExchange(context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, swapped.getRight(), buildJoinVariables, swapped.getRightHashVariable()) : ExchangeNode.gatheringExchange(context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, swapped.getRight());
        }
        JoinNode newJoinNode = new JoinNode(swapped.getId(), swapped.getType(), newLeft, newRight, swapped.getCriteria(), swapped.getOutputVariables(), swapped.getFilter(), leftHashVariable, swapped.getRightHashVariable(), swapped.getDistributionType());
        log.debug(String.format("Probe size: %.2f is smaller than Build size: %.2f => invoke runtime join swapping on JoinNode ID: %s.", leftOutputSizeInBytes, rightOutputSizeInBytes, newJoinNode.getId()));
        return Rule.Result.ofPlanNode(newJoinNode);
    }

    private boolean isSwappedJoinValid(JoinNode join) {
        return !(join.getDistributionType().get() == JoinNode.DistributionType.REPLICATED && join.getType() == JoinNode.Type.LEFT || join.getDistributionType().get() == JoinNode.DistributionType.PARTITIONED && join.getCriteria().isEmpty() && join.getType() == JoinNode.Type.RIGHT);
    }

    private boolean needLocalExchange(PlanNode root, Set<VariableReferenceExpression> partitioningColumns, Rule.Context context) {
        PlanNode actual = context.getLookup().resolve(root);
        if (actual instanceof ExchangeNode) {
            if (!partitioningColumns.isEmpty() && ((ExchangeNode)actual).getPartitioningScheme().getPartitioning().getVariableReferences().equals(partitioningColumns)) {
                return false;
            }
            return !partitioningColumns.isEmpty() && ((ExchangeNode)actual).getType() != ExchangeNode.Type.GATHER;
        }
        if (actual.getSources().isEmpty()) {
            return true;
        }
        for (PlanNode child : actual.getSources()) {
            if (!this.needLocalExchange(child, partitioningColumns, context)) continue;
            return true;
        }
        return false;
    }
}

