/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.IntersectNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.util.ImmutableCollectors;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

public class ImplementIntersectAsUnion
implements PlanOptimizer {
    @Override
    public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        Objects.requireNonNull(plan, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(types, "types is null");
        Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        Objects.requireNonNull(idAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(idAllocator, symbolAllocator), plan);
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private static final String INTERSECT_MARKER = "intersect_marker";
        private static final Signature COUNT_AGGREGATION = new Signature("count", FunctionKind.AGGREGATE, TypeSignature.parseTypeSignature((String)"bigint"), TypeSignature.parseTypeSignature((String)"boolean"));
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;

        private Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        @Override
        public PlanNode visitIntersect(IntersectNode node, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            List<PlanNode> sources = node.getSources().stream().map(rewriteContext::rewrite).collect(Collectors.toList());
            List<Symbol> markers = this.allocateMarkers(sources.size());
            List<PlanNode> withMarkers = this.appendMarkers(markers, sources, node);
            List<Symbol> outputs = node.getOutputSymbols();
            UnionNode union = this.union(withMarkers, (List<Symbol>)ImmutableList.copyOf((Iterable)Iterables.concat(outputs, markers)));
            AggregationNode aggregation = this.computeCounts(union, outputs, markers);
            FilterNode filterNode = this.addFilter(aggregation);
            return this.project(filterNode, outputs);
        }

        private List<Symbol> allocateMarkers(int count) {
            ImmutableList.Builder markers = ImmutableList.builder();
            for (int i = 0; i < count; ++i) {
                markers.add((Object)this.symbolAllocator.newSymbol(INTERSECT_MARKER, (Type)BooleanType.BOOLEAN));
            }
            return markers.build();
        }

        private List<PlanNode> appendMarkers(List<Symbol> markers, List<PlanNode> nodes, IntersectNode intersect) {
            ImmutableList.Builder result = ImmutableList.builder();
            for (int i = 0; i < nodes.size(); ++i) {
                result.add((Object)this.appendMarkers(nodes.get(i), i, markers, intersect.sourceSymbolMap(i)));
            }
            return result.build();
        }

        private PlanNode appendMarkers(PlanNode source, int markerIndex, List<Symbol> markers, Map<Symbol, SymbolReference> projections) {
            ImmutableMap.Builder assignments = ImmutableMap.builder();
            for (Map.Entry<Symbol, SymbolReference> entry : projections.entrySet()) {
                Symbol symbol = this.symbolAllocator.newSymbol(entry.getKey().getName(), this.symbolAllocator.getTypes().get(entry.getKey()));
                assignments.put((Object)symbol, (Object)entry.getValue());
            }
            for (int i = 0; i < markers.size(); ++i) {
                BooleanLiteral expression = i == markerIndex ? BooleanLiteral.TRUE_LITERAL : new Cast((Expression)new NullLiteral(), "boolean");
                assignments.put((Object)this.symbolAllocator.newSymbol(markers.get(i).getName(), (Type)BooleanType.BOOLEAN), (Object)expression);
            }
            return new ProjectNode(this.idAllocator.getNextId(), source, (Map<Symbol, Expression>)assignments.build());
        }

        private UnionNode union(List<PlanNode> nodes, List<Symbol> outputs) {
            ImmutableListMultimap.Builder outputsToInputs = ImmutableListMultimap.builder();
            for (PlanNode source : nodes) {
                for (int i = 0; i < source.getOutputSymbols().size(); ++i) {
                    outputsToInputs.put((Object)outputs.get(i), (Object)source.getOutputSymbols().get(i));
                }
            }
            return new UnionNode(this.idAllocator.getNextId(), nodes, (ListMultimap<Symbol, Symbol>)outputsToInputs.build(), outputs);
        }

        private AggregationNode computeCounts(UnionNode sourceNode, List<Symbol> originalColumns, List<Symbol> markers) {
            ImmutableMap.Builder signatures = ImmutableMap.builder();
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (Symbol marker : markers) {
                Symbol output = this.symbolAllocator.newSymbol("count", (Type)BigintType.BIGINT);
                aggregations.put((Object)output, (Object)new FunctionCall(QualifiedName.of((String)"count"), (List)ImmutableList.of((Object)marker.toSymbolReference())));
                signatures.put((Object)output, (Object)COUNT_AGGREGATION);
            }
            return new AggregationNode(this.idAllocator.getNextId(), sourceNode, originalColumns, (Map<Symbol, FunctionCall>)aggregations.build(), (Map<Symbol, Signature>)signatures.build(), (Map<Symbol, Symbol>)ImmutableMap.of(), (List<List<Symbol>>)ImmutableList.of(originalColumns), AggregationNode.Step.SINGLE, Optional.empty(), 1.0, Optional.empty());
        }

        private FilterNode addFilter(AggregationNode aggregation) {
            ImmutableList predicates = aggregation.getAggregations().keySet().stream().map(column -> new ComparisonExpression(ComparisonExpression.Type.GREATER_THAN_OR_EQUAL, (Expression)column.toSymbolReference(), (Expression)new GenericLiteral("BIGINT", "1"))).collect(ImmutableCollectors.toImmutableList());
            return new FilterNode(this.idAllocator.getNextId(), aggregation, ExpressionUtils.and(predicates));
        }

        private ProjectNode project(PlanNode node, List<Symbol> columns) {
            return new ProjectNode(this.idAllocator.getNextId(), node, columns.stream().collect(Collectors.toMap(Function.identity(), Symbol::toSymbolReference)));
        }
    }
}

