/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import org.apache.beam.sdks.java.extensions.sql.repackaged.com.google.common.collect.ImmutableList;
import org.apache.beam.sdks.java.extensions.sql.repackaged.com.google.common.collect.Lists;
import org.apache.beam.sdks.java.extensions.sql.repackaged.com.google.common.collect.Ordering;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.plan.RelOptCost;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.plan.RelOptTable;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.plan.RelOptUtil;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.RelNode;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.core.JoinInfo;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.core.SemiJoin;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.rules.LoptMultiJoin;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rex.RexBuilder;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rex.RexCall;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rex.RexNode;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rex.RexUtil;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.SqlKind;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.SqlOperator;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.util.ImmutableBitSet;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.util.ImmutableIntList;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.util.Util;

public class LoptSemiJoinOptimizer {
    private static final int THRESHOLD_SCORE = 10;
    private final RexBuilder rexBuilder;
    private final RelMetadataQuery mq;
    private RelNode[] chosenSemiJoins;
    private Map<Integer, Map<Integer, SemiJoin>> possibleSemiJoins;
    private final Ordering<Integer> factorCostOrdering = Ordering.from(new FactorCostComparator());

    public LoptSemiJoinOptimizer(RelMetadataQuery mq, LoptMultiJoin multiJoin, RexBuilder rexBuilder) {
        this.mq = mq;
        int nJoinFactors = multiJoin.getNumJoinFactors();
        this.chosenSemiJoins = new RelNode[nJoinFactors];
        for (int i = 0; i < nJoinFactors; ++i) {
            this.chosenSemiJoins[i] = multiJoin.getJoinFactor(i);
        }
        this.rexBuilder = rexBuilder;
    }

    public void makePossibleSemiJoins(LoptMultiJoin multiJoin) {
        this.possibleSemiJoins = new HashMap<Integer, Map<Integer, SemiJoin>>();
        if (multiJoin.getMultiJoinRel().isFullOuterJoin()) {
            return;
        }
        int nJoinFactors = multiJoin.getNumJoinFactors();
        for (int factIdx = 0; factIdx < nJoinFactors; ++factIdx) {
            HashMap<Integer, ArrayList<RexNode>> dimFilters = new HashMap<Integer, ArrayList<RexNode>>();
            HashMap<Integer, SemiJoin> semiJoinMap = new HashMap<Integer, SemiJoin>();
            for (RexNode joinFilter : multiJoin.getJoinFilters()) {
                int dimIdx = this.isSuitableFilter(multiJoin, joinFilter, factIdx);
                if (dimIdx == -1 || multiJoin.isNullGenerating(factIdx) || multiJoin.isNullGenerating(dimIdx)) continue;
                ArrayList<RexNode> currDimFilters = (ArrayList<RexNode>)dimFilters.get(dimIdx);
                if (currDimFilters == null) {
                    currDimFilters = new ArrayList<RexNode>();
                }
                currDimFilters.add(joinFilter);
                dimFilters.put(dimIdx, currDimFilters);
            }
            Set dimIdxes = dimFilters.keySet();
            for (Integer dimIdx : dimIdxes) {
                SemiJoin semiJoin;
                List joinFilters = (List)dimFilters.get(dimIdx);
                if (joinFilters == null || (semiJoin = this.findSemiJoinIndexByCost(multiJoin, joinFilters, factIdx, dimIdx)) == null) continue;
                semiJoinMap.put(dimIdx, semiJoin);
                this.possibleSemiJoins.put(factIdx, semiJoinMap);
            }
        }
    }

    private int isSuitableFilter(LoptMultiJoin multiJoin, RexNode joinFilter, int factIdx) {
        switch (joinFilter.getKind()) {
            case EQUALS: {
                break;
            }
            default: {
                return -1;
            }
        }
        List<RexNode> operands = ((RexCall)joinFilter).getOperands();
        if (!(operands.get(0) instanceof RexInputRef) || !(operands.get(1) instanceof RexInputRef)) {
            return -1;
        }
        ImmutableBitSet joinRefs = multiJoin.getFactorsRefByJoinFilter(joinFilter);
        assert (joinRefs.cardinality() == 2);
        int factor1 = joinRefs.nextSetBit(0);
        int factor2 = joinRefs.nextSetBit(factor1 + 1);
        if (factor1 == factIdx) {
            return factor2;
        }
        if (factor2 == factIdx) {
            return factor1;
        }
        return -1;
    }

    private SemiJoin findSemiJoinIndexByCost(LoptMultiJoin multiJoin, List<RexNode> joinFilters, int factIdx, int dimIdx) {
        ArrayList<Integer> truncatedRightKeys;
        ArrayList<Integer> truncatedLeftKeys;
        ArrayList<Integer> actualLeftKeys;
        ArrayList<Integer> rightKeys;
        RexNode semiJoinCondition = RexUtil.composeConjunction(this.rexBuilder, joinFilters, true);
        int leftAdjustment = 0;
        for (int i = 0; i < factIdx; ++i) {
            leftAdjustment -= multiJoin.getNumFieldsInJoinFactor(i);
        }
        semiJoinCondition = this.adjustSemiJoinCondition(multiJoin, leftAdjustment, semiJoinCondition, factIdx, dimIdx);
        RelNode factRel = multiJoin.getJoinFactor(factIdx);
        RelNode dimRel = multiJoin.getJoinFactor(dimIdx);
        JoinInfo joinInfo = JoinInfo.of(factRel, dimRel, semiJoinCondition);
        assert (joinInfo.leftKeys.size() > 0);
        ArrayList<Integer> leftKeys = Lists.newArrayList(joinInfo.leftKeys);
        LcsTable factTable = this.validateKeys(factRel, leftKeys, rightKeys = Lists.newArrayList(joinInfo.rightKeys), actualLeftKeys = new ArrayList<Integer>());
        if (factTable == null) {
            return null;
        }
        ArrayList<Integer> bestKeyOrder = new ArrayList<Integer>();
        LcsTableScan tmpFactRel = (LcsTableScan)((Object)factTable.toRel(RelOptUtil.getContext(factRel.getCluster())));
        LcsIndexOptimizer indexOptimizer = new LcsIndexOptimizer(tmpFactRel);
        FemLocalIndex bestIndex = indexOptimizer.findSemiJoinIndexByCost(dimRel, actualLeftKeys, rightKeys, bestKeyOrder);
        if (bestIndex == null) {
            return null;
        }
        if (actualLeftKeys.size() == bestKeyOrder.size()) {
            truncatedLeftKeys = leftKeys;
            truncatedRightKeys = rightKeys;
        } else {
            truncatedLeftKeys = new ArrayList();
            truncatedRightKeys = new ArrayList();
            Iterator iterator = bestKeyOrder.iterator();
            while (iterator.hasNext()) {
                int key = (Integer)iterator.next();
                truncatedLeftKeys.add((Integer)leftKeys.get(key));
                truncatedRightKeys.add((Integer)rightKeys.get(key));
            }
            semiJoinCondition = this.removeExtraFilters(truncatedLeftKeys, multiJoin.getNumFieldsInJoinFactor(factIdx), semiJoinCondition);
        }
        return SemiJoin.create(factRel, dimRel, semiJoinCondition, ImmutableIntList.copyOf(truncatedLeftKeys), ImmutableIntList.copyOf(truncatedRightKeys));
    }

    private RexNode adjustSemiJoinCondition(LoptMultiJoin multiJoin, int leftAdjustment, RexNode semiJoinCondition, int leftIdx, int rightIdx) {
        int rightAdjustment = 0;
        for (int i = 0; i < rightIdx; ++i) {
            rightAdjustment -= multiJoin.getNumFieldsInJoinFactor(i);
        }
        int rightStart = -rightAdjustment;
        int numFieldsLeftIdx = multiJoin.getNumFieldsInJoinFactor(leftIdx);
        int numFieldsRightIdx = multiJoin.getNumFieldsInJoinFactor(rightIdx);
        if (leftAdjustment != 0 || (rightAdjustment += numFieldsLeftIdx) != 0) {
            int i;
            int[] adjustments = new int[multiJoin.getNumTotalFields()];
            if (leftAdjustment != 0) {
                for (i = -leftAdjustment; i < -leftAdjustment + numFieldsLeftIdx; ++i) {
                    adjustments[i] = leftAdjustment;
                }
            }
            if (rightAdjustment != 0) {
                for (i = rightStart; i < rightStart + numFieldsRightIdx; ++i) {
                    adjustments[i] = rightAdjustment;
                }
            }
            return semiJoinCondition.accept(new RelOptUtil.RexInputConverter(this.rexBuilder, multiJoin.getMultiJoinFields(), adjustments));
        }
        return semiJoinCondition;
    }

    private LcsTable validateKeys(RelNode factRel, List<Integer> leftKeys, List<Integer> rightKeys, List<Integer> actualLeftKeys) {
        int keyIdx = 0;
        RelOptTable theTable = null;
        ListIterator<Integer> keyIter = leftKeys.listIterator();
        while (keyIter.hasNext()) {
            boolean removeKey = false;
            RelColumnOrigin colOrigin = this.mq.getColumnOrigin(factRel, keyIter.next());
            if (colOrigin == null || LucidDbSpecialOperators.isLcsRidColumnId(colOrigin.getOriginColumnOrdinal())) {
                removeKey = true;
            } else {
                RelOptTable table = colOrigin.getOriginTable();
                if (theTable == null) {
                    if (!(table instanceof LcsTable)) {
                        removeKey = true;
                    } else {
                        theTable = table;
                    }
                } else assert (table == theTable);
            }
            if (!removeKey) {
                actualLeftKeys.add(colOrigin.getOriginColumnOrdinal());
                ++keyIdx;
                continue;
            }
            keyIter.remove();
            rightKeys.remove(keyIdx);
        }
        if (actualLeftKeys.isEmpty()) {
            return null;
        }
        return (LcsTable)theTable;
    }

    private RexNode removeExtraFilters(List<Integer> keys, int nFields, RexNode condition) {
        assert (condition instanceof RexCall);
        RexCall call = (RexCall)condition;
        if (condition.isA(SqlKind.AND)) {
            List<RexNode> operands = call.getOperands();
            RexNode left = this.removeExtraFilters(keys, nFields, operands.get(0));
            RexNode right = this.removeExtraFilters(keys, nFields, operands.get(1));
            if (left == null) {
                return right;
            }
            if (right == null) {
                return left;
            }
            return this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, left, right);
        }
        assert (call.getOperator() == SqlStdOperatorTable.EQUALS);
        List<RexNode> operands = call.getOperands();
        assert (operands.get(0) instanceof RexInputRef);
        assert (operands.get(1) instanceof RexInputRef);
        int idx = ((RexInputRef)operands.get(0)).getIndex();
        if (idx < nFields ? !keys.contains(idx) : !keys.contains(idx = ((RexInputRef)operands.get(1)).getIndex())) {
            return null;
        }
        return condition;
    }

    public boolean chooseBestSemiJoin(LoptMultiJoin multiJoin) {
        int nJoinFactors = multiJoin.getNumJoinFactors();
        ImmutableList<Integer> sortedFactors = this.factorCostOrdering.immutableSortedCopy(Util.range(nJoinFactors));
        for (int i = 0; i < nJoinFactors; ++i) {
            Integer factIdx = (Integer)sortedFactors.get(i);
            RelNode factRel = this.chosenSemiJoins[factIdx];
            Map<Integer, SemiJoin> possibleDimensions = this.possibleSemiJoins.get(factIdx);
            if (possibleDimensions == null) continue;
            double bestScore = 0.0;
            int bestDimIdx = -1;
            Set<Integer> dimIdxes = possibleDimensions.keySet();
            for (Integer dimIdx : dimIdxes) {
                double score;
                SemiJoin semiJoin = possibleDimensions.get(dimIdx);
                if (semiJoin == null || !((score = this.computeScore(factRel, this.chosenSemiJoins[dimIdx], semiJoin)) > 10.0) || !(score > bestScore)) continue;
                bestDimIdx = dimIdx;
                bestScore = score;
            }
            if (bestDimIdx == -1) continue;
            SemiJoin semiJoin = possibleDimensions.get(bestDimIdx);
            SemiJoin chosenSemiJoin = SemiJoin.create(factRel, this.chosenSemiJoins[bestDimIdx], semiJoin.getCondition(), semiJoin.getLeftKeys(), semiJoin.getRightKeys());
            this.chosenSemiJoins[factIdx.intValue()] = chosenSemiJoin;
            this.removeJoin(multiJoin, chosenSemiJoin, factIdx, bestDimIdx);
            this.removePossibleSemiJoin(possibleDimensions, factIdx, bestDimIdx);
            this.removePossibleSemiJoin(this.possibleSemiJoins.get(bestDimIdx), bestDimIdx, factIdx);
            return true;
        }
        return false;
    }

    private double computeScore(RelNode factRel, RelNode dimRel, SemiJoin semiJoin) {
        ImmutableBitSet dimCols = ImmutableBitSet.of(semiJoin.getRightKeys());
        double selectivity = RelMdUtil.computeSemiJoinSelectivity(this.mq, factRel, dimRel, semiJoin);
        if (selectivity > 0.5) {
            return 0.0;
        }
        RelOptCost factCost = this.mq.getCumulativeCost(factRel);
        if (factCost == null) {
            return 0.0;
        }
        double savings = (1.0 - Math.sqrt(selectivity)) * Math.max(1.0, factCost.getRows());
        boolean uniq = RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(this.mq, dimRel, dimCols);
        if (uniq) {
            savings *= 2.0;
        }
        Double dimSortCost = this.mq.getRowCount(dimRel);
        Double dupRemCost = uniq ? 0.0 : dimSortCost;
        RelOptCost dimCost = this.mq.getCumulativeCost(dimRel);
        if (dimSortCost == null || dupRemCost == null || dimCost == null) {
            return 0.0;
        }
        Double dimRows = dimCost.getRows();
        if (dimRows < 1.0) {
            dimRows = 1.0;
        }
        return savings / dimRows;
    }

    private void removeJoin(LoptMultiJoin multiJoin, SemiJoin semiJoin, int factIdx, int dimIdx) {
        int i;
        if (multiJoin.getJoinRemovalFactor(dimIdx) != null) {
            return;
        }
        ImmutableBitSet dimKeys = ImmutableBitSet.of(semiJoin.getRightKeys());
        RelNode dimRel = multiJoin.getJoinFactor(dimIdx);
        if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(this.mq, dimRel, dimKeys)) {
            return;
        }
        ImmutableBitSet dimProjRefs = multiJoin.getProjFields(dimIdx);
        if (dimProjRefs == null) {
            int nDimFields = multiJoin.getNumFieldsInJoinFactor(dimIdx);
            dimProjRefs = ImmutableBitSet.range(0, nDimFields);
        }
        if (!dimKeys.contains(dimProjRefs)) {
            return;
        }
        int[] dimJoinRefCounts = multiJoin.getJoinFieldRefCounts(dimIdx);
        for (i = 0; i < dimJoinRefCounts.length; ++i) {
            if (dimJoinRefCounts[i] <= 0 || dimKeys.get(i)) continue;
            return;
        }
        multiJoin.setJoinRemovalFactor(dimIdx, factIdx);
        multiJoin.setJoinRemovalSemiJoin(dimIdx, semiJoin);
        if (dimProjRefs.cardinality() != 0) {
            return;
        }
        for (i = 0; i < dimJoinRefCounts.length; ++i) {
            if (dimJoinRefCounts[i] > 1) {
                return;
            }
            if (dimJoinRefCounts[i] != 1 || dimKeys.get(i)) continue;
            return;
        }
        int[] factJoinRefCounts = multiJoin.getJoinFieldRefCounts(factIdx);
        for (Integer key : semiJoin.getLeftKeys()) {
            int n = key;
            factJoinRefCounts[n] = factJoinRefCounts[n] - 1;
        }
    }

    private void removePossibleSemiJoin(Map<Integer, SemiJoin> possibleDimensions, Integer factIdx, Integer dimIdx) {
        if (possibleDimensions == null) {
            return;
        }
        possibleDimensions.remove(dimIdx);
        if (possibleDimensions.isEmpty()) {
            this.possibleSemiJoins.remove(factIdx);
        } else {
            this.possibleSemiJoins.put(factIdx, possibleDimensions);
        }
    }

    public RelNode getChosenSemiJoin(int factIdx) {
        return this.chosenSemiJoins[factIdx];
    }

    private static class LucidDbSpecialOperators {
        private LucidDbSpecialOperators() {
        }

        public static boolean isLcsRidColumnId(int originColumnOrdinal) {
            return false;
        }
    }

    private static class FemLocalIndex {
        private FemLocalIndex() {
        }
    }

    private static class LcsIndexOptimizer {
        LcsIndexOptimizer(LcsTableScan rel) {
        }

        public FemLocalIndex findSemiJoinIndexByCost(RelNode dimRel, List<Integer> actualLeftKeys, List<Integer> rightKeys, List<Integer> bestKeyOrder) {
            return null;
        }
    }

    private static class LcsTableScan {
        private LcsTableScan() {
        }
    }

    private static abstract class LcsTable
    implements RelOptTable {
        private LcsTable() {
        }
    }

    private class FactorCostComparator
    implements Comparator<Integer> {
        private FactorCostComparator() {
        }

        @Override
        public int compare(Integer rel1Idx, Integer rel2Idx) {
            RelOptCost c1 = LoptSemiJoinOptimizer.this.mq.getCumulativeCost(LoptSemiJoinOptimizer.this.chosenSemiJoins[rel1Idx]);
            RelOptCost c2 = LoptSemiJoinOptimizer.this.mq.getCumulativeCost(LoptSemiJoinOptimizer.this.chosenSemiJoins[rel2Idx]);
            if (c1 == null || c2 == null) {
                return -1;
            }
            return c1.isLt(c2) ? -1 : (c1.equals(c2) ? 0 : 1);
        }
    }
}

