package com.facebook.presto.sql.planner.iterative;

import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/Memo.class */
public class Memo {
    private final PlanNodeIdAllocator idAllocator;
    private final int rootGroup;
    private final Map<Integer, PlanNode> membership = new HashMap();
    private final Map<Integer, Integer> referenceCounts = new HashMap();
    private int nextGroupId;

    public Memo(PlanNodeIdAllocator planNodeIdAllocator, PlanNode planNode) {
        this.idAllocator = planNodeIdAllocator;
        this.rootGroup = insertRecursive(planNode);
        this.referenceCounts.put(Integer.valueOf(this.rootGroup), 1);
    }

    public int getRootGroup() {
        return this.rootGroup;
    }

    public PlanNode getNode(int i) {
        Preconditions.checkArgument(this.membership.containsKey(Integer.valueOf(i)), "Invalid group: %s", i);
        return this.membership.get(Integer.valueOf(i));
    }

    public PlanNode resolve(GroupReference groupReference) {
        return getNode(groupReference.getGroupId());
    }

    public PlanNode extract() {
        return extract(getNode(this.rootGroup));
    }

    private PlanNode extract(PlanNode planNode) {
        return Plans.resolveGroupReferences(planNode, Lookup.from(groupReference -> {
            return ImmutableList.of(resolve(groupReference));
        }));
    }

    public PlanNode replace(int i, PlanNode planNode, String str) {
        PlanNode planNode2 = this.membership.get(Integer.valueOf(i));
        Preconditions.checkArgument(new HashSet(planNode2.getOutputSymbols()).equals(new HashSet(planNode.getOutputSymbols())), "%s: transformed expression doesn't produce same outputs: %s vs %s", str, planNode2.getOutputSymbols(), planNode.getOutputSymbols());
        PlanNode node = planNode instanceof GroupReference ? getNode(((GroupReference) planNode).getGroupId()) : insertChildrenAndRewrite(planNode);
        incrementReferenceCounts(node);
        this.membership.put(Integer.valueOf(i), node);
        decrementReferenceCounts(planNode2);
        return node;
    }

    private void incrementReferenceCounts(PlanNode planNode) {
        Iterator<Integer> it2 = getAllReferences(planNode).iterator();
        while (it2.hasNext()) {
            this.referenceCounts.compute(Integer.valueOf(it2.next().intValue()), (num, num2) -> {
                return Integer.valueOf(num2.intValue() + 1);
            });
        }
    }

    private void decrementReferenceCounts(PlanNode planNode) {
        Iterator<Integer> it2 = getAllReferences(planNode).iterator();
        while (it2.hasNext()) {
            int intValue = it2.next().intValue();
            int intValue2 = this.referenceCounts.compute(Integer.valueOf(intValue), (num, num2) -> {
                return Integer.valueOf(num2.intValue() - 1);
            }).intValue();
            Preconditions.checkState(intValue2 >= 0, "Reference count became negative");
            if (intValue2 == 0) {
                PlanNode planNode2 = this.membership.get(Integer.valueOf(intValue));
                deleteGroup(intValue);
                decrementReferenceCounts(planNode2);
            }
        }
    }

    private Set<Integer> getAllReferences(PlanNode planNode) {
        Stream<PlanNode> stream = planNode.getSources().stream();
        Class<GroupReference> cls = GroupReference.class;
        GroupReference.class.getClass();
        return (Set) stream.map((v1) -> {
            return r1.cast(v1);
        }).map((v0) -> {
            return v0.getGroupId();
        }).collect(Collectors.toSet());
    }

    private void deleteGroup(int i) {
        this.membership.remove(Integer.valueOf(i));
        this.referenceCounts.remove(Integer.valueOf(i));
    }

    private PlanNode insertChildrenAndRewrite(PlanNode planNode) {
        return planNode.replaceChildren((List) planNode.getSources().stream().map(planNode2 -> {
            return new GroupReference(this.idAllocator.getNextId(), insertRecursive(planNode2), planNode2.getOutputSymbols());
        }).collect(Collectors.toList()));
    }

    private int insertRecursive(PlanNode planNode) {
        if (planNode instanceof GroupReference) {
            return ((GroupReference) planNode).getGroupId();
        }
        int nextGroupId = nextGroupId();
        PlanNode insertChildrenAndRewrite = insertChildrenAndRewrite(planNode);
        this.membership.put(Integer.valueOf(nextGroupId), insertChildrenAndRewrite);
        this.referenceCounts.put(Integer.valueOf(nextGroupId), 0);
        incrementReferenceCounts(insertChildrenAndRewrite);
        return nextGroupId;
    }

    private int nextGroupId() {
        int i = this.nextGroupId;
        this.nextGroupId = i + 1;
        return i;
    }

    public int getGroupCount() {
        return this.membership.size();
    }
}
