/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.ExceptNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.IntersectNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SetOperationNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizerResult;
import com.facebook.presto.sql.planner.optimizations.SetOperationNodeUtils;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Maps;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

public class ImplementIntersectAndExceptAsUnion
implements PlanOptimizer {
    private final FunctionAndTypeManager functionAndTypeManager;

    public ImplementIntersectAndExceptAsUnion(FunctionAndTypeManager functionAndTypeManager) {
        this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionManager is null");
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        Objects.requireNonNull(plan, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(types, "types is null");
        Objects.requireNonNull(variableAllocator, "variableAllocator is null");
        Objects.requireNonNull(idAllocator, "idAllocator is null");
        Rewriter rewriter = new Rewriter(session, this.functionAndTypeManager, idAllocator, variableAllocator);
        PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan);
        return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private static final String MARKER = "marker";
        private final Session session;
        private final StandardFunctionResolution functionResolution;
        private final PlanNodeIdAllocator idAllocator;
        private final VariableAllocator variableAllocator;
        private boolean planChanged;

        private Rewriter(Session session, FunctionAndTypeManager functionAndTypeManager, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator) {
            Objects.requireNonNull(functionAndTypeManager, "functionManager is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.variableAllocator = Objects.requireNonNull(variableAllocator, "variableAllocator is null");
        }

        public PlanNode visitIntersect(IntersectNode node, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            List<PlanNode> sources = node.getSources().stream().map(rewriteContext::rewrite).collect(Collectors.toList());
            List<VariableReferenceExpression> markers = this.allocateVariables(sources.size(), MARKER, (Type)BooleanType.BOOLEAN);
            List<PlanNode> withMarkers = this.appendMarkers(markers, sources, (SetOperationNode)node);
            List outputs = node.getOutputVariables();
            UnionNode union = this.union(withMarkers, (List<VariableReferenceExpression>)ImmutableList.copyOf((Iterable)Iterables.concat((Iterable)outputs, markers)));
            List<VariableReferenceExpression> aggregationOutputs = this.allocateVariables(markers.size(), "count", (Type)BigintType.BIGINT);
            AggregationNode aggregation = this.computeCounts(union, node.getOutputVariables(), markers, aggregationOutputs);
            FilterNode filterNode = this.addFilterForIntersect(aggregation);
            this.planChanged = true;
            return this.project((PlanNode)filterNode, outputs);
        }

        public PlanNode visitExcept(ExceptNode node, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            List<PlanNode> sources = node.getSources().stream().map(rewriteContext::rewrite).collect(Collectors.toList());
            List<VariableReferenceExpression> markers = this.allocateVariables(sources.size(), MARKER, (Type)BooleanType.BOOLEAN);
            List<PlanNode> withMarkers = this.appendMarkers(markers, sources, (SetOperationNode)node);
            List outputs = node.getOutputVariables();
            UnionNode union = this.union(withMarkers, (List<VariableReferenceExpression>)ImmutableList.copyOf((Iterable)Iterables.concat((Iterable)outputs, markers)));
            List<VariableReferenceExpression> aggregationOutputs = this.allocateVariables(markers.size(), "count", (Type)BigintType.BIGINT);
            AggregationNode aggregation = this.computeCounts(union, node.getOutputVariables(), markers, aggregationOutputs);
            FilterNode filterNode = this.addFilterForExcept(aggregation, aggregationOutputs.get(0), aggregationOutputs.subList(1, aggregationOutputs.size()));
            this.planChanged = true;
            return this.project((PlanNode)filterNode, outputs);
        }

        private List<VariableReferenceExpression> allocateVariables(int count, String nameHint, Type type) {
            ImmutableList.Builder variablesBuilder = ImmutableList.builder();
            for (int i = 0; i < count; ++i) {
                variablesBuilder.add((Object)this.variableAllocator.newVariable(nameHint, type));
            }
            return variablesBuilder.build();
        }

        private List<PlanNode> appendMarkers(List<VariableReferenceExpression> markers, List<PlanNode> nodes, SetOperationNode node) {
            ImmutableList.Builder result = ImmutableList.builder();
            for (int i = 0; i < nodes.size(); ++i) {
                result.add((Object)this.appendMarkers(nodes.get(i), i, markers, Maps.transformValues((Map)node.sourceVariableMap(i), variable -> variable)));
            }
            return result.build();
        }

        private PlanNode appendMarkers(PlanNode source, int markerIndex, List<VariableReferenceExpression> markers, Map<VariableReferenceExpression, VariableReferenceExpression> projections) {
            Assignments.Builder assignments = Assignments.builder();
            for (Map.Entry<VariableReferenceExpression, VariableReferenceExpression> entry : projections.entrySet()) {
                VariableReferenceExpression variable = this.variableAllocator.newVariable(entry.getKey().getSourceLocation(), entry.getKey().getName(), entry.getKey().getType());
                assignments.put(variable, (RowExpression)entry.getValue());
            }
            for (int i = 0; i < markers.size(); ++i) {
                ConstantExpression expression = i == markerIndex ? LogicalRowExpressions.TRUE_CONSTANT : new ConstantExpression(null, (Type)BooleanType.BOOLEAN);
                assignments.put(this.variableAllocator.newVariable(markers.get(i).getSourceLocation(), markers.get(i).getName(), (Type)BooleanType.BOOLEAN), (RowExpression)expression);
            }
            return new ProjectNode(this.idAllocator.getNextId(), source, assignments.build());
        }

        private UnionNode union(List<PlanNode> nodes, List<VariableReferenceExpression> outputs) {
            ImmutableListMultimap.Builder outputsToInputs = ImmutableListMultimap.builder();
            for (PlanNode source : nodes) {
                for (int i = 0; i < source.getOutputVariables().size(); ++i) {
                    outputsToInputs.put((Object)outputs.get(i), source.getOutputVariables().get(i));
                }
            }
            ImmutableListMultimap mapping = outputsToInputs.build();
            return new UnionNode(nodes.get(0).getSourceLocation(), this.idAllocator.getNextId(), nodes, (List)ImmutableList.copyOf((Collection)mapping.keySet()), SetOperationNodeUtils.fromListMultimap((ListMultimap<VariableReferenceExpression, VariableReferenceExpression>)mapping));
        }

        private AggregationNode computeCounts(UnionNode sourceNode, List<VariableReferenceExpression> originalColumns, List<VariableReferenceExpression> markers, List<VariableReferenceExpression> aggregationOutputs) {
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (int i = 0; i < markers.size(); ++i) {
                VariableReferenceExpression output = aggregationOutputs.get(i);
                aggregations.put((Object)output, (Object)new AggregationNode.Aggregation(new CallExpression(output.getSourceLocation(), "count", this.functionResolution.countFunction(markers.get(i).getType()), (Type)BigintType.BIGINT, (List)ImmutableList.of((Object)markers.get(i))), Optional.empty(), Optional.empty(), false, Optional.empty()));
            }
            return new AggregationNode(sourceNode.getSourceLocation(), this.idAllocator.getNextId(), (PlanNode)sourceNode, (Map)aggregations.build(), AggregationNode.singleGroupingSet(originalColumns), (List)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), Optional.empty());
        }

        private FilterNode addFilterForIntersect(AggregationNode aggregation) {
            ImmutableList predicates = (ImmutableList)aggregation.getAggregations().keySet().stream().map(column -> Expressions.comparisonExpression(this.functionResolution, OperatorType.GREATER_THAN_OR_EQUAL, (RowExpression)column, (RowExpression)new ConstantExpression((Object)1L, (Type)BigintType.BIGINT))).collect(ImmutableList.toImmutableList());
            return new FilterNode(aggregation.getSourceLocation(), this.idAllocator.getNextId(), (PlanNode)aggregation, LogicalRowExpressions.and((Collection)predicates));
        }

        private FilterNode addFilterForExcept(AggregationNode aggregation, VariableReferenceExpression firstSource, List<VariableReferenceExpression> remainingSources) {
            ImmutableList.Builder predicatesBuilder = ImmutableList.builder();
            predicatesBuilder.add((Object)Expressions.comparisonExpression(this.functionResolution, OperatorType.GREATER_THAN_OR_EQUAL, (RowExpression)firstSource, (RowExpression)new ConstantExpression((Object)1L, (Type)BigintType.BIGINT)));
            for (VariableReferenceExpression variable : remainingSources) {
                predicatesBuilder.add((Object)Expressions.comparisonExpression(this.functionResolution, OperatorType.EQUAL, (RowExpression)variable, (RowExpression)new ConstantExpression((Object)0L, (Type)BigintType.BIGINT)));
            }
            return new FilterNode(aggregation.getSourceLocation(), this.idAllocator.getNextId(), (PlanNode)aggregation, LogicalRowExpressions.and((Collection)predicatesBuilder.build()));
        }

        private ProjectNode project(PlanNode node, List<VariableReferenceExpression> columns) {
            return new ProjectNode(this.idAllocator.getNextId(), node, AssignmentUtils.identityAssignments(columns));
        }

        public boolean isPlanChanged() {
            return this.planChanged;
        }
    }
}

