/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.spark.translation;

import java.io.IOException;
import java.io.Serializable;
import java.util.EnumMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.stream.Collectors;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.DefaultJobBundleFactory;
import org.apache.beam.runners.fnexecution.control.JobBundleFactory;
import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
import org.apache.beam.runners.fnexecution.control.RemoteBundle;
import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
import org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

class SparkExecutableStageFunction<InputT, SideInputT>
implements FlatMapFunction<Iterator<WindowedValue<InputT>>, RawUnionValue> {
    private static final Logger LOG = LoggerFactory.getLogger(SparkExecutableStageFunction.class);
    private final RunnerApi.ExecutableStagePayload stagePayload;
    private final Map<String, Integer> outputMap;
    private final JobBundleFactoryCreator jobBundleFactoryCreator;
    private final Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValue.WindowedValueCoder<SideInputT>>> sideInputs;
    private final MetricsContainerStepMapAccumulator metricsAccumulator;

    SparkExecutableStageFunction(RunnerApi.ExecutableStagePayload stagePayload, JobInfo jobInfo, Map<String, Integer> outputMap, Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValue.WindowedValueCoder<SideInputT>>> sideInputs, MetricsContainerStepMapAccumulator metricsAccumulator) {
        this(stagePayload, outputMap, () -> DefaultJobBundleFactory.create((JobInfo)jobInfo), sideInputs, metricsAccumulator);
    }

    SparkExecutableStageFunction(RunnerApi.ExecutableStagePayload stagePayload, Map<String, Integer> outputMap, JobBundleFactoryCreator jobBundleFactoryCreator, Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValue.WindowedValueCoder<SideInputT>>> sideInputs, MetricsContainerStepMapAccumulator metricsAccumulator) {
        this.stagePayload = stagePayload;
        this.outputMap = outputMap;
        this.jobBundleFactoryCreator = jobBundleFactoryCreator;
        this.sideInputs = sideInputs;
        this.metricsAccumulator = metricsAccumulator;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public Iterator<RawUnionValue> call(Iterator<WindowedValue<InputT>> inputs) throws Exception {
        JobBundleFactory jobBundleFactory = this.jobBundleFactoryCreator.create();
        ExecutableStage executableStage = ExecutableStage.fromPayload((RunnerApi.ExecutableStagePayload)this.stagePayload);
        try (StageBundleFactory stageBundleFactory = jobBundleFactory.forStage(executableStage);){
            ConcurrentLinkedQueue<RawUnionValue> collector = new ConcurrentLinkedQueue<RawUnionValue>();
            ReceiverFactory receiverFactory = new ReceiverFactory(collector, this.outputMap);
            StateRequestHandler stateRequestHandler = this.getStateRequestHandler(executableStage, stageBundleFactory.getProcessBundleDescriptor());
            String stageName = this.stagePayload.getInput();
            final MetricsContainerImpl container = this.metricsAccumulator.value().getContainer(stageName);
            BundleProgressHandler bundleProgressHandler = new BundleProgressHandler(){

                public void onProgress(BeamFnApi.ProcessBundleProgressResponse progress) {
                    container.update((Iterable)progress.getMonitoringInfosList());
                }

                public void onCompleted(BeamFnApi.ProcessBundleResponse response) {
                    container.update((Iterable)response.getMonitoringInfosList());
                }
            };
            try (RemoteBundle bundle = stageBundleFactory.getBundle((OutputReceiverFactory)receiverFactory, stateRequestHandler, bundleProgressHandler);){
                String inputPCollectionId = executableStage.getInputPCollection().getId();
                FnDataReceiver mainReceiver = (FnDataReceiver)bundle.getInputReceivers().get(inputPCollectionId);
                while (inputs.hasNext()) {
                    WindowedValue<InputT> input = inputs.next();
                    mainReceiver.accept(input);
                }
            }
            Iterator<RawUnionValue> iterator = collector.iterator();
            return iterator;
        }
        catch (Exception e) {
            LOG.error("Spark executable stage fn terminated with exception: ", (Throwable)e);
            throw e;
        }
    }

    private StateRequestHandler getStateRequestHandler(ExecutableStage executableStage, ProcessBundleDescriptors.ExecutableProcessBundleDescriptor processBundleDescriptor) {
        StateRequestHandler sideInputHandler;
        EnumMap<BeamFnApi.StateKey.TypeCase, StateRequestHandler> handlerMap = new EnumMap<BeamFnApi.StateKey.TypeCase, StateRequestHandler>(BeamFnApi.StateKey.TypeCase.class);
        BatchSideInputHandlerFactory sideInputHandlerFactory = BatchSideInputHandlerFactory.forStage((ExecutableStage)executableStage, (BatchSideInputHandlerFactory.SideInputGetter)new BatchSideInputHandlerFactory.SideInputGetter(){

            public <T> List<T> getSideInput(String pCollectionId) {
                Tuple2 tuple2 = (Tuple2)SparkExecutableStageFunction.this.sideInputs.get(pCollectionId);
                Broadcast broadcast = (Broadcast)tuple2._1;
                WindowedValue.WindowedValueCoder coder = (WindowedValue.WindowedValueCoder)tuple2._2;
                return ((List)broadcast.value()).stream().map(bytes -> (WindowedValue)CoderHelpers.fromByteArray(bytes, coder)).collect(Collectors.toList());
            }
        });
        try {
            sideInputHandler = StateRequestHandlers.forSideInputHandlerFactory((Map)ProcessBundleDescriptors.getSideInputs((ExecutableStage)executableStage), (StateRequestHandlers.SideInputHandlerFactory)sideInputHandlerFactory);
        }
        catch (IOException e) {
            throw new RuntimeException("Failed to setup state handler", e);
        }
        handlerMap.put(BeamFnApi.StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputHandler);
        return StateRequestHandlers.delegateBasedUponType(handlerMap);
    }

    private static class ReceiverFactory
    implements OutputReceiverFactory {
        private final ConcurrentLinkedQueue<RawUnionValue> collector;
        private final Map<String, Integer> outputMap;

        ReceiverFactory(ConcurrentLinkedQueue<RawUnionValue> collector, Map<String, Integer> outputMap) {
            this.collector = collector;
            this.outputMap = outputMap;
        }

        public <OutputT> FnDataReceiver<OutputT> create(String pCollectionId) {
            Integer unionTag = this.outputMap.get(pCollectionId);
            if (unionTag == null) {
                throw new IllegalStateException(String.format(Locale.ENGLISH, "Unknown PCollectionId %s", pCollectionId));
            }
            int tagInt = unionTag;
            return receivedElement -> this.collector.add(new RawUnionValue(tagInt, receivedElement));
        }
    }

    static interface JobBundleFactoryCreator
    extends Serializable {
        public JobBundleFactory create();
    }
}

