/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.TableHandle;
import io.prestosql.spi.connector.AggregateFunction;
import io.prestosql.spi.connector.AggregationApplicationResult;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.spi.connector.SortOrder;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.expression.Variable;
import io.prestosql.sql.planner.ConnectorExpressionTranslator;
import io.prestosql.sql.planner.LiteralEncoder;
import io.prestosql.sql.planner.OrderingScheme;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.SymbolReference;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.IntStream;

public class PushAggregationIntoTableScan
implements Rule<AggregationNode> {
    private static final Capture<TableScanNode> TABLE_SCAN = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(PushAggregationIntoTableScan::allArgumentsAreSimpleReferences).matching(node -> node.getGroupingSets().getGroupingSetCount() <= 1).matching(PushAggregationIntoTableScan::hasNoMasks).with(Patterns.source().matching(Patterns.tableScan().capturedAs(TABLE_SCAN)));
    private final Metadata metadata;

    public PushAggregationIntoTableScan(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    private static boolean allArgumentsAreSimpleReferences(AggregationNode node) {
        return node.getAggregations().values().stream().flatMap(aggregation -> aggregation.getArguments().stream()).allMatch(SymbolReference.class::isInstance);
    }

    private static boolean hasNoMasks(AggregationNode node) {
        return !node.getAggregations().values().stream().map(aggregation -> aggregation.getMask().isPresent()).anyMatch(isMaskPresent -> isMaskPresent);
    }

    @Override
    public Rule.Result apply(AggregationNode node, Captures captures, Rule.Context context) {
        TableScanNode tableScan = (TableScanNode)captures.get(TABLE_SCAN);
        Map assignments = (Map)tableScan.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((Symbol)entry.getKey()).getName(), Map.Entry::getValue));
        List aggregations = (List)node.getAggregations().entrySet().stream().collect(ImmutableList.toImmutableList());
        List aggregateFunctions = (List)aggregations.stream().map(Map.Entry::getValue).map(aggregation -> this.toAggregateFunction(context, (AggregationNode.Aggregation)aggregation)).collect(ImmutableList.toImmutableList());
        List aggregationOutputSymbols = (List)aggregations.stream().map(Map.Entry::getKey).collect(ImmutableList.toImmutableList());
        AggregationNode.GroupingSetDescriptor groupingSets = node.getGroupingSets();
        List groupByColumns = (List)groupingSets.getGroupingKeys().stream().map(groupByColumn -> (ColumnHandle)assignments.get(groupByColumn.getName())).collect(ImmutableList.toImmutableList());
        Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = this.metadata.applyAggregation(context.getSession(), tableScan.getTable(), aggregateFunctions, assignments, (List<List<ColumnHandle>>)ImmutableList.of((Object)groupByColumns));
        if (aggregationPushdownResult.isEmpty()) {
            return Rule.Result.empty();
        }
        AggregationApplicationResult<TableHandle> result = aggregationPushdownResult.get();
        ImmutableList.Builder newScanOutputs = new ImmutableList.Builder();
        newScanOutputs.addAll(tableScan.getOutputSymbols());
        ImmutableBiMap.Builder newScanAssignments = new ImmutableBiMap.Builder();
        newScanAssignments.putAll(tableScan.getAssignments());
        HashMap<String, Symbol> variableMappings = new HashMap<String, Symbol>();
        for (Assignment assignment : result.getAssignments()) {
            Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
            newScanOutputs.add((Object)symbol);
            newScanAssignments.put((Object)symbol, (Object)assignment.getColumn());
            variableMappings.put(assignment.getVariable(), symbol);
        }
        List newProjections = (List)result.getProjections().stream().map(expression -> ConnectorExpressionTranslator.translate(expression, variableMappings, new LiteralEncoder(this.metadata))).collect(ImmutableList.toImmutableList());
        Verify.verify((aggregationOutputSymbols.size() == newProjections.size() ? 1 : 0) != 0);
        Assignments.Builder assignmentBuilder = Assignments.builder();
        IntStream.range(0, aggregationOutputSymbols.size()).forEach(index -> assignmentBuilder.put((Symbol)aggregationOutputSymbols.get(index), (Expression)newProjections.get(index)));
        ImmutableBiMap scanAssignments = newScanAssignments.build();
        ImmutableBiMap columnHandleToSymbol = scanAssignments.inverse();
        groupingSets.getGroupingKeys().forEach(groupBySymbol -> {
            ColumnHandle originalColumnHandle = (ColumnHandle)assignments.get(groupBySymbol.getName());
            ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
            assignmentBuilder.put((Symbol)groupBySymbol, (Expression)((Symbol)columnHandleToSymbol.get((Object)groupByColumnHandle)).toSymbolReference());
        });
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), TableScanNode.newInstance(context.getIdAllocator().getNextId(), (TableHandle)result.getHandle(), (List<Symbol>)newScanOutputs.build(), (Map<Symbol, ColumnHandle>)scanAssignments), assignmentBuilder.build()));
    }

    private AggregateFunction toAggregateFunction(Rule.Context context, AggregationNode.Aggregation aggregation) {
        Signature signature = aggregation.getResolvedFunction().getSignature();
        ImmutableList.Builder arguments = new ImmutableList.Builder();
        for (int i = 0; i < aggregation.getArguments().size(); ++i) {
            SymbolReference argument = (SymbolReference)aggregation.getArguments().get(i);
            arguments.add((Object)new Variable(argument.getName(), this.metadata.getType(signature.getArgumentTypes().get(i))));
        }
        Optional<OrderingScheme> orderingScheme = aggregation.getOrderingScheme();
        Optional<List> sortBy = orderingScheme.map(orderings -> (List)orderings.getOrderBy().stream().map(orderBy -> new SortItem(orderBy.getName(), SortOrder.valueOf((String)orderings.getOrderings().get(orderBy).name()))).collect(ImmutableList.toImmutableList()));
        Optional<ConnectorExpression> filter = aggregation.getFilter().map(symbol -> new Variable(symbol.getName(), context.getSymbolAllocator().getTypes().get((Symbol)symbol)));
        return new AggregateFunction(signature.getName(), this.metadata.getType(signature.getReturnType()), (List)arguments.build(), sortBy.orElse((List)ImmutableList.of()), aggregation.isDistinct(), filter);
    }
}

