/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.StatsCalculator;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizerResult;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.Expressions;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class ShardJoins
implements PlanOptimizer {
    private final Metadata metadata;
    private final FunctionAndTypeManager functionAndTypeManager;
    private final StatsCalculator statsCalculator;
    private boolean isEnabledForTesting;

    public ShardJoins(Metadata metadata, FunctionAndTypeManager functionAndTypeManager, StatsCalculator statsCalculator) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
        this.statsCalculator = Objects.requireNonNull(statsCalculator, "statsCalculator is null");
    }

    @Override
    public void setEnabledForTesting(boolean isSet) {
        this.isEnabledForTesting = isSet;
    }

    @Override
    public boolean isEnabled(Session session) {
        return this.isEnabledForTesting || !SystemSessionProperties.getShardedJoinStrategy(session).equals((Object)FeaturesConfig.ShardedJoinStrategy.DISABLED);
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        if (this.isEnabled(session)) {
            Rewriter rewriter = new Rewriter(session, this.metadata, this.functionAndTypeManager, idAllocator, variableAllocator);
            PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, new HashSet());
            return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
        }
        return PlanOptimizerResult.optimizerResult(plan, false);
    }

    private static class Rewriter
    extends SimplePlanRewriter<Set<VariableReferenceExpression>> {
        private final Session session;
        private final Metadata metadata;
        private final FunctionAndTypeManager functionAndTypeManager;
        private final PlanNodeIdAllocator planNodeIdAllocator;
        private final VariableAllocator planVariableAllocator;
        private boolean planChanged;

        private Rewriter(Session session, Metadata metadata, FunctionAndTypeManager functionAndTypeManager, PlanNodeIdAllocator planNodeIdAllocator, VariableAllocator planVariableAllocator) {
            this.session = Objects.requireNonNull(session, "session is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
            this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
            this.planNodeIdAllocator = Objects.requireNonNull(planNodeIdAllocator, "planNodeIdAllocator is null");
            this.planVariableAllocator = Objects.requireNonNull(planVariableAllocator, "planVariableAllocator is null");
        }

        public boolean isPlanChanged() {
            return this.planChanged;
        }

        public PlanNode visitJoin(JoinNode joinNode, SimplePlanRewriter.RewriteContext<Set<VariableReferenceExpression>> context) {
            if (this.isApplicable(joinNode)) {
                long numShards = this.getNumberOfShards();
                CallExpression randomNumber = Expressions.call(this.functionAndTypeManager, "random", (Type)BigintType.BIGINT, new RowExpression[]{Expressions.constant(numShards, (Type)BigintType.BIGINT)});
                VariableReferenceExpression leftShardVariable = this.planVariableAllocator.newVariable("shard", (Type)BigintType.BIGINT);
                VariableReferenceExpression rightShardVariable = this.planVariableAllocator.newVariable("shard", (Type)BigintType.BIGINT);
                PlanNode newLeftChild = PlannerUtils.addProjections(joinNode.getLeft(), this.planNodeIdAllocator, this.planVariableAllocator, (List<RowExpression>)ImmutableList.of((Object)randomNumber), (List<VariableReferenceExpression>)ImmutableList.of((Object)leftShardVariable));
                PlanNode newRightChild = this.shardInput(numShards, joinNode.getRight(), rightShardVariable);
                EquiJoinClause shardEquality = new EquiJoinClause(leftShardVariable, rightShardVariable);
                ArrayList<EquiJoinClause> joinCriteria = new ArrayList<EquiJoinClause>();
                joinCriteria.addAll(joinNode.getCriteria());
                joinCriteria.add(shardEquality);
                JoinNode result = new JoinNode(joinNode.getSourceLocation(), joinNode.getId(), joinNode.getStatsEquivalentPlanNode(), joinNode.getType(), newLeftChild, newRightChild, joinCriteria, joinNode.getOutputVariables(), joinNode.getFilter(), joinNode.getLeftHashVariable(), joinNode.getRightHashVariable(), joinNode.getDistributionType(), joinNode.getDynamicFilters());
                this.planChanged = true;
                return context.defaultRewrite((PlanNode)result);
            }
            return context.defaultRewrite((PlanNode)joinNode);
        }

        private boolean isApplicable(JoinNode joinNode) {
            return joinNode.getType() != JoinType.FULL && joinNode.getType() != JoinType.RIGHT && !PlannerUtils.isBroadcastJoin(joinNode) && (SystemSessionProperties.getShardedJoinStrategy(this.session).equals((Object)FeaturesConfig.ShardedJoinStrategy.ALWAYS) || SystemSessionProperties.getShardedJoinStrategy(this.session).equals((Object)FeaturesConfig.ShardedJoinStrategy.COST_BASED) && this.shouldShardJoin(joinNode));
        }

        private boolean shouldShardJoin(JoinNode joinNode) {
            return false;
        }

        private PlanNode shardInput(long numShards, PlanNode source, VariableReferenceExpression shardVariable) {
            Preconditions.checkState((numShards > 1L ? 1 : 0) != 0);
            CallExpression sequenceExpression = Expressions.call(this.functionAndTypeManager, "sequence", (Type)new ArrayType((Type)BigintType.BIGINT), new RowExpression[]{Expressions.constant(0L, (Type)BigintType.BIGINT), Expressions.constant(numShards - 1L, (Type)BigintType.BIGINT)});
            VariableReferenceExpression sequenceVariable = this.planVariableAllocator.newVariable((RowExpression)sequenceExpression);
            PlanNode projectSequence = PlannerUtils.addProjections(source, this.planNodeIdAllocator, this.planVariableAllocator, (List<RowExpression>)ImmutableList.of((Object)sequenceExpression), (List<VariableReferenceExpression>)ImmutableList.of((Object)sequenceVariable));
            UnnestNode unnest = new UnnestNode(source.getSourceLocation(), this.planNodeIdAllocator.getNextId(), projectSequence, projectSequence.getOutputVariables(), (Map)ImmutableMap.of((Object)sequenceVariable, (Object)ImmutableList.of((Object)shardVariable)), Optional.empty());
            return unnest;
        }

        private int getNumberOfShards() {
            return SystemSessionProperties.getJoinShardCount(this.session);
        }
    }
}

