/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.spark.translation;

import java.io.Serializable;
import java.util.Collection;
import java.util.Map;
import org.apache.beam.repackaged.beam_runners_spark.com.google.common.base.Optional;
import org.apache.beam.repackaged.beam_runners_spark.com.google.common.base.Preconditions;
import org.apache.beam.repackaged.beam_runners_spark.com.google.common.collect.FluentIterable;
import org.apache.beam.repackaged.beam_runners_spark.com.google.common.collect.Lists;
import org.apache.beam.repackaged.beam_runners_spark.com.google.common.collect.Maps;
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator;
import org.apache.beam.runners.spark.aggregators.NamedAggregators;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.io.SourceRDD;
import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.translation.BoundedDataset;
import org.apache.beam.runners.spark.translation.Dataset;
import org.apache.beam.runners.spark.translation.EvaluationContext;
import org.apache.beam.runners.spark.translation.GroupCombineFunctions;
import org.apache.beam.runners.spark.translation.MultiDoFnFunction;
import org.apache.beam.runners.spark.translation.SparkAssignWindowFn;
import org.apache.beam.runners.spark.translation.SparkGlobalCombineFn;
import org.apache.beam.runners.spark.translation.SparkGroupAlsoByWindowViaOutputBufferFn;
import org.apache.beam.runners.spark.translation.SparkKeyedCombineFn;
import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
import org.apache.beam.runners.spark.translation.TransformEvaluator;
import org.apache.beam.runners.spark.translation.TranslationUtils;
import org.apache.beam.runners.spark.translation.WindowingHelpers;
import org.apache.beam.runners.spark.util.SideInputBroadcast;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.io.Read;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.CombineFnUtil;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.storage.StorageLevel;
import org.joda.time.Instant;

public final class TransformTranslator {
    private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> EVALUATORS = Maps.newHashMap();

    private TransformTranslator() {
    }

    private static <T> TransformEvaluator<Flatten.PCollections<T>> flattenPColl() {
        return new TransformEvaluator<Flatten.PCollections<T>>(){

            @Override
            public void evaluate(Flatten.PCollections<T> transform, EvaluationContext context) {
                JavaRDD unionRDD;
                Collection<PValue> pcs = context.getInputs((PTransform<?, ?>)transform).values();
                if (pcs.isEmpty()) {
                    unionRDD = context.getSparkContext().emptyRDD();
                } else {
                    JavaRDD[] rdds = new JavaRDD[pcs.size()];
                    int index = 0;
                    for (PValue pc : pcs) {
                        Preconditions.checkArgument(pc instanceof PCollection, "Flatten had non-PCollection value in input: %s of type %s", (Object)pc, (Object)pc.getClass().getSimpleName());
                        rdds[index] = ((BoundedDataset)context.borrowDataset(pc)).getRDD();
                        ++index;
                    }
                    unionRDD = context.getSparkContext().union(rdds);
                }
                context.putDataset((PTransform<?, ? extends PValue>)transform, new BoundedDataset(unionRDD));
            }

            @Override
            public String toNativeString() {
                return "sparkContext.union(...)";
            }
        };
    }

    private static <K, V, W extends BoundedWindow> TransformEvaluator<GroupByKey<K, V>> groupByKey() {
        return new TransformEvaluator<GroupByKey<K, V>>(){

            @Override
            public void evaluate(GroupByKey<K, V> transform, EvaluationContext context) {
                JavaRDD inRDD = ((BoundedDataset)context.borrowDataset((PTransform<? extends PValue, ?>)transform)).getRDD();
                KvCoder coder = (KvCoder)((PCollection)context.getInput(transform)).getCoder();
                Accumulator<NamedAggregators> accum = AggregatorsAccumulator.getInstance();
                WindowingStrategy windowingStrategy = ((PCollection)context.getInput(transform)).getWindowingStrategy();
                WindowFn windowFn = windowingStrategy.getWindowFn();
                Coder keyCoder = coder.getKeyCoder();
                WindowedValue.FullWindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of((Coder)coder.getValueCoder(), (Coder)windowFn.windowCoder());
                JavaRDD groupedByKey = GroupCombineFunctions.groupByKeyOnly(inRDD, keyCoder, wvCoder);
                JavaRDD groupedAlsoByWindow = groupedByKey.flatMap(new SparkGroupAlsoByWindowViaOutputBufferFn(windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory(), SystemReduceFn.buffering((Coder)coder.getValueCoder()), context.getSerializableOptions(), accum));
                context.putDataset((PTransform<?, ? extends PValue>)transform, new BoundedDataset(groupedAlsoByWindow));
            }

            @Override
            public String toNativeString() {
                return "groupByKey()";
            }
        };
    }

    private static <K, InputT, OutputT> TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>> combineGrouped() {
        return new TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>>(){

            @Override
            public void evaluate(Combine.GroupedValues<K, InputT, OutputT> transform, EvaluationContext context) {
                CombineWithContext.CombineFnWithContext combineFn = CombineFnUtil.toFnWithContext((CombineFnBase.GlobalCombineFn)transform.getFn());
                SparkKeyedCombineFn sparkCombineFn = new SparkKeyedCombineFn(combineFn, context.getSerializableOptions(), TranslationUtils.getSideInputs(transform.getSideInputs(), context), ((PCollection)context.getInput(transform)).getWindowingStrategy());
                JavaRDD inRDD = ((BoundedDataset)context.borrowDataset((PTransform<? extends PValue, ?>)transform)).getRDD();
                JavaRDD outRDD = inRDD.map((Function & Serializable)in -> WindowedValue.of((Object)KV.of((Object)((KV)in.getValue()).getKey(), sparkCombineFn.apply(in)), (Instant)in.getTimestamp(), (Collection)in.getWindows(), (PaneInfo)in.getPane()));
                context.putDataset((PTransform<?, ? extends PValue>)transform, new BoundedDataset(outRDD));
            }

            @Override
            public String toNativeString() {
                return "map(new <fn>())";
            }
        };
    }

    private static <InputT, AccumT, OutputT> TransformEvaluator<Combine.Globally<InputT, OutputT>> combineGlobally() {
        return new TransformEvaluator<Combine.Globally<InputT, OutputT>>(){

            @Override
            public void evaluate(Combine.Globally<InputT, OutputT> transform, EvaluationContext context) {
                JavaRDD outRdd;
                Coder aCoder;
                PCollection input = (PCollection)context.getInput(transform);
                Coder iCoder = ((PCollection)context.getInput(transform)).getCoder();
                Coder oCoder = ((PCollection)context.getOutput(transform)).getCoder();
                WindowingStrategy windowingStrategy = input.getWindowingStrategy();
                CombineWithContext.CombineFnWithContext combineFn = CombineFnUtil.toFnWithContext((CombineFnBase.GlobalCombineFn)transform.getFn());
                WindowedValue.FullWindowedValueCoder wvoCoder = WindowedValue.FullWindowedValueCoder.of((Coder)oCoder, (Coder)windowingStrategy.getWindowFn().windowCoder());
                boolean hasDefault = transform.isInsertDefault();
                SparkGlobalCombineFn sparkCombineFn = new SparkGlobalCombineFn(combineFn, context.getSerializableOptions(), TranslationUtils.getSideInputs(transform.getSideInputs(), context), windowingStrategy);
                try {
                    aCoder = combineFn.getAccumulatorCoder(context.getPipeline().getCoderRegistry(), iCoder);
                }
                catch (CannotProvideCoderException e) {
                    throw new IllegalStateException("Could not determine coder for accumulator", e);
                }
                JavaRDD inRdd = ((BoundedDataset)context.borrowDataset((PTransform<? extends PValue, ?>)transform)).getRDD();
                Optional maybeAccumulated = GroupCombineFunctions.combineGlobally(inRdd, sparkCombineFn, iCoder, aCoder, windowingStrategy);
                if (maybeAccumulated.isPresent()) {
                    Iterable output = sparkCombineFn.extractOutput(maybeAccumulated.get());
                    outRdd = context.getSparkContext().parallelize(CoderHelpers.toByteArrays(output, wvoCoder)).map(CoderHelpers.fromByteFunction(wvoCoder));
                } else {
                    JavaSparkContext jsc = new JavaSparkContext(inRdd.context());
                    if (hasDefault) {
                        Object defaultValue = combineFn.defaultValue();
                        outRdd = jsc.parallelize(Lists.newArrayList(new byte[][]{CoderHelpers.toByteArray(defaultValue, oCoder)})).map(CoderHelpers.fromByteFunction(oCoder)).map(WindowingHelpers.windowFunction());
                    } else {
                        outRdd = jsc.emptyRDD();
                    }
                }
                context.putDataset((PTransform<?, ? extends PValue>)transform, new BoundedDataset(outRdd));
            }

            @Override
            public String toNativeString() {
                return "aggregate(..., new <fn>(), ...)";
            }
        };
    }

    private static <K, InputT, AccumT, OutputT> TransformEvaluator<Combine.PerKey<K, InputT, OutputT>> combinePerKey() {
        return new TransformEvaluator<Combine.PerKey<K, InputT, OutputT>>(){

            @Override
            public void evaluate(Combine.PerKey<K, InputT, OutputT> transform, EvaluationContext context) {
                Coder vaCoder;
                PCollection input = (PCollection)context.getInput(transform);
                KvCoder inputCoder = (KvCoder)((PCollection)context.getInput(transform)).getCoder();
                CombineWithContext.CombineFnWithContext combineFn = CombineFnUtil.toFnWithContext((CombineFnBase.GlobalCombineFn)transform.getFn());
                WindowingStrategy windowingStrategy = input.getWindowingStrategy();
                Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), context);
                SparkKeyedCombineFn sparkCombineFn = new SparkKeyedCombineFn(combineFn, context.getSerializableOptions(), sideInputs, windowingStrategy);
                try {
                    vaCoder = combineFn.getAccumulatorCoder(context.getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
                }
                catch (CannotProvideCoderException e) {
                    throw new IllegalStateException("Could not determine coder for accumulator", e);
                }
                JavaRDD inRdd = ((BoundedDataset)context.borrowDataset((PTransform<? extends PValue, ?>)transform)).getRDD();
                JavaPairRDD accumulatePerKey = GroupCombineFunctions.combinePerKey(inRdd, sparkCombineFn, inputCoder.getKeyCoder(), inputCoder.getValueCoder(), vaCoder, windowingStrategy);
                JavaRDD outRdd = accumulatePerKey.flatMapValues(sparkCombineFn::extractOutput).map(TranslationUtils.fromPairFunction()).map(TranslationUtils.toKVByWindowInValue());
                context.putDataset((PTransform<?, ? extends PValue>)transform, new BoundedDataset(outRdd));
            }

            @Override
            public String toNativeString() {
                return "combineByKey(..., new <fn>(), ...)";
            }
        };
    }

    private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>> parDo() {
        return new TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>(){

            @Override
            public void evaluate(ParDo.MultiOutput<InputT, OutputT> transform, EvaluationContext context) {
                String stepName = context.getCurrentTransform().getFullName();
                DoFn doFn = transform.getFn();
                TranslationUtils.rejectSplittable(doFn);
                JavaRDD inRDD = ((BoundedDataset)context.borrowDataset((PTransform<? extends PValue, ?>)transform)).getRDD();
                WindowingStrategy windowingStrategy = ((PCollection)context.getInput(transform)).getWindowingStrategy();
                Accumulator<MetricsContainerStepMap> metricsAccum = MetricsAccumulator.getInstance();
                DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass());
                boolean stateful = signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0;
                MultiDoFnFunction multiDoFnFunction = new MultiDoFnFunction(metricsAccum, stepName, doFn, context.getSerializableOptions(), transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), TranslationUtils.getSideInputs(transform.getSideInputs(), context), windowingStrategy, stateful);
                JavaPairRDD all = stateful ? TransformTranslator.statefulParDoTransform((KvCoder)((PCollection)context.getInput(transform)).getCoder(), (Coder<? extends BoundedWindow>)windowingStrategy.getWindowFn().windowCoder(), inRDD, multiDoFnFunction) : inRDD.mapPartitionsToPair(multiDoFnFunction);
                Map<TupleTag<?>, PValue> outputs = context.getOutputs((PTransform<?, ?>)transform);
                if (outputs.size() > 1) {
                    StorageLevel level = StorageLevel.fromString((String)context.storageLevel());
                    if (TranslationUtils.avoidRddSerialization(level)) {
                        all = all.persist(level);
                    } else {
                        Map<TupleTag<?>, Coder<WindowedValue<?>>> coderMap = TranslationUtils.getTupleTagCoders(outputs);
                        all = all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap)).persist(level).mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
                    }
                }
                for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
                    JavaPairRDD filtered = all.filter(new TranslationUtils.TupleTagFilter(output.getKey()));
                    JavaRDD values = filtered.values();
                    context.putDataset(output.getValue(), new BoundedDataset(values), false);
                }
            }

            @Override
            public String toNativeString() {
                return "mapPartitions(new <fn>())";
            }
        };
    }

    private static <K, V, OutputT> JavaPairRDD<TupleTag<?>, WindowedValue<?>> statefulParDoTransform(KvCoder<K, V> kvCoder, Coder<? extends BoundedWindow> windowCoder, JavaRDD<WindowedValue<KV<K, V>>> kvInRDD, MultiDoFnFunction<KV<K, V>, OutputT> doFnFunction) {
        Coder keyCoder = kvCoder.getKeyCoder();
        WindowedValue.FullWindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of((Coder)kvCoder.getValueCoder(), windowCoder);
        JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupRDD = GroupCombineFunctions.groupByKeyOnly(kvInRDD, keyCoder, wvCoder);
        return groupRDD.map((Function & Serializable)input -> {
            Object key = ((KV)input.getValue()).getKey();
            Iterable value = (Iterable)((KV)input.getValue()).getValue();
            return FluentIterable.from(value).transform(windowedValue -> windowedValue.withValue((Object)KV.of((Object)key, (Object)windowedValue.getValue()))).iterator();
        }).flatMapToPair(doFnFunction);
    }

    private static <T> TransformEvaluator<Read.Bounded<T>> readBounded() {
        return new TransformEvaluator<Read.Bounded<T>>(){

            @Override
            public void evaluate(Read.Bounded<T> transform, EvaluationContext context) {
                String stepName = context.getCurrentTransform().getFullName();
                JavaSparkContext jsc = context.getSparkContext();
                JavaRDD input = new SourceRDD.Bounded(jsc.sc(), transform.getSource(), context.getSerializableOptions(), stepName).toJavaRDD();
                context.putDataset((PTransform<?, ? extends PValue>)transform, (Dataset)new BoundedDataset(input), true);
            }

            @Override
            public String toNativeString() {
                return "sparkContext.<readFrom(<source>)>()";
            }
        };
    }

    private static <T, W extends BoundedWindow> TransformEvaluator<Window.Assign<T>> window() {
        return new TransformEvaluator<Window.Assign<T>>(){

            @Override
            public void evaluate(Window.Assign<T> transform, EvaluationContext context) {
                JavaRDD inRDD = ((BoundedDataset)context.borrowDataset((PTransform<? extends PValue, ?>)transform)).getRDD();
                if (TranslationUtils.skipAssignWindows(transform, context)) {
                    context.putDataset((PTransform<?, ? extends PValue>)transform, new BoundedDataset(inRDD));
                } else {
                    context.putDataset((PTransform<?, ? extends PValue>)transform, new BoundedDataset(inRDD.map(new SparkAssignWindowFn(transform.getWindowFn()))));
                }
            }

            @Override
            public String toNativeString() {
                return "map(new <windowFn>())";
            }
        };
    }

    private static <T> TransformEvaluator<Create.Values<T>> create() {
        return new TransformEvaluator<Create.Values<T>>(){

            @Override
            public void evaluate(Create.Values<T> transform, EvaluationContext context) {
                Iterable elems = transform.getElements();
                Coder coder = ((PCollection)context.getOutput(transform)).getCoder();
                context.putBoundedDatasetFromValues((PTransform<?, ? extends PValue>)transform, elems, coder);
            }

            @Override
            public String toNativeString() {
                return "sparkContext.parallelize(Arrays.asList(...))";
            }
        };
    }

    private static <ReadT, WriteT> TransformEvaluator<View.CreatePCollectionView<ReadT, WriteT>> createPCollView() {
        return new TransformEvaluator<View.CreatePCollectionView<ReadT, WriteT>>(){

            @Override
            public void evaluate(View.CreatePCollectionView<ReadT, WriteT> transform, EvaluationContext context) {
                Iterable<WindowedValue<?>> iter = context.getWindowedValues((PCollection)context.getInput(transform));
                PCollectionView output = transform.getView();
                IterableCoder coderInternal = IterableCoder.of((Coder)WindowedValue.getFullCoder((Coder)output.getCoderInternal(), (Coder)output.getWindowingStrategyInternal().getWindowFn().windowCoder()));
                Iterable<WindowedValue<?>> iterCast = iter;
                context.putPView((PCollectionView<?>)output, iterCast, (Coder<Iterable<WindowedValue<?>>>)coderInternal);
            }

            @Override
            public String toNativeString() {
                return "<createPCollectionView>";
            }
        };
    }

    private static <K, V, W extends BoundedWindow> TransformEvaluator<Reshuffle<K, V>> reshuffle() {
        return new TransformEvaluator<Reshuffle<K, V>>(){

            @Override
            public void evaluate(Reshuffle<K, V> transform, EvaluationContext context) {
                JavaRDD inRDD = ((BoundedDataset)context.borrowDataset((PTransform<? extends PValue, ?>)transform)).getRDD();
                WindowingStrategy windowingStrategy = ((PCollection)context.getInput(transform)).getWindowingStrategy();
                KvCoder coder = (KvCoder)((PCollection)context.getInput(transform)).getCoder();
                WindowFn windowFn = windowingStrategy.getWindowFn();
                Coder keyCoder = coder.getKeyCoder();
                WindowedValue.FullWindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of((Coder)coder.getValueCoder(), (Coder)windowFn.windowCoder());
                JavaRDD reshuffled = GroupCombineFunctions.reshuffle(inRDD, keyCoder, wvCoder);
                context.putDataset((PTransform<?, ? extends PValue>)transform, new BoundedDataset(reshuffled));
            }

            @Override
            public String toNativeString() {
                return "repartition(...)";
            }
        };
    }

    static {
        EVALUATORS.put(Read.Bounded.class, TransformTranslator.readBounded());
        EVALUATORS.put(ParDo.MultiOutput.class, TransformTranslator.parDo());
        EVALUATORS.put(GroupByKey.class, TransformTranslator.groupByKey());
        EVALUATORS.put(Combine.GroupedValues.class, TransformTranslator.combineGrouped());
        EVALUATORS.put(Combine.Globally.class, TransformTranslator.combineGlobally());
        EVALUATORS.put(Combine.PerKey.class, TransformTranslator.combinePerKey());
        EVALUATORS.put(Flatten.PCollections.class, TransformTranslator.flattenPColl());
        EVALUATORS.put(Create.Values.class, TransformTranslator.create());
        EVALUATORS.put(View.CreatePCollectionView.class, TransformTranslator.createPCollView());
        EVALUATORS.put(Window.Assign.class, TransformTranslator.window());
        EVALUATORS.put(Reshuffle.class, TransformTranslator.reshuffle());
    }

    public static class Translator
    implements SparkPipelineTranslator {
        @Override
        public boolean hasTranslation(Class<? extends PTransform<?, ?>> clazz) {
            return EVALUATORS.containsKey(clazz);
        }

        @Override
        public <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> translateBounded(Class<TransformT> clazz) {
            TransformEvaluator transformEvaluator = (TransformEvaluator)EVALUATORS.get(clazz);
            Preconditions.checkState(transformEvaluator != null, "No TransformEvaluator registered for BOUNDED transform %s", clazz);
            return transformEvaluator;
        }

        @Override
        public <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> translateUnbounded(Class<TransformT> clazz) {
            throw new IllegalStateException("TransformTranslator used in a batch pipeline only supports BOUNDED transforms.");
        }
    }
}

