/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.spark.planner;

import com.facebook.airlift.json.Codec;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.execution.ScheduledSplit;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.spark.PrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.MutablePartitionId;
import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskProcessor;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskRdd;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskSourceRdd;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskSource;
import com.facebook.presto.spark.classloader_interface.SerializedTaskInfo;
import com.facebook.presto.spark.planner.PrestoSparkPartitionedSplitAssigner;
import com.facebook.presto.spark.planner.PrestoSparkSourceDistributionSplitAssigner;
import com.facebook.presto.spark.planner.PrestoSparkSplitAssigner;
import com.facebook.presto.spark.util.PrestoSparkUtils;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.split.CloseableSplitSourceProvider;
import com.facebook.presto.split.SplitManager;
import com.facebook.presto.split.SplitSource;
import com.facebook.presto.split.SplitSourceProvider;
import com.facebook.presto.sql.planner.PartitioningHandle;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.SplitSourceFactory;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Sets;
import com.google.common.collect.UnmodifiableIterator;
import io.airlift.units.DataSize;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import javax.inject.Inject;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import org.apache.spark.util.CollectionAccumulator;

public class PrestoSparkRddFactory {
    private static final Logger log = Logger.get(PrestoSparkRddFactory.class);
    private final SplitManager splitManager;
    private final PartitioningProviderManager partitioningProviderManager;
    private final JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec;
    private final Codec<TaskSource> taskSourceCodec;

    @Inject
    public PrestoSparkRddFactory(SplitManager splitManager, PartitioningProviderManager partitioningProviderManager, JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec, Codec<TaskSource> taskSourceCodec) {
        this.splitManager = Objects.requireNonNull(splitManager, "splitManager is null");
        this.partitioningProviderManager = Objects.requireNonNull(partitioningProviderManager, "partitioningProviderManager is null");
        this.taskDescriptorJsonCodec = Objects.requireNonNull(taskDescriptorJsonCodec, "taskDescriptorJsonCodec is null");
        this.taskSourceCodec = Objects.requireNonNull(taskSourceCodec, "taskSourceCodec is null");
    }

    public <T extends PrestoSparkTaskOutput> JavaPairRDD<MutablePartitionId, T> createSparkRdd(JavaSparkContext sparkContext, Session session, PlanFragment fragment, Map<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> rddInputs, Map<PlanFragmentId, Broadcast<?>> broadcastInputs, PrestoSparkTaskExecutorFactoryProvider executorFactoryProvider, CollectionAccumulator<SerializedTaskInfo> taskInfoCollector, CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector, TableWriteInfo tableWriteInfo, Class<T> outputType) {
        Preconditions.checkArgument((!fragment.getStageExecutionDescriptor().isStageGroupedExecution() ? 1 : 0) != 0, (String)"unexpected grouped execution fragment: %s", (Object)fragment.getId());
        PartitioningHandle partitioning = fragment.getPartitioning();
        if (partitioning.equals((Object)SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION)) {
            throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.NOT_SUPPORTED, "Automatic writers scaling is not supported by Presto on Spark");
        }
        Preconditions.checkArgument((!partitioning.equals((Object)SystemPartitioningHandle.COORDINATOR_DISTRIBUTION) ? 1 : 0) != 0, (Object)"COORDINATOR_DISTRIBUTION fragment must be run on the driver");
        Preconditions.checkArgument((!partitioning.equals((Object)SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION) ? 1 : 0) != 0, (Object)"FIXED_BROADCAST_DISTRIBUTION can only be set as an output partitioning scheme, and not as a fragment distribution");
        Preconditions.checkArgument((!partitioning.equals((Object)SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION) ? 1 : 0) != 0, (Object)"FIXED_PASSTHROUGH_DISTRIBUTION can only be set as local exchange partitioning");
        Preconditions.checkArgument((!partitioning.equals((Object)SystemPartitioningHandle.ARBITRARY_DISTRIBUTION) ? 1 : 0) != 0, (Object)"ARBITRARY_DISTRIBUTION is not expected to be set as a fragment distribution");
        if (partitioning.equals((Object)SystemPartitioningHandle.SINGLE_DISTRIBUTION) || partitioning.equals((Object)SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) || partitioning.equals((Object)SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals((Object)SystemPartitioningHandle.SOURCE_DISTRIBUTION) || partitioning.getConnectorId().isPresent()) {
            for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) {
                if (!remoteSource.isEnsureSourceOrdering() && !remoteSource.getOrderingScheme().isPresent()) continue;
                throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.NOT_SUPPORTED, String.format("Order sensitive exchange is not supported by Presto on Spark. fragmentId: %s, sourceFragmentIds: %s", fragment.getId(), remoteSource.getSourceFragmentIds()));
            }
            return this.createRdd(sparkContext, session, fragment, executorFactoryProvider, taskInfoCollector, shuffleStatsCollector, tableWriteInfo, rddInputs, broadcastInputs, outputType);
        }
        throw new IllegalArgumentException(String.format("Unexpected fragment partitioning %s, fragmentId: %s", partitioning, fragment.getId()));
    }

    private <T extends PrestoSparkTaskOutput> JavaPairRDD<MutablePartitionId, T> createRdd(JavaSparkContext sparkContext, Session session, PlanFragment fragment, PrestoSparkTaskExecutorFactoryProvider executorFactoryProvider, CollectionAccumulator<SerializedTaskInfo> taskInfoCollector, CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector, TableWriteInfo tableWriteInfo, Map<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> rddInputs, Map<PlanFragmentId, Broadcast<?>> broadcastInputs, Class<T> outputType) {
        Optional<Object> taskSourceRdd;
        PrestoSparkRddFactory.checkInputs(fragment.getRemoteSourceNodes(), rddInputs, broadcastInputs);
        PrestoSparkTaskDescriptor taskDescriptor = new PrestoSparkTaskDescriptor(session.toSessionRepresentation(), session.getIdentity().getExtraCredentials(), fragment, tableWriteInfo);
        SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor = new SerializedPrestoSparkTaskDescriptor(this.taskDescriptorJsonCodec.toJsonBytes((Object)taskDescriptor));
        Optional<Object> numberOfShufflePartitions = Optional.empty();
        HashMap<String, RDD> shuffleInputRddMap = new HashMap<String, RDD>();
        for (Map.Entry<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> input : rddInputs.entrySet()) {
            RDD rdd = input.getValue().rdd();
            shuffleInputRddMap.put(input.getKey().toString(), rdd);
            if (!numberOfShufflePartitions.isPresent()) {
                numberOfShufflePartitions = Optional.of(rdd.getNumPartitions());
                continue;
            }
            Preconditions.checkArgument((((Integer)numberOfShufflePartitions.get()).intValue() == rdd.getNumPartitions() ? 1 : 0) != 0, (String)"Incompatible number of input partitions: %s != %s", (Object)numberOfShufflePartitions.get(), (int)rdd.getNumPartitions());
        }
        PrestoSparkTaskProcessor taskProcessor = new PrestoSparkTaskProcessor(executorFactoryProvider, serializedTaskDescriptor, taskInfoCollector, shuffleStatsCollector, PrestoSparkRddFactory.toTaskProcessorBroadcastInputs(broadcastInputs), outputType);
        List<TableScanNode> tableScans = PrestoSparkRddFactory.findTableScanNodes(fragment.getRoot());
        if (!tableScans.isEmpty()) {
            try (CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider((arg_0, arg_1, arg_2, arg_3) -> ((SplitManager)this.splitManager).getSplits(arg_0, arg_1, arg_2, arg_3));){
                SplitSourceFactory splitSourceFactory = new SplitSourceFactory((SplitSourceProvider)splitSourceProvider, WarningCollector.NOOP);
                Map splitSources = splitSourceFactory.createSplitSources(fragment, session, tableWriteInfo);
                taskSourceRdd = Optional.of(this.createTaskSourcesRdd(fragment.getId(), sparkContext, session, fragment.getPartitioning(), tableScans, splitSources, numberOfShufflePartitions));
            }
        } else if (rddInputs.size() == 0) {
            Preconditions.checkArgument((boolean)fragment.getPartitioning().equals((Object)SystemPartitioningHandle.SINGLE_DISTRIBUTION), (String)"SINGLE_DISTRIBUTION partitioning is expected: %s", (Object)fragment.getPartitioning());
            PrestoSparkTaskSourceRdd prestoSparkTaskSourceRdd = new PrestoSparkTaskSourceRdd(sparkContext.sc(), (List)ImmutableList.of((Object)ImmutableList.of()));
            prestoSparkTaskSourceRdd.setName(PrestoSparkRddFactory.getRDDName(fragment.getId().getId()));
            taskSourceRdd = Optional.of(prestoSparkTaskSourceRdd);
        } else {
            taskSourceRdd = Optional.empty();
        }
        return JavaPairRDD.fromRDD((RDD)PrestoSparkTaskRdd.create((SparkContext)sparkContext.sc(), taskSourceRdd, shuffleInputRddMap, (PrestoSparkTaskProcessor)taskProcessor).setName(PrestoSparkRddFactory.getRDDName(fragment.getId().getId())), PrestoSparkUtils.classTag(MutablePartitionId.class), PrestoSparkUtils.classTag(outputType));
    }

    private PrestoSparkTaskSourceRdd createTaskSourcesRdd(PlanFragmentId fragmentId, JavaSparkContext sparkContext, Session session, PartitioningHandle partitioning, List<TableScanNode> tableScans, Map<PlanNodeId, SplitSource> splitSources, Optional<Integer> numberOfShufflePartitions) {
        ArrayListMultimap taskSourcesMap = ArrayListMultimap.create();
        for (TableScanNode tableScan : tableScans) {
            int totalNumberOfSplits = 0;
            SplitSource splitSource = Objects.requireNonNull(splitSources.get(tableScan.getId()), "split source is missing for table scan node with id: " + tableScan.getId());
            try (PrestoSparkSplitAssigner splitAssigner = this.createSplitAssigner(session, tableScan.getId(), splitSource, partitioning);){
                Optional<SetMultimap<Integer, ScheduledSplit>> batch;
                while ((batch = splitAssigner.getNextBatch()).isPresent()) {
                    int numberOfSplitsInCurrentBatch = batch.get().size();
                    log.info("Found %s splits for table scan node with id %s", new Object[]{numberOfSplitsInCurrentBatch, tableScan.getId()});
                    totalNumberOfSplits += numberOfSplitsInCurrentBatch;
                    taskSourcesMap.putAll(this.createTaskSources(tableScan.getId(), batch.get()));
                }
            }
            log.info("Total number of splits for table scan node with id %s: %s", new Object[]{tableScan.getId(), totalNumberOfSplits});
        }
        long allTaskSourcesSerializedSizeInBytes = taskSourcesMap.values().stream().mapToLong(serializedTaskSource -> serializedTaskSource.getBytes().length).sum();
        log.info("Total serialized size of all task sources for fragment %s: %s", new Object[]{fragmentId, DataSize.succinctBytes((long)allTaskSourcesSerializedSizeInBytes)});
        ArrayList<List<Object>> taskSourcesByPartitionId = new ArrayList<List<Object>>();
        if (numberOfShufflePartitions.isPresent()) {
            for (int partitionId = 0; partitionId < numberOfShufflePartitions.get(); ++partitionId) {
                taskSourcesByPartitionId.add(Objects.requireNonNull(taskSourcesMap.removeAll((Object)partitionId), "taskSources is null"));
            }
        } else {
            taskSourcesByPartitionId.addAll(Multimaps.asMap((ListMultimap)taskSourcesMap).values());
        }
        PrestoSparkTaskSourceRdd prestoSparkTaskSourceRdd = new PrestoSparkTaskSourceRdd(sparkContext.sc(), taskSourcesByPartitionId);
        prestoSparkTaskSourceRdd.setName(PrestoSparkRddFactory.getRDDName(fragmentId.getId()));
        return prestoSparkTaskSourceRdd;
    }

    private PrestoSparkSplitAssigner createSplitAssigner(Session session, PlanNodeId tableScanNodeId, SplitSource splitSource, PartitioningHandle fragmentPartitioning) {
        if (fragmentPartitioning.equals((Object)SystemPartitioningHandle.SOURCE_DISTRIBUTION)) {
            return PrestoSparkSourceDistributionSplitAssigner.create(session, tableScanNodeId, splitSource);
        }
        return PrestoSparkPartitionedSplitAssigner.create(session, tableScanNodeId, splitSource, fragmentPartitioning, this.partitioningProviderManager);
    }

    private ListMultimap<Integer, SerializedPrestoSparkTaskSource> createTaskSources(PlanNodeId tableScanId, SetMultimap<Integer, ScheduledSplit> assignedSplits) {
        ArrayListMultimap result = ArrayListMultimap.create();
        UnmodifiableIterator unmodifiableIterator = ImmutableSet.copyOf((Collection)assignedSplits.keySet()).iterator();
        while (unmodifiableIterator.hasNext()) {
            int partitionId = (Integer)unmodifiableIterator.next();
            Set splits = assignedSplits.removeAll((Object)partitionId);
            TaskSource taskSource = new TaskSource(tableScanId, splits, true);
            SerializedPrestoSparkTaskSource serializedTaskSource = new SerializedPrestoSparkTaskSource(PrestoSparkUtils.serializeZstdCompressed(this.taskSourceCodec, taskSource));
            result.put((Object)partitionId, (Object)serializedTaskSource);
        }
        return result;
    }

    private static List<TableScanNode> findTableScanNodes(PlanNode node) {
        return PlanNodeSearcher.searchFrom((PlanNode)node).where(TableScanNode.class::isInstance).findAll();
    }

    private static Map<String, Broadcast<?>> toTaskProcessorBroadcastInputs(Map<PlanFragmentId, Broadcast<?>> broadcastInputs) {
        return (Map)broadcastInputs.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((PlanFragmentId)entry.getKey()).toString(), Map.Entry::getValue));
    }

    private static void checkInputs(List<RemoteSourceNode> remoteSources, Map<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> rddInputs, Map<PlanFragmentId, Broadcast<?>> broadcastInputs) {
        Set expectedInputs = (Set)remoteSources.stream().map(RemoteSourceNode::getSourceFragmentIds).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet());
        Sets.SetView actualInputs = Sets.union(rddInputs.keySet(), broadcastInputs.keySet());
        Sets.SetView missingInputs = Sets.difference((Set)expectedInputs, (Set)actualInputs);
        Sets.SetView extraInputs = Sets.difference((Set)actualInputs, (Set)expectedInputs);
        Preconditions.checkArgument((missingInputs.isEmpty() && extraInputs.isEmpty() ? 1 : 0) != 0, (String)"rddInputs mismatch discovered. expected inputs: %s, actual rdd inputs: %s, actual broadcast inputs: %s, missing inputs: %s, extra inputs: %s", (Object[])new Object[]{expectedInputs, rddInputs.keySet(), broadcastInputs.keySet(), missingInputs, expectedInputs});
    }

    public static String getRDDName(int planFragmentId) {
        return "PlanFragment #" + planFragmentId;
    }
}

