/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.AggregatorCombiner;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.KeyValueGroupedDataset;
import org.joda.time.Instant;

class CombinePerKeyTranslatorBatch<K, InputT, AccumT, OutputT>
implements TransformTranslator<PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> {
    CombinePerKeyTranslatorBatch() {
    }

    @Override
    public void translateTransform(PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> transform, AbstractTranslationContext context) {
        Combine.PerKey combineTransform = (Combine.PerKey)transform;
        PCollection input = (PCollection)context.getInput();
        PCollection output = (PCollection)context.getOutput();
        Combine.CombineFn combineFn = (Combine.CombineFn)combineTransform.getFn();
        WindowingStrategy windowingStrategy = input.getWindowingStrategy();
        Dataset inputDataset = context.getDataset((PValue)input);
        KvCoder inputCoder = (KvCoder)input.getCoder();
        Coder keyCoder = inputCoder.getKeyCoder();
        KvCoder outputKVCoder = (KvCoder)output.getCoder();
        Coder outputCoder = outputKVCoder.getValueCoder();
        KeyValueGroupedDataset groupedDataset = inputDataset.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
        Coder accumulatorCoder = null;
        try {
            accumulatorCoder = combineFn.getAccumulatorCoder(input.getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
        }
        catch (CannotProvideCoderException e) {
            throw new RuntimeException(e);
        }
        Dataset combinedDataset = groupedDataset.agg(new AggregatorCombiner(combineFn, windowingStrategy, accumulatorCoder, outputCoder).toColumn());
        WindowedValue.FullWindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of((Coder)outputKVCoder, (Coder)input.getWindowingStrategy().getWindowFn().windowCoder());
        Dataset outputDataset = combinedDataset.flatMap((FlatMapFunction & Serializable)tuple2 -> {
            Object key = tuple2._1();
            Iterable windowedValues = (Iterable)tuple2._2();
            ArrayList<WindowedValue> result = new ArrayList<WindowedValue>();
            for (WindowedValue windowedValue : windowedValues) {
                KV kv = KV.of((Object)key, (Object)windowedValue.getValue());
                result.add(WindowedValue.of((Object)kv, (Instant)windowedValue.getTimestamp(), (Collection)windowedValue.getWindows(), (PaneInfo)windowedValue.getPane()));
            }
            return result.iterator();
        }, EncoderHelpers.fromBeamCoder(wvCoder));
        context.putDataset((PValue)output, outputDataset);
    }
}

