/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.transform;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import org.apache.beam.sdk.coders.BeamRecordCoder;
import org.apache.beam.sdk.coders.BigDecimalCoder;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType;
import org.apache.beam.sdk.extensions.sql.BeamSqlRecordHelper;
import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlInputRefExpression;
import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.UdafImpl;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamBuiltinAggregations;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.values.BeamRecord;
import org.apache.beam.sdk.values.BeamRecordType;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
import org.apache.beam.sdks.java.extensions.sql.repackaged.org.apache.calcite.util.ImmutableBitSet;
import org.joda.time.Instant;

public class BeamAggregationTransforms
implements Serializable {

    public static class AggregationAccumulatorCoder
    extends CustomCoder<AggregationAccumulator> {
        private VarIntCoder sizeCoder = VarIntCoder.of();
        private List<Coder> elementCoders;

        public AggregationAccumulatorCoder(List<Coder> elementCoders) {
            this.elementCoders = elementCoders;
        }

        public void encode(AggregationAccumulator value, OutputStream outStream) throws CoderException, IOException {
            this.sizeCoder.encode(Integer.valueOf(value.accumulatorElements.size()), outStream);
            for (int idx = 0; idx < value.accumulatorElements.size(); ++idx) {
                this.elementCoders.get(idx).encode(value.accumulatorElements.get(idx), outStream);
            }
        }

        public AggregationAccumulator decode(InputStream inStream) throws CoderException, IOException {
            AggregationAccumulator accu = new AggregationAccumulator();
            int size = this.sizeCoder.decode(inStream);
            for (int idx = 0; idx < size; ++idx) {
                accu.accumulatorElements.add(this.elementCoders.get(idx).decode(inStream));
            }
            return accu;
        }
    }

    public static class AggregationAccumulator {
        private List accumulatorElements = new ArrayList();
    }

    public static class AggregationAdaptor
    extends Combine.CombineFn<BeamRecord, AggregationAccumulator, BeamRecord> {
        private List<Combine.CombineFn> aggregators = new ArrayList<Combine.CombineFn>();
        private List<BeamSqlInputRefExpression> sourceFieldExps = new ArrayList<BeamSqlInputRefExpression>();
        private BeamRecordSqlType finalRowType;

        public AggregationAdaptor(List<AggregateCall> aggregationCalls, BeamRecordSqlType sourceRowType) {
            ArrayList<String> outFieldsName = new ArrayList<String>();
            ArrayList<Integer> outFieldsType = new ArrayList<Integer>();
            block20: for (AggregateCall call : aggregationCalls) {
                int refIndex = call.getArgList().size() > 0 ? call.getArgList().get(0) : 0;
                BeamSqlInputRefExpression sourceExp = new BeamSqlInputRefExpression(CalciteUtils.getFieldType(sourceRowType, refIndex), refIndex);
                this.sourceFieldExps.add(sourceExp);
                outFieldsName.add(call.name);
                int outFieldType = CalciteUtils.toJavaType(call.type.getSqlTypeName());
                outFieldsType.add(outFieldType);
                switch (call.getAggregation().getName()) {
                    case "COUNT": {
                        this.aggregators.add(Count.combineFn());
                        continue block20;
                    }
                    case "MAX": {
                        this.aggregators.add(BeamBuiltinAggregations.createMax(call.type.getSqlTypeName()));
                        continue block20;
                    }
                    case "MIN": {
                        this.aggregators.add(BeamBuiltinAggregations.createMin(call.type.getSqlTypeName()));
                        continue block20;
                    }
                    case "SUM": {
                        this.aggregators.add(BeamBuiltinAggregations.createSum(call.type.getSqlTypeName()));
                        continue block20;
                    }
                    case "AVG": {
                        this.aggregators.add(BeamBuiltinAggregations.createAvg(call.type.getSqlTypeName()));
                        continue block20;
                    }
                    case "VAR_POP": {
                        this.aggregators.add(BeamBuiltinAggregations.createVar(call.type.getSqlTypeName(), false));
                        continue block20;
                    }
                    case "VAR_SAMP": {
                        this.aggregators.add(BeamBuiltinAggregations.createVar(call.type.getSqlTypeName(), true));
                        continue block20;
                    }
                }
                if (call.getAggregation() instanceof SqlUserDefinedAggFunction) {
                    SqlUserDefinedAggFunction udaf = (SqlUserDefinedAggFunction)call.getAggregation();
                    UdafImpl fn = (UdafImpl)udaf.function;
                    try {
                        this.aggregators.add(fn.getCombineFn());
                        continue;
                    }
                    catch (Exception e) {
                        throw new IllegalStateException(e);
                    }
                }
                throw new UnsupportedOperationException(String.format("Aggregator [%s] is not supported", call.getAggregation().getName()));
            }
            this.finalRowType = BeamRecordSqlType.create(outFieldsName, outFieldsType);
        }

        public AggregationAccumulator createAccumulator() {
            AggregationAccumulator initialAccu = new AggregationAccumulator();
            for (Combine.CombineFn agg : this.aggregators) {
                initialAccu.accumulatorElements.add(agg.createAccumulator());
            }
            return initialAccu;
        }

        public AggregationAccumulator addInput(AggregationAccumulator accumulator, BeamRecord input) {
            AggregationAccumulator deltaAcc = new AggregationAccumulator();
            for (int idx = 0; idx < this.aggregators.size(); ++idx) {
                deltaAcc.accumulatorElements.add(this.aggregators.get(idx).addInput(accumulator.accumulatorElements.get(idx), this.sourceFieldExps.get(idx).evaluate(input, null).getValue()));
            }
            return deltaAcc;
        }

        public AggregationAccumulator mergeAccumulators(Iterable<AggregationAccumulator> accumulators) {
            AggregationAccumulator deltaAcc = new AggregationAccumulator();
            for (int idx = 0; idx < this.aggregators.size(); ++idx) {
                ArrayList accs = new ArrayList();
                Iterator<AggregationAccumulator> ite = accumulators.iterator();
                while (ite.hasNext()) {
                    accs.add(ite.next().accumulatorElements.get(idx));
                }
                deltaAcc.accumulatorElements.add(this.aggregators.get(idx).mergeAccumulators(accs));
            }
            return deltaAcc;
        }

        public BeamRecord extractOutput(AggregationAccumulator accumulator) {
            ArrayList<Object> fieldValues = new ArrayList<Object>(this.aggregators.size());
            for (int idx = 0; idx < this.aggregators.size(); ++idx) {
                fieldValues.add(this.aggregators.get(idx).extractOutput(accumulator.accumulatorElements.get(idx)));
            }
            return new BeamRecord((BeamRecordType)this.finalRowType, fieldValues);
        }

        public Coder<AggregationAccumulator> getAccumulatorCoder(CoderRegistry registry, Coder<BeamRecord> inputCoder) throws CannotProvideCoderException {
            BeamRecordCoder beamRecordCoder = (BeamRecordCoder)inputCoder;
            registry.registerCoderForClass(BigDecimal.class, (Coder)BigDecimalCoder.of());
            ArrayList<Coder> aggAccuCoderList = new ArrayList<Coder>();
            for (int idx = 0; idx < this.aggregators.size(); ++idx) {
                int srcFieldIndex = this.sourceFieldExps.get(idx).getInputRef();
                Coder srcFieldCoder = (Coder)beamRecordCoder.getCoders().get(srcFieldIndex);
                aggAccuCoderList.add(this.aggregators.get(idx).getAccumulatorCoder(registry, srcFieldCoder));
            }
            return new AggregationAccumulatorCoder(aggAccuCoderList);
        }
    }

    public static class WindowTimestampFn
    implements SerializableFunction<BeamRecord, Instant> {
        private int windowFieldIdx = -1;

        public WindowTimestampFn(int windowFieldIdx) {
            this.windowFieldIdx = windowFieldIdx;
        }

        public Instant apply(BeamRecord input) {
            return new Instant(input.getDate(this.windowFieldIdx).getTime());
        }
    }

    public static class AggregationGroupByKeyFn
    implements SerializableFunction<BeamRecord, BeamRecord> {
        private List<Integer> groupByKeys = new ArrayList<Integer>();

        public AggregationGroupByKeyFn(int windowFieldIdx, ImmutableBitSet groupSet) {
            for (int i : groupSet.asList()) {
                if (i == windowFieldIdx) continue;
                this.groupByKeys.add(i);
            }
        }

        public BeamRecord apply(BeamRecord input) {
            BeamRecordSqlType typeOfKey = this.exTypeOfKeyRecord(BeamSqlRecordHelper.getSqlRecordType(input));
            ArrayList<Object> fieldValues = new ArrayList<Object>(this.groupByKeys.size());
            for (int idx = 0; idx < this.groupByKeys.size(); ++idx) {
                fieldValues.add(input.getFieldValue(this.groupByKeys.get(idx).intValue()));
            }
            BeamRecord keyOfRecord = new BeamRecord((BeamRecordType)typeOfKey, fieldValues);
            return keyOfRecord;
        }

        private BeamRecordSqlType exTypeOfKeyRecord(BeamRecordSqlType dataType) {
            ArrayList<String> fieldNames = new ArrayList<String>();
            ArrayList<Integer> fieldTypes = new ArrayList<Integer>();
            for (int idx : this.groupByKeys) {
                fieldNames.add(dataType.getFieldNameByIndex(idx));
                fieldTypes.add(dataType.getFieldTypeByIndex(idx));
            }
            return BeamRecordSqlType.create(fieldNames, fieldTypes);
        }
    }

    public static class MergeAggregationRecord
    extends DoFn<KV<BeamRecord, BeamRecord>, BeamRecord> {
        private BeamRecordSqlType outRowType;
        private List<String> aggFieldNames;
        private int windowStartFieldIdx;

        public MergeAggregationRecord(BeamRecordSqlType outRowType, List<AggregateCall> aggList, int windowStartFieldIdx) {
            this.outRowType = outRowType;
            this.aggFieldNames = new ArrayList<String>();
            for (AggregateCall ac : aggList) {
                this.aggFieldNames.add(ac.getName());
            }
            this.windowStartFieldIdx = windowStartFieldIdx;
        }

        @DoFn.ProcessElement
        public void processElement(DoFn.ProcessContext c, BoundedWindow window) {
            KV kvRecord = (KV)c.element();
            ArrayList<Date> fieldValues = new ArrayList<Date>();
            fieldValues.addAll(((BeamRecord)kvRecord.getKey()).getDataValues());
            fieldValues.addAll(((BeamRecord)kvRecord.getValue()).getDataValues());
            if (this.windowStartFieldIdx != -1) {
                fieldValues.add(this.windowStartFieldIdx, ((IntervalWindow)window).start().toDate());
            }
            BeamRecord outRecord = new BeamRecord((BeamRecordType)this.outRowType, fieldValues);
            c.output((Object)outRecord);
        }
    }
}

