/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.rel;

import java.io.Serializable;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.NavigableMap;
import java.util.TreeMap;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSortRel;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamBuiltinAnalyticFunctions;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.AggregationCombineFnAdapter;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.transforms.Group;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.plan.RelOptCluster;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.plan.RelOptPlanner;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.plan.RelTraitSet;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelFieldCollation;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.Window;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rex.RexLiteral;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;

public class BeamWindowRel
extends Window
implements BeamRelNode {
    public BeamWindowRel(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, List<RexLiteral> constants, RelDataType rowType, List<Window.Group> groups) {
        super(cluster, traitSet, input, constants, rowType, groups);
    }

    @Override
    public PTransform<PCollectionList<Row>, PCollection<Row>> buildPTransform() {
        Schema outputSchema = CalciteUtils.toSchema(this.getRowType());
        ArrayList analyticFields = Lists.newArrayList();
        this.groups.stream().forEach(anAnalyticGroup -> {
            List partitionKeysDef = anAnalyticGroup.keys.toList();
            ArrayList orderByKeys = Lists.newArrayList();
            ArrayList orderByDirections = Lists.newArrayList();
            ArrayList orderByNullDirections = Lists.newArrayList();
            anAnalyticGroup.orderKeys.getFieldCollations().stream().forEach(fc -> {
                orderByKeys.add(fc.getFieldIndex());
                orderByDirections.add(fc.direction == RelFieldCollation.Direction.ASCENDING);
                orderByNullDirections.add(fc.nullDirection == RelFieldCollation.NullDirection.FIRST);
            });
            BigDecimal lowerB = null;
            BigDecimal upperB = null;
            if (anAnalyticGroup.lowerBound.isCurrentRow()) {
                lowerB = BigDecimal.ZERO;
            } else if (anAnalyticGroup.lowerBound.isPreceding()) {
                if (!anAnalyticGroup.lowerBound.isUnbounded()) {
                    lowerB = this.getLiteralValueConstants(anAnalyticGroup.lowerBound.getOffset());
                }
            } else if (anAnalyticGroup.lowerBound.isFollowing() && !anAnalyticGroup.lowerBound.isUnbounded()) {
                lowerB = this.getLiteralValueConstants(anAnalyticGroup.lowerBound.getOffset()).negate();
            }
            if (anAnalyticGroup.upperBound.isCurrentRow()) {
                upperB = BigDecimal.ZERO;
            } else if (anAnalyticGroup.upperBound.isPreceding()) {
                if (!anAnalyticGroup.upperBound.isUnbounded()) {
                    upperB = this.getLiteralValueConstants(anAnalyticGroup.upperBound.getOffset()).negate();
                }
            } else if (anAnalyticGroup.upperBound.isFollowing() && !anAnalyticGroup.upperBound.isUnbounded()) {
                upperB = this.getLiteralValueConstants(anAnalyticGroup.upperBound.getOffset());
            }
            BigDecimal lowerBFinal = lowerB;
            BigDecimal upperBFinal = upperB;
            List aggregateCalls = anAnalyticGroup.getAggregateCalls((Window)this);
            aggregateCalls.stream().forEach(anAggCall -> {
                List argList = anAggCall.getArgList();
                Schema.Field field = CalciteUtils.toField(anAggCall.getName(), anAggCall.getType());
                Combine.CombineFn<?, ?, ?> combineFn = AggregationCombineFnAdapter.createCombineFnAnalyticsFunctions(anAggCall, field, anAggCall.getAggregation().getName());
                FieldAggregation fieldAggregation = new FieldAggregation(partitionKeysDef, orderByKeys, orderByDirections, orderByNullDirections, lowerBFinal, upperBFinal, anAnalyticGroup.isRows, argList, combineFn, field);
                analyticFields.add(fieldAggregation);
            });
        });
        return new Transform(outputSchema, analyticFields);
    }

    private BigDecimal getLiteralValueConstants(RexNode n) {
        int idx = ((RexInputRef)n).getIndex() - this.input.getRowType().getFieldCount();
        return (BigDecimal)((RexLiteral)this.constants.get(idx)).getValue();
    }

    @Override
    public NodeStats estimateNodeStats(RelMetadataQuery mq) {
        NodeStats inputStat = BeamSqlRelUtils.getNodeStats(this.input, mq);
        return inputStat;
    }

    @Override
    public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
        NodeStats inputStat = BeamSqlRelUtils.getNodeStats(this.input, mq);
        float multiplier = 1.125f;
        return BeamCostModel.FACTORY.makeCost(inputStat.getRowCount() * (double)multiplier, inputStat.getRate() * (double)multiplier);
    }

    private static DoFn<List<Row>, Row> aggField(final Schema expectedSchema, final FieldAggregation fieldAgg) {
        return new DoFn<List<Row>, Row>(){

            @DoFn.ProcessElement
            public void processElement(@DoFn.Element List<Row> inputPartition, DoFn.OutputReceiver<Row> out, DoFn.ProcessContext c) {
                List<Row> sortedRowsAsList = inputPartition;
                NavigableMap<BigDecimal, List<Row>> indexRange = null;
                if (!fieldAgg.rows) {
                    indexRange = this.indexRows(sortedRowsAsList);
                }
                for (int idx = 0; idx < sortedRowsAsList.size(); ++idx) {
                    List<Row> aggRange = null;
                    aggRange = fieldAgg.rows ? this.getRows(sortedRowsAsList, idx) : this.getRange(indexRange, sortedRowsAsList.get(idx));
                    Object accumulator = fieldAgg.combineFn.createAccumulator();
                    int aggFieldIndex = fieldAgg.inputFields.isEmpty() ? -1 : (Integer)fieldAgg.inputFields.get(0);
                    long count = 0L;
                    for (Row aggRow : aggRange) {
                        if (fieldAgg.combineFn instanceof BeamBuiltinAnalyticFunctions.PositionAwareCombineFn) {
                            BeamBuiltinAnalyticFunctions.PositionAwareCombineFn fn = (BeamBuiltinAnalyticFunctions.PositionAwareCombineFn)fieldAgg.combineFn;
                            accumulator = fn.addInput(accumulator, this.getOrderByValue(aggRow), count, Long.valueOf(idx), Long.valueOf(sortedRowsAsList.size()));
                        } else {
                            accumulator = fieldAgg.combineFn.addInput(accumulator, aggRow.getBaseValue(aggFieldIndex));
                        }
                        ++count;
                    }
                    Object result = fieldAgg.combineFn.extractOutput(accumulator);
                    Row processingRow = sortedRowsAsList.get(idx);
                    ArrayList fieldValues = Lists.newArrayListWithCapacity((int)processingRow.getFieldCount());
                    fieldValues.addAll(processingRow.getValues());
                    fieldValues.add(result);
                    Row build = Row.withSchema((Schema)expectedSchema).addValues((List)fieldValues).build();
                    out.output((Object)build);
                }
            }

            private NavigableMap<BigDecimal, List<Row>> indexRows(List<Row> input) {
                TreeMap<BigDecimal, List<Row>> map = new TreeMap<BigDecimal, List<Row>>();
                for (Row r : input) {
                    BigDecimal orderByValue = this.getOrderByValue(r);
                    if (orderByValue == null) {
                        orderByValue = BigDecimal.ZERO;
                    }
                    if (!map.containsKey(orderByValue)) {
                        map.put(orderByValue, Lists.newArrayList());
                    }
                    ((List)map.get(orderByValue)).add(r);
                }
                return map;
            }

            private List<Row> getRange(NavigableMap<BigDecimal, List<Row>> indexRanges, Row aRow) {
                NavigableMap<BigDecimal, List<Row>> subMap;
                BigDecimal ll;
                BigDecimal currentRowValue = this.getOrderByValue(aRow);
                if (currentRowValue != null && fieldAgg.lowerLimit != null && fieldAgg.upperLimit != null) {
                    ll = currentRowValue.subtract(fieldAgg.lowerLimit);
                    BigDecimal ul = currentRowValue.add(fieldAgg.upperLimit);
                    subMap = indexRanges.subMap(ll, true, ul, true);
                } else if (currentRowValue != null && fieldAgg.lowerLimit != null && fieldAgg.upperLimit == null) {
                    ll = currentRowValue.subtract(fieldAgg.lowerLimit);
                    subMap = indexRanges.tailMap(ll, true);
                } else if (currentRowValue != null && fieldAgg.lowerLimit == null && fieldAgg.upperLimit != null) {
                    BigDecimal ul = currentRowValue.add(fieldAgg.upperLimit);
                    subMap = indexRanges.headMap(ul, true);
                } else {
                    subMap = indexRanges;
                }
                ArrayList result = Lists.newArrayList();
                for (List partialList : subMap.values()) {
                    result.addAll(partialList);
                }
                return result;
            }

            private BigDecimal getOrderByValue(Row r) {
                if (fieldAgg.orderKeys.size() == 0) {
                    return null;
                }
                return new BigDecimal(((Number)r.getBaseValue(((Integer)fieldAgg.orderKeys.get(0)).intValue())).toString());
            }

            private List<Row> getRows(List<Row> input, int index) {
                Integer ll = fieldAgg.lowerLimit != null ? fieldAgg.lowerLimit.intValue() : Integer.MAX_VALUE;
                Integer ul = fieldAgg.upperLimit != null ? fieldAgg.upperLimit.intValue() : Integer.MAX_VALUE;
                int lowerIndex = ll == Integer.MAX_VALUE ? Integer.MIN_VALUE : index - ll;
                int upperIndex = ul == Integer.MAX_VALUE ? Integer.MAX_VALUE : index + ul + 1;
                lowerIndex = lowerIndex < 0 ? 0 : lowerIndex;
                upperIndex = upperIndex > input.size() ? input.size() : upperIndex;
                List<Row> out = input.subList(lowerIndex, upperIndex);
                return out;
            }
        };
    }

    private static DoFn<Iterable<Row>, List<Row>> sortPartition(final FieldAggregation fieldAgg) {
        return new DoFn<Iterable<Row>, List<Row>>(){

            @DoFn.ProcessElement
            public void processElement(@DoFn.Element Iterable<Row> inputPartition, DoFn.OutputReceiver<List<Row>> out, DoFn.ProcessContext c) {
                ArrayList partitionRows = Lists.newArrayList(inputPartition);
                BeamSortRel.BeamSqlRowComparator beamSqlRowComparator = new BeamSortRel.BeamSqlRowComparator(fieldAgg.orderKeys, fieldAgg.orderOrientations, fieldAgg.orderNulls);
                Collections.sort(partitionRows, beamSqlRowComparator);
                out.output((Object)partitionRows);
            }
        };
    }

    public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
        return this.copy(traitSet, (RelNode)BeamWindowRel.sole(inputs), (List<RexLiteral>)this.constants, this.rowType, (List<Window.Group>)this.groups);
    }

    public BeamWindowRel copy(RelTraitSet traitSet, RelNode input, List<RexLiteral> constants, RelDataType rowType, List<Window.Group> groups) {
        return new BeamWindowRel(this.getCluster(), traitSet, input, constants, rowType, groups);
    }

    static class SelectOnlyValues
    extends DoFn<KV<Row, Iterable<Row>>, Iterable<Row>> {
        SelectOnlyValues() {
        }

        @DoFn.ProcessElement
        public void processElement(@DoFn.Element KV<Row, Iterable<Row>> inputPartition, DoFn.OutputReceiver<Iterable<Row>> out, DoFn.ProcessContext c) {
            out.output((Object)((Iterable)inputPartition.getValue()));
        }
    }

    private static class Transform
    extends PTransform<PCollectionList<Row>, PCollection<Row>> {
        private Schema outputSchema;
        private List<FieldAggregation> aggFields;

        public Transform(Schema schema, List<FieldAggregation> fieldAgg) {
            this.outputSchema = schema;
            this.aggFields = fieldAgg;
        }

        public PCollection<Row> expand(PCollectionList<Row> input) {
            PCollection inputData = input.get(0);
            Schema inputSchema = inputData.getSchema();
            int ids = 0;
            for (FieldAggregation af : this.aggFields) {
                String prefix = "transform_" + ++ids;
                Coder rowCoder = inputData.getCoder();
                PCollection partitioned = null;
                if (af.partitionKeys.isEmpty()) {
                    partitioned = (PCollection)inputData.apply(prefix + "globalPartition", (PTransform)Group.globally());
                } else {
                    Group.ByFields myg = Group.byFieldIds((Iterable)af.partitionKeys);
                    PCollection partitionBy = (PCollection)inputData.apply(prefix + "partitionBy", (PTransform)myg.getToKvs());
                    partitioned = ((PCollection)partitionBy.apply(prefix + "selectOnlyValues", (PTransform)ParDo.of((DoFn)new SelectOnlyValues()))).setCoder((Coder)IterableCoder.of((Coder)rowCoder));
                }
                PCollection sortedPartition = ((PCollection)partitioned.apply(prefix + "orderBy", (PTransform)ParDo.of((DoFn)BeamWindowRel.sortPartition(af)))).setCoder((Coder)ListCoder.of((Coder)rowCoder));
                inputSchema = Schema.builder().addFields(inputSchema.getFields()).addFields(new Schema.Field[]{af.outputField}).build();
                inputData = ((PCollection)sortedPartition.apply(prefix + "aggCall", (PTransform)ParDo.of((DoFn)BeamWindowRel.aggField(inputSchema, af)))).setRowSchema(inputSchema);
            }
            return inputData.setRowSchema(this.outputSchema);
        }
    }

    private static class FieldAggregation
    implements Serializable {
        private List<Integer> partitionKeys;
        private List<Integer> orderKeys;
        private List<Boolean> orderOrientations;
        private List<Boolean> orderNulls;
        private BigDecimal lowerLimit = null;
        private BigDecimal upperLimit = null;
        private boolean rows = true;
        private List<Integer> inputFields;
        private Combine.CombineFn combineFn;
        private Schema.Field outputField;

        public FieldAggregation(List<Integer> partitionKeys, List<Integer> orderKeys, List<Boolean> orderOrientations, List<Boolean> orderNulls, BigDecimal lowerLimit, BigDecimal upperLimit, boolean rows, List<Integer> inputFields, Combine.CombineFn combineFn, Schema.Field outputField) {
            this.partitionKeys = partitionKeys;
            this.orderKeys = orderKeys;
            this.orderOrientations = orderOrientations;
            this.orderNulls = orderNulls;
            this.lowerLimit = lowerLimit;
            this.upperLimit = upperLimit;
            this.rows = rows;
            this.inputFields = inputFields;
            this.combineFn = combineFn;
            this.outputField = outputField;
        }
    }
}

