/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.spark.classloader_interface;

import com.facebook.presto.spark.classloader_interface.PrestoSparkRow;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskInputs;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.SerializedTaskStats;
import com.facebook.presto.spark.classloader_interface.SparkProcessType;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.spark.TaskContext;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.CollectionAccumulator;
import scala.Tuple2;

public class TaskProcessors {
    private TaskProcessors() {
    }

    public static PairFlatMapFunction<Iterator<SerializedPrestoSparkTaskDescriptor>, Integer, PrestoSparkRow> createTaskProcessor(final PrestoSparkTaskExecutorFactoryProvider taskExecutorFactoryProvider, final CollectionAccumulator<SerializedTaskStats> taskStatsCollector, final Map<String, Broadcast<List<PrestoSparkSerializedPage>>> broadcastInputs) {
        return new PairFlatMapFunction<Iterator<SerializedPrestoSparkTaskDescriptor>, Integer, PrestoSparkRow>(){

            public Iterator<Tuple2<Integer, PrestoSparkRow>> call(Iterator<SerializedPrestoSparkTaskDescriptor> serializedTaskRequestIterator) {
                SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor = serializedTaskRequestIterator.next();
                if (serializedTaskRequestIterator.hasNext()) {
                    throw new IllegalArgumentException("each partition is expected to contain an exactly one task descriptor");
                }
                int partitionId = TaskContext.get().partitionId();
                int attemptNumber = TaskContext.get().attemptNumber();
                return taskExecutorFactoryProvider.get(SparkProcessType.EXECUTOR).create(partitionId, attemptNumber, serializedTaskDescriptor, new PrestoSparkTaskInputs(Collections.emptyMap(), broadcastInputs), (CollectionAccumulator<SerializedTaskStats>)taskStatsCollector);
            }
        };
    }

    public static Function<List<Iterator<Tuple2<Integer, PrestoSparkRow>>>, Iterator<Tuple2<Integer, PrestoSparkRow>>> createTaskProcessor(final PrestoSparkTaskExecutorFactoryProvider taskExecutorFactoryProvider, final SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor, final List<String> fragmentIds, final CollectionAccumulator<SerializedTaskStats> taskStatsCollector, final Map<String, Broadcast<List<PrestoSparkSerializedPage>>> broadcastInputs) {
        return new SerializableFunction<List<Iterator<Tuple2<Integer, PrestoSparkRow>>>, Iterator<Tuple2<Integer, PrestoSparkRow>>>(){

            @Override
            public Iterator<Tuple2<Integer, PrestoSparkRow>> apply(List<Iterator<Tuple2<Integer, PrestoSparkRow>>> iterators) {
                int partitionId = TaskContext.get().partitionId();
                int attemptNumber = TaskContext.get().attemptNumber();
                HashMap inputsMap = new HashMap();
                for (int i = 0; i < fragmentIds.size(); ++i) {
                    inputsMap.put(fragmentIds.get(i), iterators.get(i));
                }
                return taskExecutorFactoryProvider.get(SparkProcessType.EXECUTOR).create(partitionId, attemptNumber, serializedTaskDescriptor, new PrestoSparkTaskInputs(Collections.unmodifiableMap(inputsMap), broadcastInputs), (CollectionAccumulator<SerializedTaskStats>)taskStatsCollector);
            }
        };
    }

    public static interface SerializableFunction<T, R>
    extends Function<T, R>,
    Serializable {
    }
}

