package org.apache.calcite.rel.metadata;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.metadata.BuiltInMetadata;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexTableInputRef;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

/* loaded from: input_file:org/apache/calcite/rel/metadata/RelMdExpressionLineage.class */
public class RelMdExpressionLineage implements MetadataHandler<BuiltInMetadata.ExpressionLineage> {
    public static final RelMetadataProvider SOURCE;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/calcite/rel/metadata/RelMdExpressionLineage$RexReplacer.class */
    public static class RexReplacer extends RexShuttle {
        private final Map<RexInputRef, RexNode> replacementValues;

        RexReplacer(Map<RexInputRef, RexNode> map) {
            this.replacementValues = map;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.apache.calcite.rex.RexShuttle, org.apache.calcite.rex.RexVisitor
        public RexNode visitInputRef(RexInputRef rexInputRef) {
            return (RexNode) Objects.requireNonNull(this.replacementValues.get(rexInputRef), (Supplier<String>) () -> {
                return "no replacement found for inputRef " + rexInputRef;
            });
        }
    }

    protected RelMdExpressionLineage() {
    }

    @Override // org.apache.calcite.rel.metadata.MetadataHandler
    public MetadataDef<BuiltInMetadata.ExpressionLineage> getDef() {
        return BuiltInMetadata.ExpressionLineage.DEF;
    }

    public Set<RexNode> getExpressionLineage(RelNode relNode, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        return null;
    }

    public Set<RexNode> getExpressionLineage(RelSubset relSubset, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        RelNode relNode = (RelNode) Util.first(relSubset.getBest(), relSubset.getOriginal());
        if (relNode == null) {
            return null;
        }
        return relMetadataQuery.getExpressionLineage(relNode, rexNode);
    }

    public Set<RexNode> getExpressionLineage(TableScan tableScan, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        BuiltInMetadata.ExpressionLineage.Handler handler = (BuiltInMetadata.ExpressionLineage.Handler) tableScan.getTable().unwrap(BuiltInMetadata.ExpressionLineage.Handler.class);
        if (handler != null) {
            return handler.getExpressionLineage(tableScan, relMetadataQuery, rexNode);
        }
        RexBuilder rexBuilder = tableScan.getCluster().getRexBuilder();
        ImmutableBitSet extractInputRefs = extractInputRefs(rexNode);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<Integer> it = extractInputRefs.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            linkedHashMap.put(RexInputRef.of(intValue, tableScan.getRowType().getFieldList()), ImmutableSet.of(RexTableInputRef.of(RexTableInputRef.RelTableRef.of(tableScan.getTable(), 0), RexInputRef.of(intValue, tableScan.getRowType().getFieldList()))));
        }
        return createAllPossibleExpressions(rexBuilder, rexNode, linkedHashMap);
    }

    public Set<RexNode> getExpressionLineage(Aggregate aggregate, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        RelNode input = aggregate.getInput();
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        ImmutableBitSet extractInputRefs = extractInputRefs(rexNode);
        Iterator<Integer> it = extractInputRefs.iterator();
        while (it.hasNext()) {
            if (it.next().intValue() >= aggregate.getGroupCount()) {
                return null;
            }
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<Integer> it2 = extractInputRefs.iterator();
        while (it2.hasNext()) {
            int intValue = it2.next().intValue();
            Set<RexNode> expressionLineage = relMetadataQuery.getExpressionLineage(input, RexInputRef.of(aggregate.getGroupSet().nth(intValue), input.getRowType().getFieldList()));
            if (expressionLineage == null) {
                return null;
            }
            linkedHashMap.put(RexInputRef.of(intValue, aggregate.getRowType().getFieldList()), expressionLineage);
        }
        return createAllPossibleExpressions(rexBuilder, rexNode, linkedHashMap);
    }

    public Set<RexNode> getExpressionLineage(Join join, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        Set<RexTableInputRef.RelTableRef> tableReferences;
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        RelNode left = join.getLeft();
        RelNode right = join.getRight();
        int size = left.getRowType().getFieldList().size();
        ImmutableBitSet extractInputRefs = extractInputRefs(rexNode);
        if (join.getJoinType().isOuterJoin()) {
            if (join.getJoinType() == JoinRelType.LEFT) {
                if (extractInputRefs.intersects(ImmutableBitSet.range(size, join.getRowType().getFieldCount()))) {
                    return null;
                }
            } else if (join.getJoinType() != JoinRelType.RIGHT || extractInputRefs.intersects(ImmutableBitSet.range(0, size))) {
                return null;
            }
        }
        Set<RexTableInputRef.RelTableRef> tableReferences2 = relMetadataQuery.getTableReferences(left);
        if (tableReferences2 == null || (tableReferences = relMetadataQuery.getTableReferences(right)) == null) {
            return null;
        }
        HashMultimap create = HashMultimap.create();
        HashMap hashMap = new HashMap();
        for (RexTableInputRef.RelTableRef relTableRef : tableReferences2) {
            create.put(relTableRef.getQualifiedName(), relTableRef);
        }
        for (RexTableInputRef.RelTableRef relTableRef2 : tableReferences) {
            int i = 0;
            Collection collection = create.get(relTableRef2.getQualifiedName());
            if (collection != null) {
                i = collection.size();
            }
            hashMap.put(relTableRef2, RexTableInputRef.RelTableRef.of(relTableRef2.getTable(), i + relTableRef2.getEntityNumber()));
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<Integer> it = extractInputRefs.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue < size) {
                Set<RexNode> expressionLineage = relMetadataQuery.getExpressionLineage(left, RexInputRef.of(intValue, left.getRowType().getFieldList()));
                if (expressionLineage == null) {
                    return null;
                }
                linkedHashMap.put(RexInputRef.of(intValue, join.getRowType().getFieldList()), expressionLineage);
            } else {
                Set<RexNode> expressionLineage2 = relMetadataQuery.getExpressionLineage(right, RexInputRef.of(intValue - size, right.getRowType().getFieldList()));
                if (expressionLineage2 == null) {
                    return null;
                }
                linkedHashMap.put(RexInputRef.of(intValue, SqlValidatorUtil.createJoinType(rexBuilder.getTypeFactory(), join.getLeft().getRowType(), join.getRight().getRowType(), null, ImmutableList.of())), ImmutableSet.copyOf(Util.transform(expressionLineage2, rexNode2 -> {
                    return RexUtil.swapTableReferences(rexBuilder, rexNode2, hashMap);
                })));
            }
        }
        return createAllPossibleExpressions(rexBuilder, rexNode, linkedHashMap);
    }

    public Set<RexNode> getExpressionLineage(Union union, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        RexBuilder rexBuilder = union.getCluster().getRexBuilder();
        ImmutableBitSet extractInputRefs = extractInputRefs(rexNode);
        HashMultimap create = HashMultimap.create();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (RelNode relNode : union.getInputs()) {
            HashMap hashMap = new HashMap();
            Set<RexTableInputRef.RelTableRef> tableReferences = relMetadataQuery.getTableReferences(relNode);
            if (tableReferences == null) {
                return null;
            }
            for (RexTableInputRef.RelTableRef relTableRef : tableReferences) {
                int i = 0;
                Collection collection = create.get(relTableRef.getQualifiedName());
                if (collection != null) {
                    i = collection.size();
                }
                hashMap.put(relTableRef, RexTableInputRef.RelTableRef.of(relTableRef.getTable(), i + relTableRef.getEntityNumber()));
            }
            Iterator<Integer> it = extractInputRefs.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                Set<RexNode> expressionLineage = relMetadataQuery.getExpressionLineage(relNode, RexInputRef.of(intValue, relNode.getRowType().getFieldList()));
                if (expressionLineage == null) {
                    return null;
                }
                RexInputRef of = RexInputRef.of(intValue, union.getRowType().getFieldList());
                Set set = (Set) expressionLineage.stream().map(rexNode2 -> {
                    return RexUtil.swapTableReferences(rexBuilder, rexNode2, hashMap);
                }).collect(Collectors.toSet());
                Set set2 = (Set) linkedHashMap.get(of);
                if (set2 != null) {
                    set2.addAll(set);
                } else {
                    linkedHashMap.put(of, set);
                }
            }
            for (RexTableInputRef.RelTableRef relTableRef2 : hashMap.values()) {
                create.put(relTableRef2.getQualifiedName(), relTableRef2);
            }
        }
        return createAllPossibleExpressions(rexBuilder, rexNode, linkedHashMap);
    }

    public Set<RexNode> getExpressionLineage(Project project, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        RelNode input = project.getInput();
        RexBuilder rexBuilder = project.getCluster().getRexBuilder();
        ImmutableBitSet extractInputRefs = extractInputRefs(rexNode);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<Integer> it = extractInputRefs.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            Set<RexNode> expressionLineage = relMetadataQuery.getExpressionLineage(input, project.getProjects().get(intValue));
            if (expressionLineage == null) {
                return null;
            }
            linkedHashMap.put(RexInputRef.of(intValue, project.getRowType().getFieldList()), expressionLineage);
        }
        return createAllPossibleExpressions(rexBuilder, rexNode, linkedHashMap);
    }

    public Set<RexNode> getExpressionLineage(Filter filter, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        return relMetadataQuery.getExpressionLineage(filter.getInput(), rexNode);
    }

    public Set<RexNode> getExpressionLineage(Sort sort, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        return relMetadataQuery.getExpressionLineage(sort.getInput(), rexNode);
    }

    public Set<RexNode> getExpressionLineage(TableModify tableModify, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        return relMetadataQuery.getExpressionLineage(tableModify.getInput(), rexNode);
    }

    public Set<RexNode> getExpressionLineage(Exchange exchange, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        return relMetadataQuery.getExpressionLineage(exchange.getInput(), rexNode);
    }

    public Set<RexNode> getExpressionLineage(Calc calc, RelMetadataQuery relMetadataQuery, RexNode rexNode) {
        RelNode input = calc.getInput();
        RexBuilder rexBuilder = calc.getCluster().getRexBuilder();
        ImmutableBitSet extractInputRefs = extractInputRefs(rexNode);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Pair<ImmutableList<RexNode>, ImmutableList<RexNode>> split = calc.getProgram().split();
        Iterator<Integer> it = extractInputRefs.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            Set<RexNode> expressionLineage = relMetadataQuery.getExpressionLineage(input, (RexNode) split.getKey().get(intValue));
            if (expressionLineage == null) {
                return null;
            }
            linkedHashMap.put(RexInputRef.of(intValue, calc.getRowType().getFieldList()), expressionLineage);
        }
        return createAllPossibleExpressions(rexBuilder, rexNode, linkedHashMap);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Set<RexNode> createAllPossibleExpressions(RexBuilder rexBuilder, RexNode rexNode, Map<RexInputRef, Set<RexNode>> map) {
        ImmutableBitSet extractInputRefs = extractInputRefs(rexNode);
        if (extractInputRefs.isEmpty()) {
            return ImmutableSet.of(rexNode);
        }
        try {
            return createAllPossibleExpressions(rexBuilder, rexNode, extractInputRefs, map, new HashMap());
        } catch (UnsupportedOperationException e) {
            return null;
        }
    }

    private static Set<RexNode> createAllPossibleExpressions(RexBuilder rexBuilder, RexNode rexNode, ImmutableBitSet immutableBitSet, Map<RexInputRef, Set<RexNode>> map, Map<RexInputRef, RexNode> map2) {
        RexInputRef next = map.keySet().iterator().next();
        Set<RexNode> set = (Set) Objects.requireNonNull(map.remove(next), (Supplier<String>) () -> {
            return "mapping.remove(inputRef) is null for " + next;
        });
        HashSet hashSet = new HashSet();
        if (!$assertionsDisabled && set.isEmpty()) {
            throw new AssertionError();
        }
        if (immutableBitSet.indexOf(next.getIndex()) != -1) {
            Iterator<RexNode> it = set.iterator();
            while (it.hasNext()) {
                map2.put(next, it.next());
                createExpressions(rexBuilder, rexNode, immutableBitSet, map, map2, hashSet);
                map2.remove(next);
            }
        } else {
            createExpressions(rexBuilder, rexNode, immutableBitSet, map, map2, hashSet);
        }
        map.put(next, set);
        return hashSet;
    }

    private static void createExpressions(RexBuilder rexBuilder, RexNode rexNode, ImmutableBitSet immutableBitSet, Map<RexInputRef, Set<RexNode>> map, Map<RexInputRef, RexNode> map2, Set<RexNode> set) {
        if (!map.isEmpty()) {
            set.addAll(createAllPossibleExpressions(rexBuilder, rexNode, immutableBitSet, map, map2));
            return;
        }
        RexReplacer rexReplacer = new RexReplacer(map2);
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(rexBuilder.copy(rexNode));
        rexReplacer.mutate(arrayList);
        set.addAll(arrayList);
    }

    private static ImmutableBitSet extractInputRefs(RexNode rexNode) {
        RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(new LinkedHashSet());
        rexNode.accept(inputFinder);
        return inputFinder.build();
    }

    static {
        $assertionsDisabled = !RelMdExpressionLineage.class.desiredAssertionStatus();
        SOURCE = ReflectiveRelMetadataProvider.reflectiveSource(new RelMdExpressionLineage(), BuiltInMetadata.ExpressionLineage.Handler.class);
    }
}
