/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative;

import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Plans;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.google.common.base.Preconditions;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class Memo {
    private final PlanNodeIdAllocator idAllocator;
    private final int rootGroup;
    private final Map<Integer, PlanNode> membership = new HashMap<Integer, PlanNode>();
    private final Map<Integer, Integer> referenceCounts = new HashMap<Integer, Integer>();
    private int nextGroupId;

    public Memo(PlanNodeIdAllocator idAllocator, PlanNode plan) {
        this.idAllocator = idAllocator;
        this.rootGroup = this.insertRecursive(plan);
        this.referenceCounts.put(this.rootGroup, 1);
    }

    public int getRootGroup() {
        return this.rootGroup;
    }

    public PlanNode getNode(int group) {
        Preconditions.checkArgument((boolean)this.membership.containsKey(group), (String)"Invalid group: %s", (int)group);
        return this.membership.get(group);
    }

    public PlanNode resolve(GroupReference groupReference) {
        return this.getNode(groupReference.getGroupId());
    }

    public PlanNode extract() {
        return this.extract(this.getNode(this.rootGroup));
    }

    private PlanNode extract(PlanNode node) {
        return Plans.resolveGroupReferences(node, Lookup.from(planNode -> Stream.of(this.resolve((GroupReference)planNode))));
    }

    public PlanNode replace(int group, PlanNode node, String reason) {
        PlanNode old = this.membership.get(group);
        Preconditions.checkArgument((boolean)new HashSet<Symbol>(old.getOutputSymbols()).equals(new HashSet<Symbol>(node.getOutputSymbols())), (String)"%s: transformed expression doesn't produce same outputs: %s vs %s", (Object)reason, old.getOutputSymbols(), node.getOutputSymbols());
        node = node instanceof GroupReference ? this.getNode(((GroupReference)node).getGroupId()) : this.insertChildrenAndRewrite(node);
        this.incrementReferenceCounts(node);
        this.membership.put(group, node);
        this.decrementReferenceCounts(old);
        return node;
    }

    private void incrementReferenceCounts(PlanNode node) {
        Set<Integer> references = this.getAllReferences(node);
        for (int group : references) {
            this.referenceCounts.compute(group, (g, count) -> count + 1);
        }
    }

    private void decrementReferenceCounts(PlanNode node) {
        Set<Integer> references = this.getAllReferences(node);
        for (int group : references) {
            int newCount = this.referenceCounts.compute(group, (g, count) -> count - 1);
            Preconditions.checkState((newCount >= 0 ? 1 : 0) != 0, (Object)"Reference count became negative");
            if (newCount != 0) continue;
            PlanNode child = this.membership.get(group);
            this.deleteGroup(group);
            this.decrementReferenceCounts(child);
        }
    }

    private Set<Integer> getAllReferences(PlanNode node) {
        return node.getSources().stream().map(GroupReference.class::cast).map(GroupReference::getGroupId).collect(Collectors.toSet());
    }

    private void deleteGroup(int group) {
        this.membership.remove(group);
        this.referenceCounts.remove(group);
    }

    private PlanNode insertChildrenAndRewrite(PlanNode node) {
        return node.replaceChildren(node.getSources().stream().map(child -> new GroupReference(this.idAllocator.getNextId(), this.insertRecursive((PlanNode)child), child.getOutputSymbols())).collect(Collectors.toList()));
    }

    private int insertRecursive(PlanNode node) {
        if (node instanceof GroupReference) {
            return ((GroupReference)node).getGroupId();
        }
        int group = this.nextGroupId();
        PlanNode rewritten = this.insertChildrenAndRewrite(node);
        this.membership.put(group, rewritten);
        this.referenceCounts.put(group, 0);
        this.incrementReferenceCounts(rewritten);
        return group;
    }

    private int nextGroupId() {
        return this.nextGroupId++;
    }

    public int getGroupCount() {
        return this.membership.size();
    }
}

