/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative;

import com.facebook.presto.cost.PlanNodeCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Plans;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;

public class Memo {
    private static final int ROOT_GROUP_REF = 0;
    private final PlanNodeIdAllocator idAllocator;
    private final int rootGroup;
    private final Map<Integer, Group> groups = new HashMap<Integer, Group>();
    private int nextGroupId = 1;

    public Memo(PlanNodeIdAllocator idAllocator, PlanNode plan) {
        this.idAllocator = idAllocator;
        this.rootGroup = this.insertRecursive(plan);
        this.groups.get(this.rootGroup).incomingReferences.add((Object)0);
    }

    public int getRootGroup() {
        return this.rootGroup;
    }

    private Group getGroup(int group) {
        Preconditions.checkArgument((boolean)this.groups.containsKey(group), (String)"Invalid group: %s", (int)group);
        return this.groups.get(group);
    }

    public PlanNode getNode(int group) {
        return this.getGroup(group).membership;
    }

    public PlanNode resolve(GroupReference groupReference) {
        return this.getNode(groupReference.getGroupId());
    }

    public PlanNode extract() {
        return this.extract(this.getNode(this.rootGroup));
    }

    private PlanNode extract(PlanNode node) {
        return Plans.resolveGroupReferences(node, Lookup.from(planNode -> Stream.of(this.resolve((GroupReference)planNode))));
    }

    public PlanNode replace(int group, PlanNode node, String reason) {
        PlanNode old = this.getGroup(group).membership;
        Preconditions.checkArgument((boolean)new HashSet<Symbol>(old.getOutputSymbols()).equals(new HashSet<Symbol>(node.getOutputSymbols())), (String)"%s: transformed expression doesn't produce same outputs: %s vs %s", (Object)reason, old.getOutputSymbols(), node.getOutputSymbols());
        node = node instanceof GroupReference ? this.getNode(((GroupReference)node).getGroupId()) : this.insertChildrenAndRewrite(node);
        this.incrementReferenceCounts(node, group);
        this.getGroup(group).membership = node;
        this.decrementReferenceCounts(old, group);
        this.evictStatisticsAndCost(group);
        return node;
    }

    private void evictStatisticsAndCost(int group) {
        this.getGroup(group).stats = null;
        this.getGroup(group).cumulativeCost = null;
        Iterator iterator = this.getGroup(group).incomingReferences.elementSet().iterator();
        while (iterator.hasNext()) {
            int parentGroup = (Integer)iterator.next();
            if (parentGroup == 0) continue;
            this.evictStatisticsAndCost(parentGroup);
        }
    }

    public Optional<PlanNodeStatsEstimate> getStats(int group) {
        return Optional.ofNullable(this.getGroup(group).stats);
    }

    public void storeStats(int groupId, PlanNodeStatsEstimate stats) {
        Group group = this.getGroup(groupId);
        if (group.stats != null) {
            this.evictStatisticsAndCost(groupId);
        }
        group.stats = Objects.requireNonNull(stats, "stats is null");
    }

    public Optional<PlanNodeCostEstimate> getCumulativeCost(int group) {
        return Optional.ofNullable(this.getGroup(group).cumulativeCost);
    }

    public void storeCumulativeCost(int group, PlanNodeCostEstimate cost) {
        this.getGroup(group).cumulativeCost = Objects.requireNonNull(cost, "cost is null");
    }

    private void incrementReferenceCounts(PlanNode fromNode, int fromGroup) {
        Set<Integer> references = this.getAllReferences(fromNode);
        for (int group : references) {
            this.groups.get(group).incomingReferences.add((Object)fromGroup);
        }
    }

    private void decrementReferenceCounts(PlanNode fromNode, int fromGroup) {
        Set<Integer> references = this.getAllReferences(fromNode);
        for (int group : references) {
            Group childGroup = this.groups.get(group);
            Preconditions.checkState((boolean)childGroup.incomingReferences.remove((Object)fromGroup), (Object)"Reference to remove not found");
            if (!childGroup.incomingReferences.isEmpty()) continue;
            this.deleteGroup(group);
        }
    }

    private Set<Integer> getAllReferences(PlanNode node) {
        return node.getSources().stream().map(GroupReference.class::cast).map(GroupReference::getGroupId).collect(Collectors.toSet());
    }

    private void deleteGroup(int group) {
        Preconditions.checkArgument((boolean)this.getGroup(group).incomingReferences.isEmpty(), (Object)"Cannot delete group that has incoming references");
        PlanNode deletedNode = this.groups.remove(group).membership;
        this.decrementReferenceCounts(deletedNode, group);
    }

    private PlanNode insertChildrenAndRewrite(PlanNode node) {
        return node.replaceChildren(node.getSources().stream().map(child -> new GroupReference(this.idAllocator.getNextId(), this.insertRecursive((PlanNode)child), child.getOutputSymbols())).collect(Collectors.toList()));
    }

    private int insertRecursive(PlanNode node) {
        if (node instanceof GroupReference) {
            return ((GroupReference)node).getGroupId();
        }
        int group = this.nextGroupId();
        PlanNode rewritten = this.insertChildrenAndRewrite(node);
        this.groups.put(group, Group.withMember(rewritten));
        this.incrementReferenceCounts(rewritten, group);
        return group;
    }

    private int nextGroupId() {
        return this.nextGroupId++;
    }

    public int getGroupCount() {
        return this.groups.size();
    }

    private static final class Group {
        private PlanNode membership;
        private Multiset<Integer> incomingReferences = HashMultiset.create();
        @Nullable
        private PlanNodeStatsEstimate stats;
        @Nullable
        private PlanNodeCostEstimate cumulativeCost;

        static Group withMember(PlanNode member) {
            return new Group(member);
        }

        private Group(PlanNode member) {
            this.membership = Objects.requireNonNull(member, "member is null");
        }
    }
}

