/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.metadata.FunctionId;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.analyzer.TypeSignatureProvider;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.optimizations.PlanNodeDecorrelator;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.CorrelatedJoinNode;
import io.prestosql.sql.planner.plan.DynamicFilterId;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.QualifiedName;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class ScalarAggregationToJoinRewriter {
    private final Metadata metadata;
    private final SymbolAllocator symbolAllocator;
    private final PlanNodeIdAllocator idAllocator;
    private final Lookup lookup;
    private final PlanNodeDecorrelator planNodeDecorrelator;

    public ScalarAggregationToJoinRewriter(Metadata metadata, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
        this.lookup = Objects.requireNonNull(lookup, "lookup is null");
        this.planNodeDecorrelator = new PlanNodeDecorrelator(metadata, symbolAllocator, lookup);
    }

    public PlanNode rewriteScalarAggregation(CorrelatedJoinNode correlatedJoinNode, AggregationNode aggregation) {
        List<Symbol> correlation = correlatedJoinNode.getCorrelation();
        Optional<PlanNodeDecorrelator.DecorrelatedNode> source = this.planNodeDecorrelator.decorrelateFilters(aggregation.getSource(), correlation);
        if (source.isEmpty()) {
            return correlatedJoinNode;
        }
        Symbol nonNull = this.symbolAllocator.newSymbol("non_null", (Type)BooleanType.BOOLEAN);
        Assignments scalarAggregationSourceAssignments = Assignments.builder().putIdentities(source.get().getNode().getOutputSymbols()).put(nonNull, (Expression)BooleanLiteral.TRUE_LITERAL).build();
        ProjectNode scalarAggregationSourceWithNonNullableSymbol = new ProjectNode(this.idAllocator.getNextId(), source.get().getNode(), scalarAggregationSourceAssignments);
        return this.rewriteScalarAggregation(correlatedJoinNode, aggregation, scalarAggregationSourceWithNonNullableSymbol, source.get().getCorrelatedPredicates(), nonNull);
    }

    private PlanNode rewriteScalarAggregation(CorrelatedJoinNode correlatedJoinNode, AggregationNode scalarAggregation, PlanNode scalarAggregationSource, Optional<Expression> joinExpression, Symbol nonNull) {
        AssignUniqueId inputWithUniqueColumns = new AssignUniqueId(this.idAllocator.getNextId(), correlatedJoinNode.getInput(), this.symbolAllocator.newSymbol("unique", (Type)BigintType.BIGINT));
        JoinNode leftOuterJoin = new JoinNode(correlatedJoinNode.getId(), JoinNode.Type.LEFT, inputWithUniqueColumns, scalarAggregationSource, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), inputWithUniqueColumns.getOutputSymbols(), scalarAggregationSource.getOutputSymbols(), joinExpression, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map<DynamicFilterId, Symbol>)ImmutableMap.of(), Optional.empty());
        return this.createAggregationNode(scalarAggregation, leftOuterJoin, nonNull);
    }

    private AggregationNode createAggregationNode(AggregationNode scalarAggregation, JoinNode leftOuterJoin, Symbol nonNullableAggregationSourceSymbol) {
        FunctionId countFunctionId = this.metadata.resolveFunction(QualifiedName.of((String)"count"), (List<TypeSignatureProvider>)ImmutableList.of()).getFunctionId();
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : scalarAggregation.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = entry.getValue();
            Symbol symbol = entry.getKey();
            if (aggregation.getResolvedFunction().getFunctionId().equals(countFunctionId)) {
                ImmutableList scalarAggregationSourceTypes = ImmutableList.of((Object)this.symbolAllocator.getTypes().get(nonNullableAggregationSourceSymbol));
                aggregations.put((Object)symbol, (Object)new AggregationNode.Aggregation(this.metadata.resolveFunction(QualifiedName.of((String)"count"), TypeSignatureProvider.fromTypes((List<? extends Type>)scalarAggregationSourceTypes)), (List<Expression>)ImmutableList.of((Object)nonNullableAggregationSourceSymbol.toSymbolReference()), false, Optional.empty(), Optional.empty(), aggregation.getMask()));
                continue;
            }
            aggregations.put((Object)symbol, (Object)aggregation);
        }
        return new AggregationNode(scalarAggregation.getId(), leftOuterJoin, (Map<Symbol, AggregationNode.Aggregation>)aggregations.build(), AggregationNode.singleGroupingSet(leftOuterJoin.getLeft().getOutputSymbols()), (List<Symbol>)ImmutableList.of(), scalarAggregation.getStep(), scalarAggregation.getHashSymbol(), Optional.empty());
    }
}

