/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.spark.planner;

import com.facebook.airlift.concurrent.MoreFutures;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.execution.Lifespan;
import com.facebook.presto.execution.ScheduledSplit;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.spark.PrestoSparkSessionProperties;
import com.facebook.presto.spark.PrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.IntegerIdentityPartitioner;
import com.facebook.presto.spark.classloader_interface.PrestoSparkRow;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider;
import com.facebook.presto.spark.classloader_interface.PrestoSparkZipRdd;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.SerializedTaskStats;
import com.facebook.presto.spark.classloader_interface.TaskProcessors;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.connector.ConnectorSplitManager;
import com.facebook.presto.spi.connector.NotPartitionedPartitionHandle;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.split.SplitManager;
import com.facebook.presto.split.SplitSource;
import com.facebook.presto.sql.planner.PartitioningHandle;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Future;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.inject.Inject;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import org.apache.spark.util.CollectionAccumulator;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

public class PrestoSparkRddFactory {
    private final SplitManager splitManager;
    private final Metadata metadata;
    private final JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec;

    @Inject
    public PrestoSparkRddFactory(SplitManager splitManager, Metadata metadata, JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec) {
        this.splitManager = Objects.requireNonNull(splitManager, "splitManager is null");
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.taskDescriptorJsonCodec = Objects.requireNonNull(taskDescriptorJsonCodec, "taskDescriptorJsonCodec is null");
    }

    public JavaPairRDD<Integer, PrestoSparkRow> createSparkRdd(JavaSparkContext sparkContext, Session session, PlanFragment fragment, Map<PlanFragmentId, JavaPairRDD<Integer, PrestoSparkRow>> rddInputs, Map<PlanFragmentId, Broadcast<List<PrestoSparkSerializedPage>>> broadcastInputs, PrestoSparkTaskExecutorFactoryProvider executorFactoryProvider, CollectionAccumulator<SerializedTaskStats> taskStatsCollector, TableWriteInfo tableWriteInfo) {
        Preconditions.checkArgument((!fragment.getStageExecutionDescriptor().isStageGroupedExecution() ? 1 : 0) != 0, (String)"unexpected grouped execution fragment: %s", (Object)fragment.getId());
        PartitioningHandle partitioning = fragment.getPartitioning();
        if (!(partitioning.getConnectorHandle() instanceof SystemPartitioningHandle)) {
            throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.NOT_SUPPORTED, "Partitioned (bucketed) tables are not yet supported by Presto on Spark");
        }
        if (partitioning.equals((Object)SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION)) {
            throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.NOT_SUPPORTED, "Automatic writers scaling is not supported by Presto on Spark");
        }
        Preconditions.checkArgument((!partitioning.equals((Object)SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION) ? 1 : 0) != 0, (Object)"FIXED_ARBITRARY_DISTRIBUTION is not supported");
        Preconditions.checkArgument((!partitioning.equals((Object)SystemPartitioningHandle.COORDINATOR_DISTRIBUTION) ? 1 : 0) != 0, (Object)"COORDINATOR_DISTRIBUTION fragment must be run on the driver");
        Preconditions.checkArgument((!partitioning.equals((Object)SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION) ? 1 : 0) != 0, (Object)"FIXED_BROADCAST_DISTRIBUTION can only be set as an output partitioning scheme, and not as a fragment distribution");
        Preconditions.checkArgument((!partitioning.equals((Object)SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION) ? 1 : 0) != 0, (Object)"FIXED_PASSTHROUGH_DISTRIBUTION can only be set as local exchange partitioning");
        Preconditions.checkArgument((!partitioning.equals((Object)SystemPartitioningHandle.ARBITRARY_DISTRIBUTION) ? 1 : 0) != 0, (Object)"ARBITRARY_DISTRIBUTION is not expected to be set as a fragment distribution");
        int hashPartitionCount = SystemSessionProperties.getHashPartitionCount((Session)session);
        if (fragment.getPartitioningScheme().getPartitioning().getHandle().equals((Object)SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION)) {
            fragment = fragment.withBucketToPartition(Optional.of(IntStream.range(0, hashPartitionCount).toArray()));
        }
        if (partitioning.equals((Object)SystemPartitioningHandle.SINGLE_DISTRIBUTION) || partitioning.equals((Object)SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION)) {
            Preconditions.checkArgument((boolean)fragment.getTableScanSchedulingOrder().isEmpty(), (String)"Fragment with is not expected to have table scans. fragmentId: %s, fragment partitioning %s", (Object)fragment.getId(), (Object)fragment.getPartitioning());
            for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) {
                if (!remoteSource.isEnsureSourceOrdering() && !remoteSource.getOrderingScheme().isPresent()) continue;
                throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.NOT_SUPPORTED, String.format("Order sensitive exchange is not supported by Presto on Spark. fragmentId: %s, sourceFragmentIds: %s", fragment.getId(), remoteSource.getSourceFragmentIds()));
            }
            Partitioner inputPartitioner = PrestoSparkRddFactory.createPartitioner(partitioning, hashPartitionCount);
            Map partitionedInputs = (Map)rddInputs.entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> ((JavaPairRDD)entry.getValue()).partitionBy(inputPartitioner)));
            return this.createIntermediateRdd(sparkContext, session, fragment, executorFactoryProvider, taskStatsCollector, tableWriteInfo, partitionedInputs, broadcastInputs);
        }
        if (partitioning.equals((Object)SystemPartitioningHandle.SOURCE_DISTRIBUTION)) {
            Preconditions.checkArgument((boolean)rddInputs.isEmpty(), (String)"rddInputs is expected to be empty for SOURCE_DISTRIBUTION fragment: %s", (Object)fragment.getId());
            return this.createSourceRdd(sparkContext, session, fragment, executorFactoryProvider, taskStatsCollector, tableWriteInfo, broadcastInputs);
        }
        throw new IllegalArgumentException(String.format("Unexpected fragment partitioning %s, fragmentId: %s", partitioning, fragment.getId()));
    }

    private static Partitioner createPartitioner(PartitioningHandle partitioning, int partitionCount) {
        if (partitioning.equals((Object)SystemPartitioningHandle.SINGLE_DISTRIBUTION)) {
            return new IntegerIdentityPartitioner(1);
        }
        if (partitioning.equals((Object)SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION)) {
            return new IntegerIdentityPartitioner(partitionCount);
        }
        throw new IllegalArgumentException(String.format("Unexpected fragment partitioning %s", partitioning));
    }

    private JavaPairRDD<Integer, PrestoSparkRow> createIntermediateRdd(JavaSparkContext sparkContext, Session session, PlanFragment fragment, PrestoSparkTaskExecutorFactoryProvider executorFactoryProvider, CollectionAccumulator<SerializedTaskStats> taskStatsCollector, TableWriteInfo tableWriteInfo, Map<PlanFragmentId, JavaPairRDD<Integer, PrestoSparkRow>> rddInputs, Map<PlanFragmentId, Broadcast<List<PrestoSparkSerializedPage>>> broadcastInputs) {
        PrestoSparkRddFactory.checkInputs(fragment.getRemoteSourceNodes(), rddInputs, broadcastInputs);
        List<TableScanNode> tableScans = PrestoSparkRddFactory.findTableScanNodes(fragment.getRoot());
        Verify.verify((boolean)tableScans.isEmpty(), (String)"no table scans is expected", (Object[])new Object[0]);
        PrestoSparkTaskDescriptor taskDescriptor = this.createIntermediateTaskDescriptor(session, tableWriteInfo, fragment);
        SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor = new SerializedPrestoSparkTaskDescriptor(this.taskDescriptorJsonCodec.toJsonBytes((Object)taskDescriptor));
        if (rddInputs.size() == 0) {
            Preconditions.checkArgument((boolean)fragment.getPartitioning().equals((Object)SystemPartitioningHandle.SINGLE_DISTRIBUTION), (String)"SINGLE_DISTRIBUTION partitioning is expected: %s", (Object)fragment.getPartitioning());
            return sparkContext.parallelize((List)ImmutableList.of((Object)serializedTaskDescriptor), 1).mapPartitionsToPair(TaskProcessors.createTaskProcessor((PrestoSparkTaskExecutorFactoryProvider)executorFactoryProvider, taskStatsCollector, PrestoSparkRddFactory.toTaskProcessorBroadcastInputs(broadcastInputs)));
        }
        ImmutableList.Builder fragmentIds = ImmutableList.builder();
        ImmutableList.Builder rdds = ImmutableList.builder();
        for (Map.Entry<PlanFragmentId, JavaPairRDD<Integer, PrestoSparkRow>> input : rddInputs.entrySet()) {
            fragmentIds.add((Object)input.getKey().toString());
            rdds.add((Object)input.getValue().rdd());
        }
        Function taskProcessor = TaskProcessors.createTaskProcessor((PrestoSparkTaskExecutorFactoryProvider)executorFactoryProvider, (SerializedPrestoSparkTaskDescriptor)serializedTaskDescriptor, (List)fragmentIds.build(), taskStatsCollector, PrestoSparkRddFactory.toTaskProcessorBroadcastInputs(broadcastInputs));
        return JavaPairRDD.fromRDD((RDD)new PrestoSparkZipRdd(sparkContext.sc(), (List)rdds.build(), taskProcessor), PrestoSparkRddFactory.classTag(Integer.class), PrestoSparkRddFactory.classTag(PrestoSparkRow.class));
    }

    private JavaPairRDD<Integer, PrestoSparkRow> createSourceRdd(JavaSparkContext sparkContext, Session session, PlanFragment fragment, PrestoSparkTaskExecutorFactoryProvider executorFactoryProvider, CollectionAccumulator<SerializedTaskStats> taskStatsCollector, TableWriteInfo tableWriteInfo, Map<PlanFragmentId, Broadcast<List<PrestoSparkSerializedPage>>> broadcastInputs) {
        PrestoSparkRddFactory.checkInputs(fragment.getRemoteSourceNodes(), (Map<PlanFragmentId, JavaPairRDD<Integer, PrestoSparkRow>>)ImmutableMap.of(), broadcastInputs);
        List<TableScanNode> tableScans = PrestoSparkRddFactory.findTableScanNodes(fragment.getRoot());
        Preconditions.checkArgument((tableScans.size() == 1 ? 1 : 0) != 0, (String)"exactly one table scan is expected in SOURCE_DISTRIBUTION fragment. fragmentId: %s, actual number of table scans: %s", (Object)fragment.getId(), (int)tableScans.size());
        TableScanNode tableScan = (TableScanNode)Iterables.getOnlyElement(tableScans);
        List<ScheduledSplit> splits = this.getSplits(session, tableScan);
        Collections.shuffle(splits);
        int initialPartitionCount = PrestoSparkSessionProperties.getSparkInitialPartitionCount(session);
        int numTasks = Math.min(splits.size(), initialPartitionCount);
        if (numTasks == 0) {
            return JavaPairRDD.fromJavaRDD((JavaRDD)sparkContext.emptyRDD());
        }
        List<List<ScheduledSplit>> assignedSplits = PrestoSparkRddFactory.assignSplitsToTasks(splits, numTasks);
        splits = null;
        ImmutableList.Builder serializedTaskDescriptors = ImmutableList.builder();
        for (int i = 0; i < assignedSplits.size(); ++i) {
            List<ScheduledSplit> splitBatch = assignedSplits.get(i);
            PrestoSparkTaskDescriptor taskDescriptor = this.createSourceTaskDescriptor(session, tableWriteInfo, fragment, splitBatch);
            byte[] jsonSerializedTaskDescriptor = this.taskDescriptorJsonCodec.toJsonBytes((Object)taskDescriptor);
            serializedTaskDescriptors.add((Object)new SerializedPrestoSparkTaskDescriptor(jsonSerializedTaskDescriptor));
            assignedSplits.set(i, null);
        }
        return sparkContext.parallelize((List)serializedTaskDescriptors.build(), numTasks).mapPartitionsToPair(TaskProcessors.createTaskProcessor((PrestoSparkTaskExecutorFactoryProvider)executorFactoryProvider, taskStatsCollector, PrestoSparkRddFactory.toTaskProcessorBroadcastInputs(broadcastInputs)));
    }

    private List<ScheduledSplit> getSplits(Session session, TableScanNode tableScan) {
        ArrayList<ScheduledSplit> splits = new ArrayList<ScheduledSplit>();
        SplitSource splitSource = this.splitManager.getSplits(session, tableScan.getTable(), ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING);
        long sequenceId = 0L;
        while (!splitSource.isFinished()) {
            List splitBatch = ((SplitSource.SplitBatch)MoreFutures.getFutureValue((Future)splitSource.getNextBatch(NotPartitionedPartitionHandle.NOT_PARTITIONED, Lifespan.taskWide(), 1000))).getSplits();
            for (Split split : splitBatch) {
                splits.add(new ScheduledSplit(sequenceId++, tableScan.getId(), split));
            }
        }
        return splits;
    }

    private static List<List<ScheduledSplit>> assignSplitsToTasks(List<ScheduledSplit> splits, int numTasks) {
        Preconditions.checkArgument((numTasks > 0 ? 1 : 0) != 0, (Object)"numTasks must be greater then zero");
        ArrayList<List<ScheduledSplit>> assignedSplits = new ArrayList<List<ScheduledSplit>>();
        for (int i = 0; i < numTasks; ++i) {
            assignedSplits.add(new ArrayList());
        }
        for (int splitIndex = 0; splitIndex < splits.size(); ++splitIndex) {
            ((List)assignedSplits.get(splitIndex % numTasks)).add(splits.get(splitIndex));
        }
        return assignedSplits;
    }

    private PrestoSparkTaskDescriptor createIntermediateTaskDescriptor(Session session, TableWriteInfo tableWriteInfo, PlanFragment fragment) {
        return this.createSourceTaskDescriptor(session, tableWriteInfo, fragment, (List<ScheduledSplit>)ImmutableList.of());
    }

    private PrestoSparkTaskDescriptor createSourceTaskDescriptor(Session session, TableWriteInfo tableWriteInfo, PlanFragment fragment, List<ScheduledSplit> splits) {
        Map splitsByPlanNode = splits.stream().collect(Collectors.groupingBy(ScheduledSplit::getPlanNodeId, Collectors.mapping(Function.identity(), Collectors.toSet())));
        List taskSourceByPlanNode = (List)splitsByPlanNode.entrySet().stream().map(entry -> new TaskSource((PlanNodeId)entry.getKey(), (Set)entry.getValue(), (Set)ImmutableSet.of(), true)).collect(ImmutableList.toImmutableList());
        return new PrestoSparkTaskDescriptor(session.toSessionRepresentation(), session.getIdentity().getExtraCredentials(), fragment, taskSourceByPlanNode, tableWriteInfo);
    }

    private static List<TableScanNode> findTableScanNodes(PlanNode node) {
        return PlanNodeSearcher.searchFrom((PlanNode)node).where(TableScanNode.class::isInstance).findAll();
    }

    private static Map<String, Broadcast<List<PrestoSparkSerializedPage>>> toTaskProcessorBroadcastInputs(Map<PlanFragmentId, Broadcast<List<PrestoSparkSerializedPage>>> broadcastInputs) {
        return (Map)broadcastInputs.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((PlanFragmentId)entry.getKey()).toString(), Map.Entry::getValue));
    }

    private static void checkInputs(List<RemoteSourceNode> remoteSources, Map<PlanFragmentId, JavaPairRDD<Integer, PrestoSparkRow>> rddInputs, Map<PlanFragmentId, Broadcast<List<PrestoSparkSerializedPage>>> broadcastInputs) {
        Set expectedInputs = (Set)remoteSources.stream().map(RemoteSourceNode::getSourceFragmentIds).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet());
        Sets.SetView actualInputs = Sets.union(rddInputs.keySet(), broadcastInputs.keySet());
        Sets.SetView missingInputs = Sets.difference((Set)expectedInputs, (Set)actualInputs);
        Sets.SetView extraInputs = Sets.difference((Set)actualInputs, (Set)expectedInputs);
        Preconditions.checkArgument((missingInputs.isEmpty() && extraInputs.isEmpty() ? 1 : 0) != 0, (String)"rddInputs mismatch discovered. expected inputs: %s, actual rdd inputs: %s, actual broadcast inputs: %s, missing inputs: %s, extra inputs: %s", (Object[])new Object[]{expectedInputs, rddInputs.keySet(), broadcastInputs.keySet(), missingInputs, expectedInputs});
    }

    private static <T> ClassTag<T> classTag(Class<T> clazz) {
        return ClassTag$.MODULE$.apply(clazz);
    }
}

