/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionType;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.Domain;
import com.facebook.presto.spi.Marker;
import com.facebook.presto.spi.Range;
import com.facebook.presto.spi.SortedRangeSet;
import com.facebook.presto.spi.TupleDomain;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.DomainTranslator;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanRewriter;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.primitives.Ints;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.stream.Collectors;

public class WindowFilterPushDown
extends PlanOptimizer {
    private static final Signature ROW_NUMBER_SIGNATURE = new Signature("row_number", FunctionType.WINDOW, "bigint", (List<String>)ImmutableList.of());
    private final Metadata metadata;

    public WindowFilterPushDown(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        Objects.requireNonNull(plan, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(types, "types is null");
        Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        Objects.requireNonNull(idAllocator, "idAllocator is null");
        return PlanRewriter.rewriteWith(new Rewriter(idAllocator, this.metadata, session, types), plan, null);
    }

    private static class Rewriter
    extends PlanRewriter<Void> {
        private final PlanNodeIdAllocator idAllocator;
        private final Metadata metadata;
        private final Session session;
        private final Map<Symbol, Type> types;

        private Rewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session, Map<Symbol, Type> types) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.types = ImmutableMap.copyOf(Objects.requireNonNull(types, "types is null"));
        }

        @Override
        public PlanNode visitWindow(WindowNode node, PlanRewriter.RewriteContext<Void> context) {
            PlanNode rewrittenSource = context.rewrite(node.getSource());
            if (Rewriter.canReplaceWithRowNumber(node)) {
                return new RowNumberNode(this.idAllocator.getNextId(), rewrittenSource, node.getPartitionBy(), (Symbol)Iterables.getOnlyElement(node.getWindowFunctions().keySet()), Optional.empty(), Optional.empty());
            }
            return context.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)rewrittenSource));
        }

        @Override
        public PlanNode visitLimit(LimitNode node, PlanRewriter.RewriteContext<Void> context) {
            if (node.getCount() > Integer.MAX_VALUE) {
                return context.defaultRewrite(node);
            }
            PlanNode source = context.rewrite(node.getSource());
            int limit = Ints.checkedCast((long)node.getCount());
            if (source instanceof RowNumberNode) {
                RowNumberNode rowNumberNode = Rewriter.mergeLimit((RowNumberNode)source, limit);
                if (rowNumberNode.getPartitionBy().isEmpty()) {
                    return rowNumberNode;
                }
                source = rowNumberNode;
            } else if (source instanceof WindowNode && Rewriter.canOptimizeWindowFunction((WindowNode)source)) {
                WindowNode windowNode = (WindowNode)source;
                Verify.verify((!windowNode.getOrderBy().isEmpty() ? 1 : 0) != 0);
                TopNRowNumberNode topNRowNumberNode = this.convertToTopNRowNumber(windowNode, limit);
                if (windowNode.getPartitionBy().isEmpty()) {
                    return topNRowNumberNode;
                }
                source = topNRowNumberNode;
            }
            return context.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)source));
        }

        @Override
        public PlanNode visitFilter(FilterNode node, PlanRewriter.RewriteContext<Void> context) {
            WindowNode windowNode;
            Symbol rowNumberSymbol;
            OptionalInt upperBound;
            PlanNode source = context.rewrite(node.getSource());
            TupleDomain<Symbol> tupleDomain = DomainTranslator.fromPredicate(this.metadata, this.session, node.getPredicate(), this.types).getTupleDomain();
            if (source instanceof RowNumberNode) {
                Symbol rowNumberSymbol2 = ((RowNumberNode)source).getRowNumberSymbol();
                OptionalInt upperBound2 = Rewriter.extractUpperBound(tupleDomain, rowNumberSymbol2);
                if (upperBound2.isPresent()) {
                    source = Rewriter.mergeLimit((RowNumberNode)source, upperBound2.getAsInt());
                    return this.rewriteFilterSource(node, source, rowNumberSymbol2, upperBound2.getAsInt());
                }
            } else if (source instanceof WindowNode && Rewriter.canOptimizeWindowFunction((WindowNode)source) && (upperBound = Rewriter.extractUpperBound(tupleDomain, rowNumberSymbol = (Symbol)((Map.Entry)Iterables.getOnlyElement((windowNode = (WindowNode)source).getWindowFunctions().entrySet())).getKey())).isPresent()) {
                source = this.convertToTopNRowNumber(windowNode, upperBound.getAsInt());
                return this.rewriteFilterSource(node, source, rowNumberSymbol, upperBound.getAsInt());
            }
            return context.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)source));
        }

        private PlanNode rewriteFilterSource(FilterNode filterNode, PlanNode source, Symbol rowNumberSymbol, int upperBound) {
            DomainTranslator.ExtractionResult extractionResult = DomainTranslator.fromPredicate(this.metadata, this.session, filterNode.getPredicate(), this.types);
            TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
            if (!Rewriter.isEqualRange(tupleDomain, rowNumberSymbol, upperBound)) {
                return new FilterNode(filterNode.getId(), source, filterNode.getPredicate());
            }
            Map<Symbol, Domain> newDomains = tupleDomain.getDomains().entrySet().stream().filter(entry -> !((Symbol)entry.getKey()).equals(rowNumberSymbol)).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
            TupleDomain newTupleDomain = TupleDomain.withColumnDomains(newDomains);
            Expression newPredicate = ExpressionUtils.combineConjuncts(extractionResult.getRemainingExpression(), DomainTranslator.toPredicate((TupleDomain<Symbol>)newTupleDomain, this.types));
            if (newPredicate.equals((Object)BooleanLiteral.TRUE_LITERAL)) {
                return source;
            }
            return new FilterNode(filterNode.getId(), source, newPredicate);
        }

        private static boolean isEqualRange(TupleDomain<Symbol> tupleDomain, Symbol symbol, long upperBound) {
            if (tupleDomain.isNone()) {
                return false;
            }
            return ((Domain)tupleDomain.getDomains().get(symbol)).getRanges().equals((Object)SortedRangeSet.of((Range)Range.lessThanOrEqual((Comparable)Long.valueOf(upperBound)), (Range[])new Range[0]));
        }

        private static OptionalInt extractUpperBound(TupleDomain<Symbol> tupleDomain, Symbol symbol) {
            if (tupleDomain.isNone()) {
                return OptionalInt.empty();
            }
            Domain rowNumberDomain = (Domain)tupleDomain.getDomains().get(symbol);
            if (rowNumberDomain == null) {
                return OptionalInt.empty();
            }
            SortedRangeSet ranges = rowNumberDomain.getRanges();
            if (ranges.isAll() || ranges.isNone() || ranges.getRangeCount() <= 0) {
                return OptionalInt.empty();
            }
            Range span = ranges.getSpan();
            if (span.getHigh().isUpperUnbounded()) {
                return OptionalInt.empty();
            }
            Verify.verify((boolean)rowNumberDomain.getType().equals(Long.class));
            long upperBound = (Long)span.getHigh().getValue();
            if (span.getHigh().getBound() == Marker.Bound.BELOW) {
                --upperBound;
            }
            if (upperBound > Integer.MAX_VALUE) {
                return OptionalInt.empty();
            }
            return OptionalInt.of(Ints.checkedCast((long)upperBound));
        }

        private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPartition) {
            if (node.getMaxRowCountPerPartition().isPresent()) {
                newRowCountPerPartition = Math.min(node.getMaxRowCountPerPartition().get(), newRowCountPerPartition);
            }
            return new RowNumberNode(node.getId(), node.getSource(), node.getPartitionBy(), node.getRowNumberSymbol(), Optional.of(newRowCountPerPartition), node.getHashSymbol());
        }

        private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit) {
            return new TopNRowNumberNode(this.idAllocator.getNextId(), windowNode.getSource(), windowNode.getPartitionBy(), windowNode.getOrderBy(), windowNode.getOrderings(), (Symbol)Iterables.getOnlyElement(windowNode.getWindowFunctions().keySet()), limit, false, Optional.empty());
        }

        private static boolean canReplaceWithRowNumber(WindowNode node) {
            return Rewriter.canOptimizeWindowFunction(node) && node.getOrderBy().isEmpty();
        }

        private static boolean canOptimizeWindowFunction(WindowNode node) {
            if (node.getWindowFunctions().size() != 1) {
                return false;
            }
            Symbol rowNumberSymbol = (Symbol)((Map.Entry)Iterables.getOnlyElement(node.getWindowFunctions().entrySet())).getKey();
            return Rewriter.isRowNumberSignature(node.getSignatures().get(rowNumberSymbol));
        }

        private static boolean isRowNumberSignature(Signature signature) {
            return signature.equals(ROW_NUMBER_SIGNATURE);
        }
    }
}

