/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.transform;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.repackaged.sql.org.apache.calcite.rel.core.JoinRelType;
import org.apache.beam.repackaged.sql.org.apache.calcite.rex.RexCall;
import org.apache.beam.repackaged.sql.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.repackaged.sql.org.apache.calcite.rex.RexNode;
import org.apache.beam.repackaged.sql.org.apache.calcite.util.Pair;
import org.apache.beam.sdk.extensions.sql.BeamSqlSeekableTable;
import org.apache.beam.sdk.extensions.sql.impl.utils.SerializableRexFieldAccess;
import org.apache.beam.sdk.extensions.sql.impl.utils.SerializableRexInputRef;
import org.apache.beam.sdk.extensions.sql.impl.utils.SerializableRexNode;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.Row;

public class BeamJoinTransforms {
    private static Row combineTwoRowsIntoOne(Row leftRow, Row rightRow, boolean swap, Schema outputSchema) {
        if (swap) {
            return BeamJoinTransforms.combineTwoRowsIntoOneHelper(rightRow, leftRow, outputSchema);
        }
        return BeamJoinTransforms.combineTwoRowsIntoOneHelper(leftRow, rightRow, outputSchema);
    }

    private static Row combineTwoRowsIntoOneHelper(Row leftRow, Row rightRow, Schema ouputSchema) {
        return Row.withSchema((Schema)ouputSchema).addValues(leftRow.getValues()).addValues(rightRow.getValues()).build();
    }

    public static class JoinAsLookup
    extends PTransform<PCollection<Row>, PCollection<Row>> {
        private final BeamSqlSeekableTable seekableTable;
        private final Schema lkpSchema;
        private final int factColOffset;
        private Schema joinSubsetType;
        private final Schema outputSchema;
        private List<Integer> factJoinIdx;

        public JoinAsLookup(RexNode joinCondition, BeamSqlSeekableTable seekableTable, Schema lkpSchema, Schema outputSchema, int factColOffset, int lkpColOffset) {
            this.seekableTable = seekableTable;
            this.lkpSchema = lkpSchema;
            this.outputSchema = outputSchema;
            this.factColOffset = factColOffset;
            this.joinFieldsMapping(joinCondition, factColOffset, lkpColOffset);
        }

        private void joinFieldsMapping(RexNode joinCondition, int factColOffset, int lkpColOffset) {
            this.factJoinIdx = new ArrayList<Integer>();
            ArrayList<Schema.Field> lkpJoinFields = new ArrayList<Schema.Field>();
            RexCall call = (RexCall)joinCondition;
            if ("AND".equals(call.getOperator().getName())) {
                List<RexNode> operands = call.getOperands();
                for (RexNode rexNode : operands) {
                    this.factJoinIdx.add(((RexInputRef)((RexCall)rexNode).getOperands().get(0)).getIndex() - factColOffset);
                    int lkpJoinIdx = ((RexInputRef)((RexCall)rexNode).getOperands().get(1)).getIndex() - lkpColOffset;
                    lkpJoinFields.add(this.lkpSchema.getField(lkpJoinIdx));
                }
            } else if ("=".equals(call.getOperator().getName())) {
                this.factJoinIdx.add(((RexInputRef)call.getOperands().get(0)).getIndex() - factColOffset);
                int lkpJoinIdx = ((RexInputRef)call.getOperands().get(1)).getIndex() - lkpColOffset;
                lkpJoinFields.add(this.lkpSchema.getField(lkpJoinIdx));
            } else {
                throw new UnsupportedOperationException("Operator " + call.getOperator().getName() + " is not supported in join condition");
            }
            this.joinSubsetType = Schema.builder().addFields(lkpJoinFields).build();
        }

        public PCollection<Row> expand(PCollection<Row> input) {
            return ((PCollection)input.apply("join_as_lookup", (PTransform)ParDo.of((DoFn)new DoFn<Row, Row>(){

                @DoFn.Setup
                public void setup() {
                    seekableTable.setUp();
                }

                @DoFn.ProcessElement
                public void processElement(DoFn.ProcessContext context) {
                    Row factRow = (Row)context.element();
                    Row joinSubRow = this.extractJoinSubRow(factRow);
                    List<Row> lookupRows = seekableTable.seekRow(joinSubRow);
                    for (Row lr : lookupRows) {
                        context.output((Object)BeamJoinTransforms.combineTwoRowsIntoOne(factRow, lr, factColOffset != 0, outputSchema));
                    }
                }

                @DoFn.Teardown
                public void teardown() {
                    seekableTable.tearDown();
                }

                private Row extractJoinSubRow(Row factRow) {
                    List joinSubsetValues = factJoinIdx.stream().map(arg_0 -> ((Row)factRow).getValue(arg_0)).collect(Collectors.toList());
                    return Row.withSchema((Schema)joinSubsetType).addValues(joinSubsetValues).build();
                }
            }))).setRowSchema(this.joinSubsetType);
        }
    }

    public static class JoinParts2WholeRow
    extends SimpleFunction<KV<Row, KV<Row, Row>>, Row> {
        private final Schema schema;

        public JoinParts2WholeRow(Schema schema) {
            this.schema = schema;
        }

        public Row apply(KV<Row, KV<Row, Row>> input) {
            KV parts = (KV)input.getValue();
            Row leftRow = (Row)parts.getKey();
            Row rightRow = (Row)parts.getValue();
            return BeamJoinTransforms.combineTwoRowsIntoOne(leftRow, rightRow, false, this.schema);
        }
    }

    public static class SideInputJoinDoFn
    extends DoFn<KV<Row, Row>, Row> {
        private final PCollectionView<Map<Row, Iterable<Row>>> sideInputView;
        private final JoinRelType joinType;
        private final Row rightNullRow;
        private final boolean swap;
        private final Schema schema;

        public SideInputJoinDoFn(JoinRelType joinType, Row rightNullRow, PCollectionView<Map<Row, Iterable<Row>>> sideInputView, boolean swap, Schema schema) {
            this.joinType = joinType;
            this.rightNullRow = rightNullRow;
            this.sideInputView = sideInputView;
            this.swap = swap;
            this.schema = schema;
        }

        @DoFn.ProcessElement
        public void processElement(DoFn.ProcessContext context) {
            Row key = (Row)((KV)context.element()).getKey();
            Row leftRow = (Row)((KV)context.element()).getValue();
            Map key2Rows = (Map)context.sideInput(this.sideInputView);
            Iterable rightRowsIterable = (Iterable)key2Rows.get(key);
            if (rightRowsIterable != null && rightRowsIterable.iterator().hasNext()) {
                for (Row aRightRowsIterable : rightRowsIterable) {
                    context.output((Object)BeamJoinTransforms.combineTwoRowsIntoOne(leftRow, aRightRowsIterable, this.swap, this.schema));
                }
            } else if (this.joinType == JoinRelType.LEFT) {
                context.output((Object)BeamJoinTransforms.combineTwoRowsIntoOne(leftRow, this.rightNullRow, this.swap, this.schema));
            }
        }
    }

    public static class ExtractJoinFields
    extends SimpleFunction<Row, KV<Row, Row>> {
        private final List<SerializableRexNode> joinColumns;
        private final Schema schema;
        private int leftRowColumnCount;

        public ExtractJoinFields(boolean isLeft, List<Pair<RexNode, RexNode>> joinColumns, Schema schema, int leftRowColumnCount) {
            this.joinColumns = joinColumns.stream().map(pair -> SerializableRexNode.builder(isLeft ? (RexNode)pair.left : (RexNode)pair.right).build()).collect(Collectors.toList());
            this.schema = schema;
            this.leftRowColumnCount = leftRowColumnCount;
        }

        public KV<Row, Row> apply(Row input) {
            Row row = (Row)this.joinColumns.stream().map(v -> this.getValue((SerializableRexNode)v, input, this.leftRowColumnCount)).collect(Row.toRow((Schema)this.schema));
            return KV.of((Object)row, (Object)input);
        }

        private Schema.Field toField(Schema schema, Integer fieldIndex) {
            Schema.Field original = schema.getField(fieldIndex.intValue());
            return original.withName("c" + fieldIndex);
        }

        private Object getValue(SerializableRexNode serializableRexNode, Row input, int leftRowColumnCount) {
            if (serializableRexNode instanceof SerializableRexInputRef) {
                return input.getValue(((SerializableRexInputRef)serializableRexNode).getIndex() - leftRowColumnCount);
            }
            List<Integer> indexes = ((SerializableRexFieldAccess)serializableRexNode).getIndexes();
            Row rowField = (Row)input.getValue(indexes.get(0) - leftRowColumnCount);
            for (int i = 1; i < indexes.size() - 1; ++i) {
                rowField = rowField.getRow(indexes.get(i).intValue());
            }
            return rowField.getValue(indexes.get(indexes.size() - 1).intValue());
        }
    }
}

