/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.spark.classloader_interface;

import com.facebook.presto.spark.classloader_interface.MutablePartitionId;
import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow;
import com.facebook.presto.spark.classloader_interface.PrestoSparkNativeExecutionShuffleManager;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleReadDescriptor;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleWriteDescriptor;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskProcessor;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskRdd;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskSourceRdd;
import com.facebook.presto.spark.classloader_interface.ScalaUtils;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskSource;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.Partition;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.ShuffledRDD;
import org.apache.spark.rdd.ShuffledRDDPartition;
import org.apache.spark.rdd.ZippedPartitionsPartition;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.storage.BlockManagerId;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.Iterator;
import scala.collection.JavaConversions;
import scala.collection.Seq;

public class PrestoSparkNativeTaskRdd<T extends PrestoSparkTaskOutput>
extends PrestoSparkTaskRdd<T> {
    public static <T extends PrestoSparkTaskOutput> PrestoSparkNativeTaskRdd<T> create(SparkContext context, Optional<PrestoSparkTaskSourceRdd> taskSourceRdd, Map<String, RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRddMap, PrestoSparkTaskProcessor<T> taskProcessor) {
        Objects.requireNonNull(context, "context is null");
        Objects.requireNonNull(taskSourceRdd, "taskSourceRdd is null");
        Objects.requireNonNull(shuffleInputRddMap, "shuffleInputRddMap is null");
        Objects.requireNonNull(taskProcessor, "taskProcessor is null");
        ImmutableList.Builder shuffleInputFragmentIds = ImmutableList.builder();
        ImmutableList.Builder shuffleInputRdds = ImmutableList.builder();
        for (Map.Entry<String, RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> entry : shuffleInputRddMap.entrySet()) {
            shuffleInputFragmentIds.add((Object)entry.getKey());
            shuffleInputRdds.add(entry.getValue());
        }
        return new PrestoSparkNativeTaskRdd<T>(context, taskSourceRdd, (List<String>)shuffleInputFragmentIds.build(), (List<RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>>)shuffleInputRdds.build(), taskProcessor);
    }

    @Override
    public Iterator<Tuple2<MutablePartitionId, T>> compute(Partition split, TaskContext context) {
        PrestoSparkTaskSourceRdd taskSourceRdd = this.getTaskSourceRdd();
        List partitions = JavaConversions.seqAsJavaList((Seq)((ZippedPartitionsPartition)split).partitions());
        int expectedPartitionsSize = (taskSourceRdd != null ? 1 : 0) + this.getShuffleInputRdds().size();
        Preconditions.checkState((partitions.size() == expectedPartitionsSize ? 1 : 0) != 0, (Object)String.format("Unexpected partitions size. Expected: %s. Actual: %s.", expectedPartitionsSize, partitions.size()));
        Iterator taskSourceIterator = taskSourceRdd != null ? taskSourceRdd.iterator((Partition)partitions.get(partitions.size() - 1), context) : ScalaUtils.emptyScalaIterator();
        return this.getTaskProcessor().process((Iterator<SerializedPrestoSparkTaskSource>)taskSourceIterator, this.getShuffleReadDescriptors(partitions), this.getShuffleWriteDescriptor(split));
    }

    private PrestoSparkNativeTaskRdd(SparkContext context, Optional<PrestoSparkTaskSourceRdd> taskSourceRdd, List<String> shuffleInputFragmentIds, List<RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRdds, PrestoSparkTaskProcessor<T> taskProcessor) {
        super(context, taskSourceRdd, shuffleInputFragmentIds, shuffleInputRdds, taskProcessor);
    }

    private Map<String, PrestoSparkShuffleReadDescriptor> getShuffleReadDescriptors(List<Partition> partitions) {
        ImmutableMap.Builder shuffleReadDescriptors = ImmutableMap.builder();
        int numPartitions = partitions.size();
        List<RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRdds = this.getShuffleInputRdds();
        List<String> shuffleInputFragmentIds = this.getShuffleInputFragmentIds();
        Preconditions.checkState((numPartitions >= shuffleInputRdds.size() && numPartitions >= shuffleInputFragmentIds.size() ? 1 : 0) != 0, (Object)String.format("Size of shuffleInputRdds %d or shuffleInputFragmentIds %d is not equal to number of partitions %d", shuffleInputRdds.size(), shuffleInputFragmentIds.size(), numPartitions));
        for (int i = 0; i < shuffleInputRdds.size(); ++i) {
            Partition partition = partitions.get(i);
            Preconditions.checkState((partition != null ? 1 : 0) != 0);
            Preconditions.checkState((boolean)(partition instanceof ShuffledRDDPartition), (String)"partition is required to be ShuffledRddPartition, but got: %s", (Object)partition.getClass().getName());
            RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>> shuffleRdd = shuffleInputRdds.get(i);
            Preconditions.checkState((shuffleRdd != null ? 1 : 0) != 0);
            Preconditions.checkState((boolean)(shuffleRdd instanceof ShuffledRDD), (String)"ShuffledRdd is required but got: %s", (Object)shuffleRdd.getClass().getName());
            ShuffleHandle handle = ((ShuffleDependency)shuffleRdd.dependencies().head()).shuffleHandle();
            shuffleReadDescriptors.put((Object)shuffleInputFragmentIds.get(i), (Object)new PrestoSparkShuffleReadDescriptor(partition, handle, shuffleRdd.getNumPartitions(), this.getBlockIds((ShuffledRDDPartition)partition, handle), this.getPartitionSize((ShuffledRDDPartition)partition, handle)));
        }
        return shuffleReadDescriptors.build();
    }

    private Optional<PrestoSparkShuffleWriteDescriptor> getShuffleWriteDescriptor(Partition split) {
        Preconditions.checkState((boolean)(SparkEnv.get().shuffleManager() instanceof PrestoSparkNativeExecutionShuffleManager), (String)"Native execution requires to use PrestoSparkNativeExecutionShuffleManager. But got: %s", (Object)SparkEnv.get().shuffleManager().getClass().getName());
        PrestoSparkNativeExecutionShuffleManager shuffleManager = (PrestoSparkNativeExecutionShuffleManager)SparkEnv.get().shuffleManager();
        Optional<ShuffleHandle> shuffleHandle = shuffleManager.getShuffleHandle(split.index());
        return shuffleHandle.map(handle -> new PrestoSparkShuffleWriteDescriptor((ShuffleHandle)handle, shuffleManager.getNumOfPartitions(handle.shuffleId())));
    }

    private List<String> getBlockIds(ShuffledRDDPartition partition, ShuffleHandle shuffleHandle) {
        MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker();
        Collection mapSizes = JavaConversions.asJavaCollection((Iterable)mapOutputTracker.getMapSizesByExecutorId(shuffleHandle.shuffleId(), partition.idx(), partition.idx() + 1));
        return mapSizes.stream().map(item -> ((BlockManagerId)item._1).executorId()).collect(Collectors.toList());
    }

    private List<Long> getPartitionSize(ShuffledRDDPartition partition, ShuffleHandle shuffleHandle) {
        MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker();
        Collection mapSizes = JavaConversions.asJavaCollection((Iterable)mapOutputTracker.getMapSizesByExecutorId(shuffleHandle.shuffleId(), partition.idx(), partition.idx() + 1));
        return mapSizes.stream().map(item -> JavaConversions.seqAsJavaList((Seq)((Seq)item._2)).stream().map(item2 -> (Long)item2._2).reduce(0L, Long::sum)).collect(Collectors.toList());
    }
}

