/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.DoFnFunction;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.MultiOuputCoder;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.Dataset;
import scala.Tuple2;

class ParDoTranslatorBatch<InputT, OutputT>
implements TransformTranslator<PTransform<PCollection<InputT>, PCollectionTuple>> {
    ParDoTranslatorBatch() {
    }

    @Override
    public void translateTransform(PTransform<PCollection<InputT>, PCollectionTuple> transform, TranslationContext context) {
        String stepName = context.getCurrentTransform().getFullName();
        DoFn<InputT, OutputT> doFn = this.getDoFn(context);
        Preconditions.checkState((!DoFnSignatures.signatureForDoFn(doFn).processElement().isSplittable() ? 1 : 0) != 0, (String)"Not expected to directly translate splittable DoFn, should have been overridden: %s", doFn);
        DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
        boolean stateful = signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0;
        Preconditions.checkState((!stateful ? 1 : 0) != 0, (Object)"States and timers are not supported for the moment.");
        DoFnSchemaInformation doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform());
        PValue input = context.getInput();
        Dataset inputDataSet = context.getDataset(input);
        Map<TupleTag<?>, PValue> outputs = context.getOutputs();
        TupleTag<?> mainOutputTag = this.getTupleTag(context);
        ArrayList outputTags = new ArrayList(outputs.keySet());
        WindowingStrategy windowingStrategy = ((PCollection)input).getWindowingStrategy();
        Coder inputCoder = ((PCollection)input).getCoder();
        Coder windowCoder = windowingStrategy.getWindowFn().windowCoder();
        List<PCollectionView<?>> sideInputs = this.getSideInputs(context);
        HashMap sideInputStrategies = new HashMap();
        for (PCollectionView<?> sideInput : sideInputs) {
            sideInputStrategies.put(sideInput, sideInput.getPCollection().getWindowingStrategy());
        }
        SideInputBroadcast broadcastStateData = ParDoTranslatorBatch.createBroadcastSideInputs(sideInputs, context);
        Map<TupleTag<?>, Coder<?>> outputCoderMap = context.getOutputCoders();
        MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance();
        ArrayList additionalOutputTags = new ArrayList();
        for (TupleTag tupleTag : outputTags) {
            if (tupleTag.equals(mainOutputTag)) continue;
            additionalOutputTags.add(tupleTag);
        }
        Map sideInputMapping = ParDoTranslation.getSideInputMapping(context.getCurrentTransform());
        DoFnFunction<InputT, OutputT> doFnFunction = new DoFnFunction<InputT, OutputT>(metricsAccum, stepName, doFn, windowingStrategy, sideInputStrategies, context.getSerializableOptions(), additionalOutputTags, mainOutputTag, inputCoder, outputCoderMap, broadcastStateData, doFnSchemaInformation, sideInputMapping);
        MultiOuputCoder multipleOutputCoder = MultiOuputCoder.of((Coder<TupleTag>)SerializableCoder.of(TupleTag.class), outputCoderMap, (Coder<? extends BoundedWindow>)windowCoder);
        Dataset allOutputs = inputDataSet.mapPartitions(doFnFunction, EncoderHelpers.fromBeamCoder(multipleOutputCoder));
        if (outputs.entrySet().size() > 1) {
            allOutputs.persist();
            for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
                this.pruneOutputFilteredByTag(context, allOutputs, output, (Coder<BoundedWindow>)windowCoder);
            }
        } else {
            Coder outputCoder = ((PCollection)outputs.get(mainOutputTag)).getCoder();
            WindowedValue.FullWindowedValueCoder windowedValueCoder = WindowedValue.getFullCoder((Coder)outputCoder, (Coder)windowCoder);
            Dataset outputDataset = allOutputs.map((MapFunction & Serializable)value -> (WindowedValue)value._2, EncoderHelpers.fromBeamCoder(windowedValueCoder));
            context.putDatasetWildcard(outputs.entrySet().iterator().next().getValue(), outputDataset);
        }
    }

    private static SideInputBroadcast createBroadcastSideInputs(List<PCollectionView<?>> sideInputs, TranslationContext context) {
        JavaSparkContext jsc = JavaSparkContext.fromSparkContext((SparkContext)context.getSparkSession().sparkContext());
        SideInputBroadcast sideInputBroadcast = new SideInputBroadcast();
        for (PCollectionView<?> sideInput : sideInputs) {
            Coder windowCoder = sideInput.getPCollection().getWindowingStrategy().getWindowFn().windowCoder();
            WindowedValue.FullWindowedValueCoder windowedValueCoder = WindowedValue.getFullCoder((Coder)sideInput.getPCollection().getCoder(), (Coder)windowCoder);
            Dataset broadcastSet = context.getSideInputDataSet(sideInput);
            List valuesList = broadcastSet.collectAsList();
            ArrayList<byte[]> codedValues = new ArrayList<byte[]>();
            for (WindowedValue v : valuesList) {
                codedValues.add(CoderHelpers.toByteArray(v, windowedValueCoder));
            }
            sideInputBroadcast.add(sideInput.getTagInternal().getId(), (Broadcast<?>)jsc.broadcast(codedValues), (Coder<?>)windowedValueCoder);
        }
        return sideInputBroadcast;
    }

    private List<PCollectionView<?>> getSideInputs(TranslationContext context) {
        List sideInputs;
        try {
            sideInputs = ParDoTranslation.getSideInputs(context.getCurrentTransform());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return sideInputs;
    }

    private TupleTag<?> getTupleTag(TranslationContext context) {
        TupleTag mainOutputTag;
        try {
            mainOutputTag = ParDoTranslation.getMainOutputTag(context.getCurrentTransform());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return mainOutputTag;
    }

    private DoFn<InputT, OutputT> getDoFn(TranslationContext context) {
        DoFn doFn;
        try {
            doFn = ParDoTranslation.getDoFn(context.getCurrentTransform());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return doFn;
    }

    private void pruneOutputFilteredByTag(TranslationContext context, Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs, Map.Entry<TupleTag<?>, PValue> output, Coder<? extends BoundedWindow> windowCoder) {
        Dataset filteredDataset = allOutputs.filter((FilterFunction)new DoFnFilterFunction(output.getKey()));
        WindowedValue.FullWindowedValueCoder windowedValueCoder = WindowedValue.getFullCoder((Coder)((PCollection)output.getValue()).getCoder(), windowCoder);
        Dataset outputDataset = filteredDataset.map((MapFunction & Serializable)value -> (WindowedValue)value._2, EncoderHelpers.fromBeamCoder(windowedValueCoder));
        context.putDatasetWildcard(output.getValue(), outputDataset);
    }

    static class DoFnFilterFunction
    implements FilterFunction<Tuple2<TupleTag<?>, WindowedValue<?>>> {
        private final TupleTag<?> key;

        DoFnFilterFunction(TupleTag<?> key) {
            this.key = key;
        }

        public boolean call(Tuple2<TupleTag<?>, WindowedValue<?>> value) {
            return ((TupleTag)value._1).equals(this.key);
        }
    }
}

