/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.transform;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.com.google.common.collect.ImmutableList;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.com.google.common.collect.Lists;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.ImmutableBitSet;
import org.apache.beam.repackaged.beam_sdks_java_extensions_sql.org.apache.calcite.util.Pair;
import org.apache.beam.sdk.coders.BigDecimalCoder;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.extensions.sql.impl.UdafImpl;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamBuiltinAggregations;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CovarianceFn;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.VarianceFn;
import org.apache.beam.sdk.extensions.sql.impl.utils.BigDecimalConverter;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.Row;
import org.joda.time.Instant;

public class BeamAggregationTransforms
implements Serializable {

    public static class AggregationAccumulatorCoder
    extends CustomCoder<AggregationAccumulator> {
        private VarIntCoder sizeCoder = VarIntCoder.of();
        private List<Coder> elementCoders;

        public AggregationAccumulatorCoder(List<Coder> elementCoders) {
            this.elementCoders = elementCoders;
        }

        public void encode(AggregationAccumulator value, OutputStream outStream) throws IOException {
            this.sizeCoder.encode(Integer.valueOf(value.accumulatorElements.size()), outStream);
            for (int idx = 0; idx < value.accumulatorElements.size(); ++idx) {
                this.elementCoders.get(idx).encode(value.accumulatorElements.get(idx), outStream);
            }
        }

        public AggregationAccumulator decode(InputStream inStream) throws CoderException, IOException {
            AggregationAccumulator accu = new AggregationAccumulator();
            int size = this.sizeCoder.decode(inStream);
            for (int idx = 0; idx < size; ++idx) {
                accu.accumulatorElements.add(this.elementCoders.get(idx).decode(inStream));
            }
            return accu;
        }
    }

    public static class AggregationAccumulator {
        private List accumulatorElements = new ArrayList();
    }

    public static class AggregationAdaptor
    extends Combine.CombineFn<Row, AggregationAccumulator, Row> {
        private List<Combine.CombineFn> aggregators = new ArrayList<Combine.CombineFn>();
        private List<Object> sourceFieldExps = new ArrayList<Object>();
        private Schema sourceSchema;
        private Schema finalSchema;

        public AggregationAdaptor(List<Pair<AggregateCall, String>> aggregationCalls, Schema sourceSchema) {
            this.sourceSchema = sourceSchema;
            ImmutableList.Builder fields = ImmutableList.builder();
            block25: for (Pair<AggregateCall, String> aggCall : aggregationCalls) {
                AggregateCall call = (AggregateCall)aggCall.left;
                String aggName = (String)aggCall.right;
                if (call.getArgList().size() == 2) {
                    int refIndexKey = call.getArgList().get(0);
                    int refIndexValue = call.getArgList().get(1);
                    this.sourceFieldExps.add(KV.of((Object)refIndexKey, (Object)refIndexValue));
                } else {
                    Integer refIndex = call.getArgList().size() > 0 ? call.getArgList().get(0) : 0;
                    this.sourceFieldExps.add(refIndex);
                }
                Schema.Field field = CalciteUtils.toField(aggName, call.type);
                Schema.TypeName fieldTypeName = field.getType().getTypeName();
                fields.add(field);
                switch (call.getAggregation().getName()) {
                    case "COUNT": {
                        this.aggregators.add(Count.combineFn());
                        continue block25;
                    }
                    case "MAX": {
                        this.aggregators.add(BeamBuiltinAggregations.createMax(call.type.getSqlTypeName()));
                        continue block25;
                    }
                    case "MIN": {
                        this.aggregators.add(BeamBuiltinAggregations.createMin(call.type.getSqlTypeName()));
                        continue block25;
                    }
                    case "SUM": 
                    case "$SUM0": {
                        this.aggregators.add(BeamBuiltinAggregations.createSum(call.type.getSqlTypeName()));
                        continue block25;
                    }
                    case "AVG": {
                        this.aggregators.add(BeamBuiltinAggregations.createAvg(call.type.getSqlTypeName()));
                        continue block25;
                    }
                    case "VAR_POP": {
                        this.aggregators.add(VarianceFn.newPopulation(BigDecimalConverter.forSqlType(fieldTypeName)));
                        continue block25;
                    }
                    case "VAR_SAMP": {
                        this.aggregators.add(VarianceFn.newSample(BigDecimalConverter.forSqlType(fieldTypeName)));
                        continue block25;
                    }
                    case "COVAR_POP": {
                        this.aggregators.add(CovarianceFn.newPopulation(BigDecimalConverter.forSqlType(fieldTypeName)));
                        continue block25;
                    }
                    case "COVAR_SAMP": {
                        this.aggregators.add(CovarianceFn.newSample(BigDecimalConverter.forSqlType(fieldTypeName)));
                        continue block25;
                    }
                }
                if (call.getAggregation() instanceof SqlUserDefinedAggFunction) {
                    SqlUserDefinedAggFunction udaf = (SqlUserDefinedAggFunction)call.getAggregation();
                    UdafImpl fn = (UdafImpl)udaf.function;
                    try {
                        this.aggregators.add(fn.getCombineFn());
                        continue;
                    }
                    catch (Exception e) {
                        throw new IllegalStateException(e);
                    }
                }
                throw new UnsupportedOperationException(String.format("Aggregator [%s] is not supported", call.getAggregation().getName()));
            }
            this.finalSchema = (Schema)fields.build().stream().collect(Schema.toSchema());
        }

        public AggregationAccumulator createAccumulator() {
            AggregationAccumulator initialAccu = new AggregationAccumulator();
            for (Combine.CombineFn agg : this.aggregators) {
                initialAccu.accumulatorElements.add(agg.createAccumulator());
            }
            return initialAccu;
        }

        public AggregationAccumulator addInput(AggregationAccumulator accumulator, Row input) {
            AggregationAccumulator deltaAcc = new AggregationAccumulator();
            for (int idx = 0; idx < this.aggregators.size(); ++idx) {
                Combine.CombineFn aggregator = this.aggregators.get(idx);
                Object element = accumulator.accumulatorElements.get(idx);
                if (this.sourceFieldExps.get(idx) instanceof Integer) {
                    Object value = input.getValue(((Integer)this.sourceFieldExps.get(idx)).intValue());
                    if (value != null) {
                        Object delta = aggregator.addInput(element, value);
                        deltaAcc.accumulatorElements.add(delta);
                        continue;
                    }
                    deltaAcc.accumulatorElements.add(element);
                    continue;
                }
                if (!(this.sourceFieldExps.get(idx) instanceof KV)) continue;
                KV exp = (KV)this.sourceFieldExps.get(idx);
                Object key = input.getValue(((Integer)exp.getKey()).intValue());
                Object value = input.getValue(((Integer)exp.getValue()).intValue());
                if (key != null && value != null) {
                    deltaAcc.accumulatorElements.add(aggregator.addInput(element, (Object)KV.of((Object)key, (Object)value)));
                    continue;
                }
                deltaAcc.accumulatorElements.add(element);
            }
            return deltaAcc;
        }

        public AggregationAccumulator mergeAccumulators(Iterable<AggregationAccumulator> accumulators) {
            AggregationAccumulator deltaAcc = new AggregationAccumulator();
            for (int idx = 0; idx < this.aggregators.size(); ++idx) {
                ArrayList accs = new ArrayList();
                for (AggregationAccumulator accumulator : accumulators) {
                    accs.add(accumulator.accumulatorElements.get(idx));
                }
                deltaAcc.accumulatorElements.add(this.aggregators.get(idx).mergeAccumulators(accs));
            }
            return deltaAcc;
        }

        public Row extractOutput(AggregationAccumulator accumulator) {
            return (Row)IntStream.range(0, this.aggregators.size()).mapToObj(idx -> this.getAggregatorOutput(accumulator, idx)).collect(Row.toRow((Schema)this.finalSchema));
        }

        private Object getAggregatorOutput(AggregationAccumulator accumulator, int idx) {
            return this.aggregators.get(idx).extractOutput(accumulator.accumulatorElements.get(idx));
        }

        public Coder<AggregationAccumulator> getAccumulatorCoder(CoderRegistry registry, Coder<Row> inputCoder) throws CannotProvideCoderException {
            registry.registerCoderForClass(BigDecimal.class, (Coder)BigDecimalCoder.of());
            ArrayList<Coder> aggAccuCoderList = new ArrayList<Coder>();
            for (int idx = 0; idx < this.aggregators.size(); ++idx) {
                if (this.sourceFieldExps.get(idx) instanceof Integer) {
                    int srcFieldIndex = (Integer)this.sourceFieldExps.get(idx);
                    Coder srcFieldCoder = RowCoder.coderForFieldType((Schema.FieldType)this.sourceSchema.getField(srcFieldIndex).getType());
                    aggAccuCoderList.add(this.aggregators.get(idx).getAccumulatorCoder(registry, srcFieldCoder));
                    continue;
                }
                if (!(this.sourceFieldExps.get(idx) instanceof KV)) continue;
                KV exp = (KV)this.sourceFieldExps.get(idx);
                int srcFieldIndexKey = (Integer)exp.getKey();
                int srcFieldIndexValue = (Integer)exp.getValue();
                Coder srcFieldCoderKey = RowCoder.coderForFieldType((Schema.FieldType)this.sourceSchema.getField(srcFieldIndexKey).getType());
                Coder srcFieldCoderValue = RowCoder.coderForFieldType((Schema.FieldType)this.sourceSchema.getField(srcFieldIndexValue).getType());
                aggAccuCoderList.add(this.aggregators.get(idx).getAccumulatorCoder(registry, (Coder)KvCoder.of((Coder)srcFieldCoderKey, (Coder)srcFieldCoderValue)));
            }
            return new AggregationAccumulatorCoder(aggAccuCoderList);
        }
    }

    public static class WindowTimestampFn
    implements SerializableFunction<Row, Instant> {
        private int windowFieldIdx = -1;

        public WindowTimestampFn(int windowFieldIdx) {
            this.windowFieldIdx = windowFieldIdx;
        }

        public Instant apply(Row input) {
            return new Instant((Object)input.getDateTime(this.windowFieldIdx));
        }
    }

    public static class AggregationGroupByKeyFn
    implements SerializableFunction<Row, Row> {
        private Schema keySchema;
        private List<Integer> groupByKeys;

        public AggregationGroupByKeyFn(Schema keySchema, int windowFieldIdx, ImmutableBitSet groupSet) {
            this.keySchema = keySchema;
            this.groupByKeys = new ArrayList<Integer>();
            for (int i : groupSet.asList()) {
                if (i == windowFieldIdx) continue;
                this.groupByKeys.add(i);
            }
        }

        public Row apply(Row input) {
            return (Row)this.groupByKeys.stream().map(arg_0 -> ((Row)input).getValue(arg_0)).collect(Row.toRow((Schema)this.keySchema));
        }
    }

    public static class MergeAggregationRecord
    extends DoFn<KV<Row, Row>, Row> {
        private Schema outSchema;
        private int windowStartFieldIdx;

        public MergeAggregationRecord(Schema outSchema, int windowStartFieldIdx) {
            this.outSchema = outSchema;
            this.windowStartFieldIdx = windowStartFieldIdx;
        }

        @DoFn.ProcessElement
        public void processElement(DoFn.ProcessContext c, BoundedWindow window) {
            KV kvRow = (KV)c.element();
            ArrayList<Instant> fieldValues = Lists.newArrayListWithCapacity(((Row)kvRow.getKey()).getValues().size() + ((Row)kvRow.getValue()).getValues().size());
            fieldValues.addAll(((Row)kvRow.getKey()).getValues());
            fieldValues.addAll(((Row)kvRow.getValue()).getValues());
            if (this.windowStartFieldIdx != -1) {
                fieldValues.add(this.windowStartFieldIdx, ((IntervalWindow)window).start());
            }
            c.output((Object)Row.withSchema((Schema)this.outSchema).addValues(fieldValues).build());
        }
    }
}

