package org.apache.flink.table.planner.plan.rules.logical;

import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.AbstractRelNode;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
import org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule;
import org.apache.flink.table.planner.plan.utils.PythonUtil;
import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
import scala.collection.Iterator;
import scala.collection.mutable.ArrayBuffer;

/* loaded from: input_file:org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.class */
public class PythonCorrelateSplitRule extends RelOptRule {
    public static final PythonCorrelateSplitRule INSTANCE = new PythonCorrelateSplitRule();

    private PythonCorrelateSplitRule() {
        super(operand(FlinkLogicalCorrelate.class, any()), "PythonCorrelateSplitRule");
    }

    private FlinkLogicalTableFunctionScan createNewScan(FlinkLogicalTableFunctionScan flinkLogicalTableFunctionScan, ScalarFunctionSplitter scalarFunctionSplitter) {
        RexCall rexCall = (RexCall) flinkLogicalTableFunctionScan.getCall();
        return new FlinkLogicalTableFunctionScan(flinkLogicalTableFunctionScan.getCluster(), flinkLogicalTableFunctionScan.getTraitSet(), flinkLogicalTableFunctionScan.getInputs(), rexCall.clone(rexCall.getType(), (List) rexCall.getOperands().stream().map(rexNode -> {
            return (RexNode) rexNode.accept(scalarFunctionSplitter);
        }).collect(Collectors.toList())), flinkLogicalTableFunctionScan.getElementType(), flinkLogicalTableFunctionScan.getRowType(), flinkLogicalTableFunctionScan.getColumnMappings());
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        FlinkLogicalTableFunctionScan tableScan;
        RelNode currentRel = ((HepRelVertex) ((FlinkLogicalCorrelate) relOptRuleCall.rel(0)).getRight()).getCurrentRel();
        if (currentRel instanceof FlinkLogicalTableFunctionScan) {
            tableScan = (FlinkLogicalTableFunctionScan) currentRel;
        } else {
            if (!(currentRel instanceof FlinkLogicalCalc)) {
                return false;
            }
            tableScan = StreamPhysicalCorrelateRule.getTableScan((FlinkLogicalCalc) currentRel);
        }
        RexNode call = tableScan.getCall();
        if (call instanceof RexCall) {
            return (PythonUtil.isPythonCall(call, null) && PythonUtil.containsNonPythonCall(call)) || (PythonUtil.isNonPythonCall(call) && PythonUtil.containsPythonCall(call, null)) || (PythonUtil.isPythonCall(call, null) && RexUtil.containsFieldAccess(call));
        }
        return false;
    }

    private List<String> createNewFieldNames(RelDataType relDataType, final RexBuilder rexBuilder, int i, ArrayBuffer<RexNode> arrayBuffer, List<RexNode> list) {
        for (int i2 = 0; i2 < i; i2++) {
            list.add(RexInputRef.of(i2, relDataType));
        }
        RexDefaultVisitor<RexNode> rexDefaultVisitor = new RexDefaultVisitor<RexNode>() { // from class: org.apache.flink.table.planner.plan.rules.logical.PythonCorrelateSplitRule.1
            @Override // org.apache.flink.table.planner.plan.utils.RexDefaultVisitor, org.apache.calcite.rex.RexVisitor
            /* renamed from: visitFieldAccess */
            public RexNode mo5213visitFieldAccess(RexFieldAccess rexFieldAccess) {
                RexNode referenceExpr = rexFieldAccess.getReferenceExpr();
                if (!(referenceExpr instanceof RexCorrelVariable)) {
                    return rexBuilder.makeFieldAccess((RexNode) referenceExpr.accept(this), rexFieldAccess.getField().getIndex());
                }
                RelDataTypeField field = rexFieldAccess.getField();
                return new RexInputRef(field.getIndex(), field.getType());
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.apache.flink.table.planner.plan.utils.RexDefaultVisitor
            /* renamed from: visitNode */
            public RexNode mo5218visitNode(RexNode rexNode) {
                return rexNode;
            }
        };
        Iterator it = arrayBuffer.iterator();
        while (it.hasNext()) {
            RexNode rexNode = (RexNode) it.next();
            if (rexNode instanceof RexCall) {
                RexCall rexCall = (RexCall) rexNode;
                list.add(rexCall.clone(rexCall.getType(), (List) rexCall.getOperands().stream().map(rexNode2 -> {
                    return (RexNode) rexNode2.accept(rexDefaultVisitor);
                }).collect(Collectors.toList())));
            } else {
                list.add(rexNode);
            }
        }
        LinkedList linkedList = new LinkedList();
        for (int i3 = 0; i3 < i; i3++) {
            linkedList.add(relDataType.getFieldNames().get(i3));
        }
        Iterator it2 = arrayBuffer.indices().iterator();
        while (it2.hasNext()) {
            linkedList.add("f" + it2.next());
        }
        return SqlValidatorUtil.uniquify(linkedList, rexBuilder.getTypeFactory().getTypeSystem().isSchemaCaseSensitive());
    }

    private FlinkLogicalCalc createNewLeftCalc(RelNode relNode, RexBuilder rexBuilder, ArrayBuffer<RexNode> arrayBuffer, FlinkLogicalCorrelate flinkLogicalCorrelate) {
        LinkedList linkedList = new LinkedList();
        RelDataType rowType = relNode.getRowType();
        return new FlinkLogicalCalc(flinkLogicalCorrelate.getCluster(), flinkLogicalCorrelate.getTraitSet(), relNode, RexProgram.create(rowType, linkedList, (RexNode) null, createNewFieldNames(rowType, rexBuilder, rowType.getFieldCount(), arrayBuffer, linkedList), rexBuilder));
    }

    private FlinkLogicalCalc createTopCalc(int i, RexBuilder rexBuilder, ArrayBuffer<RexNode> arrayBuffer, RelDataType relDataType, FlinkLogicalCorrelate flinkLogicalCorrelate) {
        RexProgram program = new RexProgramBuilder(flinkLogicalCorrelate.getRowType(), rexBuilder).getProgram();
        int size = arrayBuffer.size() + i;
        return new FlinkLogicalCalc(flinkLogicalCorrelate.getCluster(), flinkLogicalCorrelate.getTraitSet(), flinkLogicalCorrelate, RexProgram.create(flinkLogicalCorrelate.getRowType(), (List<? extends RexNode>) program.getExprList().stream().filter(rexNode -> {
            return rexNode instanceof RexInputRef;
        }).filter(rexNode2 -> {
            int index = ((RexInputRef) rexNode2).getIndex();
            return index < i || index >= size;
        }).collect(Collectors.toList()), (RexNode) null, relDataType, rexBuilder));
    }

    private ScalarFunctionSplitter createScalarFunctionSplitter(RexProgram rexProgram, RexBuilder rexBuilder, int i, ArrayBuffer<RexNode> arrayBuffer, RexNode rexNode) {
        return new ScalarFunctionSplitter(rexProgram, rexBuilder, i, arrayBuffer, rexNode2 -> {
            return PythonUtil.isNonPythonCall(rexNode) ? Boolean.valueOf(PythonUtil.isPythonCall(rexNode2, null)) : PythonUtil.containsNonPythonCall(rexNode2) ? Boolean.valueOf(PythonUtil.isNonPythonCall(rexNode2)) : Boolean.valueOf(rexNode2 instanceof RexFieldAccess);
        }, new PythonRemoteCalcCallFinder());
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        AbstractRelNode copy;
        FlinkLogicalCorrelate flinkLogicalCorrelate;
        FlinkLogicalCorrelate flinkLogicalCorrelate2 = (FlinkLogicalCorrelate) relOptRuleCall.rel(0);
        RexBuilder rexBuilder = relOptRuleCall.builder().getRexBuilder();
        RelNode currentRel = ((HepRelVertex) flinkLogicalCorrelate2.getLeft()).getCurrentRel();
        RelNode currentRel2 = ((HepRelVertex) flinkLogicalCorrelate2.getRight()).getCurrentRel();
        int fieldCount = currentRel.getRowType().getFieldCount();
        ArrayBuffer<RexNode> arrayBuffer = new ArrayBuffer<>();
        if (currentRel2 instanceof FlinkLogicalTableFunctionScan) {
            FlinkLogicalTableFunctionScan flinkLogicalTableFunctionScan = (FlinkLogicalTableFunctionScan) currentRel2;
            copy = createNewScan(flinkLogicalTableFunctionScan, createScalarFunctionSplitter(null, rexBuilder, fieldCount, arrayBuffer, flinkLogicalTableFunctionScan.getCall()));
        } else {
            FlinkLogicalCalc flinkLogicalCalc = (FlinkLogicalCalc) currentRel2;
            FlinkLogicalTableFunctionScan tableScan = StreamPhysicalCorrelateRule.getTableScan(flinkLogicalCalc);
            FlinkLogicalCalc mergedCalc = StreamPhysicalCorrelateRule.getMergedCalc(flinkLogicalCalc);
            copy = mergedCalc.copy(mergedCalc.getTraitSet(), createNewScan(tableScan, createScalarFunctionSplitter(null, rexBuilder, fieldCount, arrayBuffer, tableScan.getCall())), mergedCalc.getProgram());
        }
        if (arrayBuffer.size() > 0) {
            flinkLogicalCorrelate = new FlinkLogicalCorrelate(flinkLogicalCorrelate2.getCluster(), flinkLogicalCorrelate2.getTraitSet(), createNewLeftCalc(currentRel, rexBuilder, arrayBuffer, flinkLogicalCorrelate2), copy, flinkLogicalCorrelate2.getCorrelationId(), flinkLogicalCorrelate2.getRequiredColumns(), flinkLogicalCorrelate2.getJoinType());
        } else {
            flinkLogicalCorrelate = new FlinkLogicalCorrelate(flinkLogicalCorrelate2.getCluster(), flinkLogicalCorrelate2.getTraitSet(), currentRel, copy, flinkLogicalCorrelate2.getCorrelationId(), flinkLogicalCorrelate2.getRequiredColumns(), flinkLogicalCorrelate2.getJoinType());
        }
        relOptRuleCall.transformTo(createTopCalc(fieldCount, rexBuilder, arrayBuffer, flinkLogicalCorrelate2.getRowType(), flinkLogicalCorrelate));
    }
}
