/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.spark.translation;

import java.io.IOException;
import java.io.Serializable;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.core.construction.NativeTransforms;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.graph.PipelineNode;
import org.apache.beam.runners.core.construction.graph.QueryablePipeline;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.translation.BoundedDataset;
import org.apache.beam.runners.spark.translation.Dataset;
import org.apache.beam.runners.spark.translation.GroupCombineFunctions;
import org.apache.beam.runners.spark.translation.GroupNonMergingWindowsFunctions;
import org.apache.beam.runners.spark.translation.SparkExecutableStageContextFactory;
import org.apache.beam.runners.spark.translation.SparkExecutableStageExtractionFunction;
import org.apache.beam.runners.spark.translation.SparkExecutableStageFunction;
import org.apache.beam.runners.spark.translation.SparkGroupAlsoByWindowViaOutputBufferFn;
import org.apache.beam.runners.spark.translation.SparkPortablePipelineTranslator;
import org.apache.beam.runners.spark.translation.SparkTranslationContext;
import org.apache.beam.runners.spark.translation.TranslationUtils;
import org.apache.beam.runners.spark.util.ByteArray;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.spark.HashPartitioner;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class SparkBatchPortablePipelineTranslator
implements SparkPortablePipelineTranslator<SparkTranslationContext> {
    private static final Logger LOG = LoggerFactory.getLogger(SparkBatchPortablePipelineTranslator.class);
    private final ImmutableMap<String, PTransformTranslator> urnToTransformTranslator;

    @Override
    public Set<String> knownUrns() {
        return this.urnToTransformTranslator.keySet();
    }

    public SparkBatchPortablePipelineTranslator() {
        ImmutableMap.Builder translatorMap = ImmutableMap.builder();
        translatorMap.put((Object)"beam:transform:impulse:v1", SparkBatchPortablePipelineTranslator::translateImpulse);
        translatorMap.put((Object)"beam:transform:group_by_key:v1", SparkBatchPortablePipelineTranslator::translateGroupByKey);
        translatorMap.put((Object)"beam:runner:executable_stage:v1", SparkBatchPortablePipelineTranslator::translateExecutableStage);
        translatorMap.put((Object)"beam:transform:flatten:v1", SparkBatchPortablePipelineTranslator::translateFlatten);
        translatorMap.put((Object)"beam:transform:reshuffle:v1", SparkBatchPortablePipelineTranslator::translateReshuffle);
        this.urnToTransformTranslator = translatorMap.build();
    }

    @Override
    public void translate(RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        QueryablePipeline p = QueryablePipeline.forTransforms((Collection)pipeline.getRootTransformIdsList(), (RunnerApi.Components)pipeline.getComponents());
        for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) {
            for (String inputId : transformNode.getTransform().getInputsMap().values()) {
                context.incrementConsumptionCountBy(inputId, 1);
            }
            if (transformNode.getTransform().getSpec().getUrn().equals("beam:runner:executable_stage:v1")) {
                context.incrementConsumptionCountBy(PipelineTranslatorUtils.getExecutableStageIntermediateId((PipelineNode.PTransformNode)transformNode), transformNode.getTransform().getOutputsMap().size());
            }
            for (String outputId : transformNode.getTransform().getOutputsMap().values()) {
                WindowedValue.WindowedValueCoder outputCoder = PipelineTranslatorUtils.getWindowedValueCoder((String)outputId, (RunnerApi.Components)pipeline.getComponents());
                context.putCoder(outputId, (Coder)outputCoder);
            }
        }
        for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) {
            ((PTransformTranslator)this.urnToTransformTranslator.getOrDefault((Object)transformNode.getTransform().getSpec().getUrn(), SparkBatchPortablePipelineTranslator::urnNotFound)).translate(transformNode, pipeline, context);
        }
    }

    private static void urnNotFound(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        throw new IllegalArgumentException(String.format("Transform %s has unknown URN %s", transformNode.getId(), transformNode.getTransform().getSpec().getUrn()));
    }

    private static void translateImpulse(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        BoundedDataset<byte[]> output = new BoundedDataset<byte[]>((Iterable<byte[]>)Collections.singletonList(new byte[0]), context.getSparkContext(), (Coder<byte[]>)ByteArrayCoder.of());
        context.pushDataset(PipelineTranslatorUtils.getOutputId((PipelineNode.PTransformNode)transformNode), output);
    }

    private static <K, V> void translateGroupByKey(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        JavaRDD groupedByKeyAndWindow;
        RunnerApi.Components components = pipeline.getComponents();
        String inputId = PipelineTranslatorUtils.getInputId((PipelineNode.PTransformNode)transformNode);
        Dataset inputDataset = context.popDataset(inputId);
        JavaRDD inputRdd = ((BoundedDataset)inputDataset).getRDD();
        WindowedValue.WindowedValueCoder inputCoder = PipelineTranslatorUtils.getWindowedValueCoder((String)inputId, (RunnerApi.Components)components);
        KvCoder inputKvCoder = (KvCoder)inputCoder.getValueCoder();
        Coder inputKeyCoder = inputKvCoder.getKeyCoder();
        Coder inputValueCoder = inputKvCoder.getValueCoder();
        WindowingStrategy windowingStrategy = PipelineTranslatorUtils.getWindowingStrategy((String)inputId, (RunnerApi.Components)components);
        WindowFn windowFn = windowingStrategy.getWindowFn();
        WindowedValue.FullWindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of((Coder)inputValueCoder, (Coder)windowFn.windowCoder());
        Partitioner partitioner = SparkBatchPortablePipelineTranslator.getPartitioner(context);
        if (GroupNonMergingWindowsFunctions.isEligibleForGroupByWindow(windowingStrategy)) {
            groupedByKeyAndWindow = GroupNonMergingWindowsFunctions.groupByKeyAndWindow(inputRdd, inputKeyCoder, inputValueCoder, windowingStrategy, partitioner);
        } else {
            JavaRDD groupedByKeyOnly = GroupCombineFunctions.groupByKeyOnly(inputRdd, inputKeyCoder, wvCoder, partitioner);
            groupedByKeyAndWindow = groupedByKeyOnly.flatMap(new SparkGroupAlsoByWindowViaOutputBufferFn(windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory(), SystemReduceFn.buffering((Coder)inputValueCoder), context.serializablePipelineOptions));
        }
        context.pushDataset(PipelineTranslatorUtils.getOutputId((PipelineNode.PTransformNode)transformNode), new BoundedDataset(groupedByKeyAndWindow));
    }

    private static <InputT, OutputT, SideInputT> void translateExecutableStage(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        JavaRDD staged;
        RunnerApi.ExecutableStagePayload stagePayload;
        try {
            stagePayload = RunnerApi.ExecutableStagePayload.parseFrom((ByteString)transformNode.getTransform().getSpec().getPayload());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        String inputPCollectionId = stagePayload.getInput();
        Dataset inputDataset = context.popDataset(inputPCollectionId);
        Map outputs = transformNode.getTransform().getOutputsMap();
        BiMap outputExtractionMap = PipelineTranslatorUtils.createOutputMap(outputs.values());
        RunnerApi.Components components = pipeline.getComponents();
        Coder windowCoder = PipelineTranslatorUtils.getWindowingStrategy((String)inputPCollectionId, (RunnerApi.Components)components).getWindowFn().windowCoder();
        ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValue.WindowedValueCoder<SideInputT>>> broadcastVariables = SparkBatchPortablePipelineTranslator.broadcastSideInputs(stagePayload, context);
        if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
            Coder windowedInputCoder = PipelineTranslatorUtils.instantiateCoder((String)inputPCollectionId, (RunnerApi.Components)components);
            Coder valueCoder = ((WindowedValue.FullWindowedValueCoder)windowedInputCoder).getValueCoder();
            if (!(valueCoder instanceof KvCoder)) {
                throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
            }
            Coder keyCoder = ((KvCoder)valueCoder).getKeyCoder();
            Coder innerValueCoder = ((KvCoder)valueCoder).getValueCoder();
            WindowingStrategy windowingStrategy = PipelineTranslatorUtils.getWindowingStrategy((String)inputPCollectionId, (RunnerApi.Components)components);
            WindowFn windowFn = windowingStrategy.getWindowFn();
            WindowedValue.FullWindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of((Coder)innerValueCoder, (Coder)windowFn.windowCoder());
            JavaPairRDD groupedByKey = SparkBatchPortablePipelineTranslator.groupByKeyPair(inputDataset, keyCoder, wvCoder);
            SparkExecutableStageFunction function = new SparkExecutableStageFunction(stagePayload, context.jobInfo, (Map<String, Integer>)outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
            staged = groupedByKey.flatMap(function.forPair());
        } else {
            JavaRDD inputRdd2 = ((BoundedDataset)inputDataset).getRDD();
            SparkExecutableStageFunction function2 = new SparkExecutableStageFunction(stagePayload, context.jobInfo, (Map<String, Integer>)outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
            staged = inputRdd2.mapPartitions(function2);
        }
        String intermediateId = PipelineTranslatorUtils.getExecutableStageIntermediateId((PipelineNode.PTransformNode)transformNode);
        context.pushDataset(intermediateId, new Dataset(){

            @Override
            public void cache(String storageLevel, Coder<?> coder) {
                StorageLevel level = StorageLevel.fromString((String)storageLevel);
                staged.persist(level);
            }

            @Override
            public void action() {
                staged.foreach(TranslationUtils.emptyVoidFunction());
            }

            @Override
            public void setName(String name) {
                staged.setName(name);
            }
        });
        context.popDataset(intermediateId);
        for (String outputId : outputs.values()) {
            JavaRDD outputRdd = staged.flatMap(new SparkExecutableStageExtractionFunction((Integer)outputExtractionMap.get((Object)outputId)));
            context.pushDataset(outputId, new BoundedDataset(outputRdd));
        }
        if (outputs.isEmpty()) {
            JavaRDD outputRdd = staged.flatMap((FlatMapFunction & Serializable)rawUnionValue -> Collections.emptyIterator());
            context.pushDataset(String.format("EmptyOutputSink_%d", context.nextSinkId()), new BoundedDataset(outputRdd));
        }
    }

    private static <K, V> JavaPairRDD<ByteArray, Iterable<WindowedValue<KV<K, V>>>> groupByKeyPair(Dataset dataset, Coder<K> keyCoder, WindowedValue.WindowedValueCoder<V> wvCoder) {
        JavaRDD inputRdd = ((BoundedDataset)dataset).getRDD();
        return GroupCombineFunctions.groupByKeyPair(inputRdd, keyCoder, wvCoder);
    }

    private static <SideInputT> ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValue.WindowedValueCoder<SideInputT>>> broadcastSideInputs(RunnerApi.ExecutableStagePayload stagePayload, SparkTranslationContext context) {
        HashMap broadcastVariables = new HashMap();
        for (RunnerApi.ExecutableStagePayload.SideInputId sideInputId : stagePayload.getSideInputsList()) {
            RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
            String collectionId = stagePayloadComponents.getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
            if (broadcastVariables.containsKey(collectionId)) continue;
            Tuple2 tuple2 = SparkBatchPortablePipelineTranslator.broadcastSideInput(collectionId, stagePayloadComponents, context);
            broadcastVariables.put(collectionId, tuple2);
        }
        return ImmutableMap.copyOf(broadcastVariables);
    }

    private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValue.WindowedValueCoder<T>> broadcastSideInput(String collectionId, RunnerApi.Components components, SparkTranslationContext context) {
        BoundedDataset dataset = (BoundedDataset)context.popDataset(collectionId);
        WindowedValue.WindowedValueCoder coder = PipelineTranslatorUtils.getWindowedValueCoder((String)collectionId, (RunnerApi.Components)components);
        List<byte[]> bytes = dataset.getBytes(coder);
        Broadcast broadcast = context.getSparkContext().broadcast(bytes);
        return new Tuple2((Object)broadcast, (Object)coder);
    }

    private static <T> void translateFlatten(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        JavaRDD unionRDD;
        Map inputsMap = transformNode.getTransform().getInputsMap();
        if (inputsMap.isEmpty()) {
            unionRDD = context.getSparkContext().emptyRDD();
        } else {
            JavaRDD[] rdds = new JavaRDD[inputsMap.size()];
            int index = 0;
            for (String inputId : inputsMap.values()) {
                rdds[index] = ((BoundedDataset)context.popDataset(inputId)).getRDD();
                ++index;
            }
            unionRDD = context.getSparkContext().union(rdds);
        }
        context.pushDataset(PipelineTranslatorUtils.getOutputId((PipelineNode.PTransformNode)transformNode), new BoundedDataset(unionRDD));
    }

    private static <T> void translateReshuffle(PipelineNode.PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
        String inputId = PipelineTranslatorUtils.getInputId((PipelineNode.PTransformNode)transformNode);
        WindowedValue.WindowedValueCoder coder = PipelineTranslatorUtils.getWindowedValueCoder((String)inputId, (RunnerApi.Components)pipeline.getComponents());
        JavaRDD inRDD = ((BoundedDataset)context.popDataset(inputId)).getRDD();
        JavaRDD reshuffled = GroupCombineFunctions.reshuffle(inRDD, coder);
        context.pushDataset(PipelineTranslatorUtils.getOutputId((PipelineNode.PTransformNode)transformNode), new BoundedDataset(reshuffled));
    }

    private static @Nullable Partitioner getPartitioner(SparkTranslationContext context) {
        Long bundleSize = ((SparkPipelineOptions)context.serializablePipelineOptions.get().as(SparkPipelineOptions.class)).getBundleSize();
        return bundleSize > 0L ? null : new HashPartitioner(context.getSparkContext().defaultParallelism().intValue());
    }

    @Override
    public SparkTranslationContext createTranslationContext(JavaSparkContext jsc, SparkPipelineOptions options, JobInfo jobInfo) {
        return new SparkTranslationContext(jsc, options, jobInfo);
    }

    public static class IsSparkNativeTransform
    implements NativeTransforms.IsNativeTransform {
        public boolean test(RunnerApi.PTransform pTransform) {
            return "beam:transform:reshuffle:v1".equals(PTransformTranslation.urnForTransformOrNull((RunnerApi.PTransform)pTransform));
        }
    }

    static interface PTransformTranslator {
        public void translate(PipelineNode.PTransformNode var1, RunnerApi.Pipeline var2, SparkTranslationContext var3);
    }
}

