/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative;

import com.facebook.presto.cost.PlanCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.spi.plan.LogicalProperties;
import com.facebook.presto.spi.plan.LogicalPropertiesProvider;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Plans;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;

public class Memo {
    private static final int ROOT_GROUP_REF = 0;
    private final PlanNodeIdAllocator idAllocator;
    private final int rootGroup;
    private final Map<Integer, Group> groups = new HashMap<Integer, Group>();
    private final Optional<LogicalPropertiesProvider> logicalPropertiesProvider;
    private int nextGroupId = 1;

    public Memo(PlanNodeIdAllocator idAllocator, PlanNode plan) {
        this(idAllocator, plan, Optional.empty());
    }

    public Memo(PlanNodeIdAllocator idAllocator, PlanNode plan, Optional<LogicalPropertiesProvider> logicalPropertiesProvider) {
        this.idAllocator = idAllocator;
        this.logicalPropertiesProvider = logicalPropertiesProvider;
        this.rootGroup = this.insertRecursive(plan);
        this.groups.get(this.rootGroup).incomingReferences.add((Object)0);
    }

    public int getRootGroup() {
        return this.rootGroup;
    }

    private Group getGroup(int group) {
        Preconditions.checkArgument((boolean)this.groups.containsKey(group), (String)"Invalid group: %s", (int)group);
        return this.groups.get(group);
    }

    public Optional<LogicalProperties> getLogicalProperties(int group) {
        Preconditions.checkArgument((boolean)this.groups.containsKey(group), (String)"Invalid group: %s", (int)group);
        return this.groups.get(group).logicalProperties;
    }

    public PlanNode getNode(int group) {
        return this.getGroup(group).membership;
    }

    public PlanNode resolve(GroupReference groupReference) {
        return this.getNode(groupReference.getGroupId());
    }

    public PlanNode extract() {
        return this.extract(this.getNode(this.rootGroup));
    }

    private PlanNode extract(PlanNode node) {
        return Plans.resolveGroupReferences(node, Lookup.from(planNode -> Stream.of(this.resolve((GroupReference)((Object)planNode)))));
    }

    public PlanNode replace(int group, PlanNode node, String reason) {
        PlanNode old = this.getGroup(group).membership;
        Preconditions.checkArgument((boolean)new HashSet(old.getOutputVariables()).equals(new HashSet(node.getOutputVariables())), (String)"%s: transformed expression doesn't produce same outputs: %s vs %s", (Object)reason, (Object)old.getOutputVariables(), (Object)node.getOutputVariables());
        node = node instanceof GroupReference ? this.getNode(((GroupReference)node).getGroupId()) : this.insertChildrenAndRewrite(node);
        this.incrementReferenceCounts(node, group);
        this.getGroup(group).membership = node;
        if (this.logicalPropertiesProvider.isPresent()) {
            LogicalProperties newLogicalProperties = node.computeLogicalProperties(this.logicalPropertiesProvider.get());
            this.getGroup(group).logicalProperties = Optional.of(newLogicalProperties);
        }
        this.decrementReferenceCounts(old, group);
        this.evictStatisticsAndCost(group);
        return node;
    }

    private void evictStatisticsAndCost(int group) {
        this.getGroup(group).stats = null;
        this.getGroup(group).cost = null;
        Iterator iterator = this.getGroup(group).incomingReferences.elementSet().iterator();
        while (iterator.hasNext()) {
            int parentGroup = (Integer)iterator.next();
            if (parentGroup == 0) continue;
            this.evictStatisticsAndCost(parentGroup);
        }
    }

    public Optional<PlanNodeStatsEstimate> getStats(int group) {
        return Optional.ofNullable(this.getGroup(group).stats);
    }

    public void storeStats(int groupId, PlanNodeStatsEstimate stats) {
        Group group = this.getGroup(groupId);
        if (group.stats != null) {
            this.evictStatisticsAndCost(groupId);
        }
        group.stats = Objects.requireNonNull(stats, "stats is null");
    }

    public Optional<PlanCostEstimate> getCost(int group) {
        return Optional.ofNullable(this.getGroup(group).cost);
    }

    public void storeCost(int group, PlanCostEstimate cost) {
        this.getGroup(group).cost = Objects.requireNonNull(cost, "cost is null");
    }

    private void incrementReferenceCounts(PlanNode fromNode, int fromGroup) {
        Set<Integer> references = this.getAllReferences(fromNode);
        for (int group : references) {
            this.groups.get(group).incomingReferences.add((Object)fromGroup);
        }
    }

    private void decrementReferenceCounts(PlanNode fromNode, int fromGroup) {
        Set<Integer> references = this.getAllReferences(fromNode);
        for (int group : references) {
            Group childGroup = this.groups.get(group);
            Preconditions.checkState((boolean)childGroup.incomingReferences.remove((Object)fromGroup), (Object)"Reference to remove not found");
            if (!childGroup.incomingReferences.isEmpty()) continue;
            this.deleteGroup(group);
        }
    }

    private Set<Integer> getAllReferences(PlanNode node) {
        return node.getSources().stream().map(GroupReference.class::cast).map(GroupReference::getGroupId).collect(Collectors.toSet());
    }

    private void deleteGroup(int group) {
        Preconditions.checkArgument((boolean)this.getGroup(group).incomingReferences.isEmpty(), (Object)"Cannot delete group that has incoming references");
        PlanNode deletedNode = this.groups.remove(group).membership;
        this.decrementReferenceCounts(deletedNode, group);
    }

    private PlanNode insertChildrenAndRewrite(PlanNode node) {
        return node.replaceChildren(node.getSources().stream().map(child -> {
            int childId = this.insertRecursive((PlanNode)child);
            return new GroupReference(node.getSourceLocation(), this.idAllocator.getNextId(), childId, child.getOutputVariables(), this.groups.get(childId).logicalProperties);
        }).collect(Collectors.toList()));
    }

    private int insertRecursive(PlanNode node) {
        if (node instanceof GroupReference) {
            return ((GroupReference)node).getGroupId();
        }
        int group = this.nextGroupId();
        PlanNode rewritten = this.insertChildrenAndRewrite(node);
        this.groups.put(group, new Group(rewritten, this.logicalPropertiesProvider.map(arg_0 -> ((PlanNode)rewritten).computeLogicalProperties(arg_0))));
        this.incrementReferenceCounts(rewritten, group);
        return group;
    }

    private int nextGroupId() {
        return this.nextGroupId++;
    }

    public int getGroupCount() {
        return this.groups.size();
    }

    private static final class Group {
        private final Multiset<Integer> incomingReferences = HashMultiset.create();
        private PlanNode membership;
        private Optional<LogicalProperties> logicalProperties;
        @Nullable
        private PlanNodeStatsEstimate stats;
        @Nullable
        private PlanCostEstimate cost;

        private Group(PlanNode member, Optional<LogicalProperties> logicalProperties) {
            this.membership = Objects.requireNonNull(member, "member is null");
            this.logicalProperties = logicalProperties;
        }
    }
}

