/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.transform;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.beam.sdk.extensions.sql.BeamSqlSeekableTable;
import org.apache.beam.sdk.extensions.sql.impl.utils.SerializableRexFieldAccess;
import org.apache.beam.sdk.extensions.sql.impl.utils.SerializableRexInputRef;
import org.apache.beam.sdk.extensions.sql.impl.utils.SerializableRexNode;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexCall;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.Pair;
import org.checkerframework.checker.initialization.qual.Initialized;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.UnknownKeyFor;

public class BeamJoinTransforms {
    public static @UnknownKeyFor @NonNull @Initialized FieldAccessDescriptor getJoinColumns(@UnknownKeyFor @NonNull @Initialized boolean isLeft, @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized Pair<@UnknownKeyFor @NonNull @Initialized RexNode, @UnknownKeyFor @NonNull @Initialized RexNode>> joinColumns, @UnknownKeyFor @NonNull @Initialized int leftRowColumnCount, @UnknownKeyFor @NonNull @Initialized Schema schema) {
        List joinColumnsBuilt = joinColumns.stream().map(pair -> SerializableRexNode.builder(isLeft ? (RexNode)pair.left : (RexNode)pair.right).build()).collect(Collectors.toList());
        return FieldAccessDescriptor.union((Iterable)joinColumnsBuilt.stream().map(v -> BeamJoinTransforms.getJoinColumn(v, leftRowColumnCount).resolve(schema)).collect(Collectors.toList()));
    }

    private static @UnknownKeyFor @NonNull @Initialized FieldAccessDescriptor getJoinColumn(@UnknownKeyFor @NonNull @Initialized SerializableRexNode serializableRexNode, @UnknownKeyFor @NonNull @Initialized int leftRowColumnCount) {
        if (serializableRexNode instanceof SerializableRexInputRef) {
            SerializableRexInputRef inputRef = (SerializableRexInputRef)serializableRexNode;
            return FieldAccessDescriptor.withFieldIds((Integer[])new Integer[]{inputRef.getIndex() - leftRowColumnCount});
        }
        List<Integer> indexes = ((SerializableRexFieldAccess)serializableRexNode).getIndexes();
        FieldAccessDescriptor fieldAccessDescriptor = FieldAccessDescriptor.withFieldIds((Integer[])new Integer[]{indexes.get(0) - leftRowColumnCount});
        for (int i = 1; i < indexes.size(); ++i) {
            fieldAccessDescriptor = FieldAccessDescriptor.withFieldIds((FieldAccessDescriptor)fieldAccessDescriptor, (Integer[])new Integer[]{indexes.get(i)});
        }
        return fieldAccessDescriptor;
    }

    private static @UnknownKeyFor @NonNull @Initialized Row combineTwoRowsIntoOne(@UnknownKeyFor @NonNull @Initialized Row leftRow, @UnknownKeyFor @NonNull @Initialized Row rightRow, @UnknownKeyFor @NonNull @Initialized boolean swap, @UnknownKeyFor @NonNull @Initialized Schema outputSchema) {
        if (swap) {
            return BeamJoinTransforms.combineTwoRowsIntoOneHelper(rightRow, leftRow, outputSchema);
        }
        return BeamJoinTransforms.combineTwoRowsIntoOneHelper(leftRow, rightRow, outputSchema);
    }

    private static @UnknownKeyFor @NonNull @Initialized Row combineTwoRowsIntoOneHelper(@UnknownKeyFor @NonNull @Initialized Row leftRow, @UnknownKeyFor @NonNull @Initialized Row rightRow, @UnknownKeyFor @NonNull @Initialized Schema ouputSchema) {
        return Row.withSchema((Schema)ouputSchema).addValues(leftRow.getBaseValues()).addValues(rightRow.getBaseValues()).build();
    }

    public static class JoinAsLookup
    extends PTransform<PCollection<Row>, PCollection<Row>> {
        private final @UnknownKeyFor @NonNull @Initialized BeamSqlSeekableTable seekableTable;
        private final @UnknownKeyFor @NonNull @Initialized Schema lkpSchema;
        private final @UnknownKeyFor @NonNull @Initialized int factColOffset;
        private @UnknownKeyFor @NonNull @Initialized Schema joinSubsetType;
        private final @UnknownKeyFor @NonNull @Initialized Schema outputSchema;
        private @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized Integer> factJoinIdx;

        public JoinAsLookup(@UnknownKeyFor @NonNull @Initialized RexNode joinCondition, @UnknownKeyFor @NonNull @Initialized BeamSqlSeekableTable seekableTable, @UnknownKeyFor @NonNull @Initialized Schema lkpSchema, @UnknownKeyFor @NonNull @Initialized Schema outputSchema, @UnknownKeyFor @NonNull @Initialized int factColOffset, @UnknownKeyFor @NonNull @Initialized int lkpColOffset) {
            this.seekableTable = seekableTable;
            this.lkpSchema = lkpSchema;
            this.outputSchema = outputSchema;
            this.factColOffset = factColOffset;
            this.joinFieldsMapping(joinCondition, factColOffset, lkpColOffset);
        }

        private void joinFieldsMapping(@UnknownKeyFor @NonNull @Initialized RexNode joinCondition, @UnknownKeyFor @NonNull @Initialized int factColOffset, @UnknownKeyFor @NonNull @Initialized int lkpColOffset) {
            this.factJoinIdx = new ArrayList<Integer>();
            ArrayList<Schema.Field> lkpJoinFields = new ArrayList<Schema.Field>();
            RexCall call = (RexCall)joinCondition;
            if ("AND".equals(call.getOperator().getName())) {
                List operands = call.getOperands();
                for (RexNode rexNode : operands) {
                    this.factJoinIdx.add(((RexInputRef)((RexCall)rexNode).getOperands().get(0)).getIndex() - factColOffset);
                    int lkpJoinIdx = ((RexInputRef)((RexCall)rexNode).getOperands().get(1)).getIndex() - lkpColOffset;
                    lkpJoinFields.add(this.lkpSchema.getField(lkpJoinIdx));
                }
            } else if ("=".equals(call.getOperator().getName())) {
                this.factJoinIdx.add(((RexInputRef)call.getOperands().get(0)).getIndex() - factColOffset);
                int lkpJoinIdx = ((RexInputRef)call.getOperands().get(1)).getIndex() - lkpColOffset;
                lkpJoinFields.add(this.lkpSchema.getField(lkpJoinIdx));
            } else {
                throw new UnsupportedOperationException("Operator " + call.getOperator().getName() + " is not supported in join condition");
            }
            this.joinSubsetType = Schema.builder().addFields(lkpJoinFields).build();
        }

        public @UnknownKeyFor @NonNull @Initialized PCollection<@UnknownKeyFor @NonNull @Initialized Row> expand(@UnknownKeyFor @NonNull @Initialized PCollection<@UnknownKeyFor @NonNull @Initialized Row> input) {
            return ((PCollection)input.apply("join_as_lookup", (PTransform)ParDo.of((DoFn)new DoFn<Row, Row>(){

                @DoFn.Setup
                public void setup() {
                    seekableTable.setUp();
                }

                @DoFn.ProcessElement
                public void processElement(/*
                 * Issues handling annotations - annotations may be inaccurate
                 */
                // Could not load outer class - annotation placement on inner may be incorrect
                @UnknownKeyFor @UnknownKeyFor @UnknownKeyFor @NonNull @Initialized @NonNull @Initialized @NonNull @Initialized DoFn. @UnknownKeyFor @NonNull @Initialized ProcessContext context) {
                    Row factRow = (Row)context.element();
                    Row joinSubRow = this.extractJoinSubRow(factRow);
                    List<Row> lookupRows = seekableTable.seekRow(joinSubRow);
                    for (Row lr : lookupRows) {
                        context.output((Object)BeamJoinTransforms.combineTwoRowsIntoOne(factRow, lr, factColOffset != 0, outputSchema));
                    }
                }

                @DoFn.Teardown
                public void teardown() {
                    seekableTable.tearDown();
                }

                private @UnknownKeyFor @NonNull @Initialized Row extractJoinSubRow(@UnknownKeyFor @NonNull @Initialized Row factRow) {
                    List joinSubsetValues = factJoinIdx.stream().map(i -> factRow.getBaseValue(i.intValue(), Object.class)).collect(Collectors.toList());
                    return Row.withSchema((Schema)joinSubsetType).addValues(joinSubsetValues).build();
                }
            }))).setRowSchema(this.joinSubsetType);
        }
    }
}

