/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.org.apache.calcite.rel.rules;

import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.org.apache.calcite.plan.RelOptRule;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleCall;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleOperand;
import com.hazelcast.org.apache.calcite.plan.RelOptUtil;
import com.hazelcast.org.apache.calcite.rel.RelNode;
import com.hazelcast.org.apache.calcite.rel.core.Aggregate;
import com.hazelcast.org.apache.calcite.rel.core.AggregateCall;
import com.hazelcast.org.apache.calcite.rel.core.Join;
import com.hazelcast.org.apache.calcite.rel.core.JoinRelType;
import com.hazelcast.org.apache.calcite.rel.core.RelFactories;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalAggregate;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalJoin;
import com.hazelcast.org.apache.calcite.rel.rules.TransformationRule;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexUtil;
import com.hazelcast.org.apache.calcite.tools.RelBuilder;
import com.hazelcast.org.apache.calcite.tools.RelBuilderFactory;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.org.apache.calcite.util.mapping.Mappings;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Set;

public class AggregateJoinJoinRemoveRule
extends RelOptRule
implements TransformationRule {
    public static final AggregateJoinJoinRemoveRule INSTANCE = new AggregateJoinJoinRemoveRule(LogicalAggregate.class, LogicalJoin.class, RelFactories.LOGICAL_BUILDER);

    public AggregateJoinJoinRemoveRule(Class<? extends Aggregate> aggregateClass, Class<? extends Join> joinClass, RelBuilderFactory relBuilderFactory) {
        super(AggregateJoinJoinRemoveRule.operand(aggregateClass, AggregateJoinJoinRemoveRule.operandJ(joinClass, null, join -> join.getJoinType() == JoinRelType.LEFT, AggregateJoinJoinRemoveRule.operandJ(joinClass, null, join -> join.getJoinType() == JoinRelType.LEFT, AggregateJoinJoinRemoveRule.any()), new RelOptRuleOperand[0]), new RelOptRuleOperand[0]), relBuilderFactory, null);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Join topJoin = (Join)call.rel(1);
        Join bottomJoin = (Join)call.rel(2);
        int leftBottomChildSize = bottomJoin.getLeft().getRowType().getFieldCount();
        Set<Integer> allFields = RelOptUtil.getAllFields(aggregate);
        if (allFields.stream().anyMatch(i -> i >= leftBottomChildSize && i < bottomJoin.getRowType().getFieldCount())) {
            return;
        }
        if (aggregate.getAggCallList().stream().anyMatch(aggregateCall -> !aggregateCall.isDistinct())) {
            return;
        }
        ArrayList<Integer> leftKeys = new ArrayList<Integer>();
        RelOptUtil.splitJoinCondition(topJoin.getLeft(), topJoin.getRight(), topJoin.getCondition(), leftKeys, new ArrayList<Integer>(), new ArrayList<Boolean>());
        if (leftKeys.stream().anyMatch(s -> s >= leftBottomChildSize)) {
            return;
        }
        ArrayList<Integer> leftChildKeys = new ArrayList<Integer>();
        RelOptUtil.splitJoinCondition(bottomJoin.getLeft(), bottomJoin.getRight(), bottomJoin.getCondition(), leftChildKeys, new ArrayList<Integer>(), new ArrayList<Boolean>());
        if (!leftKeys.equals(leftChildKeys)) {
            return;
        }
        int offset = bottomJoin.getRight().getRowType().getFieldCount();
        RelBuilder relBuilder = call.builder();
        RexNode condition = RexUtil.shift(topJoin.getCondition(), leftBottomChildSize, -offset);
        RelNode join = relBuilder.push(bottomJoin.getLeft()).push(topJoin.getRight()).join(topJoin.getJoinType(), condition).build();
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        allFields.forEach(index -> map.put((Integer)index, index < leftBottomChildSize ? index : index - offset));
        ImmutableBitSet groupSet = aggregate.getGroupSet().permute(map);
        ImmutableList.Builder aggCalls = ImmutableList.builder();
        int sourceCount = aggregate.getInput().getRowType().getFieldCount();
        Mappings.TargetMapping targetMapping = Mappings.target(map, sourceCount, sourceCount);
        aggregate.getAggCallList().forEach(aggregateCall -> aggCalls.add(aggregateCall.transform(targetMapping)));
        RelNode newAggregate = relBuilder.push(join).aggregate(relBuilder.groupKey(groupSet), (List<AggregateCall>)((Object)aggCalls.build())).build();
        call.transformTo(newAggregate);
    }
}

