PrestoSparkNativeTaskRdd.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.classloader_interface;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.Partition;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.ShuffledRDD;
import org.apache.spark.rdd.ShuffledRDDPartition;
import org.apache.spark.rdd.ZippedPartitionsPartition;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManagerId;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.Seq;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import static com.facebook.presto.spark.classloader_interface.ScalaUtils.emptyScalaIterator;
import static com.google.common.base.Preconditions.checkState;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static scala.collection.JavaConversions.asJavaCollection;
import static scala.collection.JavaConversions.seqAsJavaList;

/**
 * PrestoSparkTaskRdd represents execution of Presto stage, it contains:
 * - A list of shuffleInputRdds, each of the corresponding to a child stage.
 * - An optional taskSourceRdd, which represents ALL table scan inputs in this stage.
 * <p>
 * Table scan is present when joining a bucketed table with an unbucketed table, for example:
 * Join
 * /  \
 * Scan  Remote Source
 * <p>
 * In this case, bucket to Spark partition mapping has to be consistent with the Spark shuffle partition.
 * <p>
 * When the stage partitioning is SINGLE_DISTRIBUTION and the shuffleInputRdds is empty,
 * the taskSourceRdd is expected to be present and contain exactly one empty partition.
 * <p>
 * The broadcast inputs are encapsulated in taskProcessor.
 */
public class PrestoSparkNativeTaskRdd<T extends PrestoSparkTaskOutput>
        extends PrestoSparkTaskRdd<T>
{
    public static <T extends PrestoSparkTaskOutput> PrestoSparkNativeTaskRdd<T> create(
            SparkContext context,
            Optional<PrestoSparkTaskSourceRdd> taskSourceRdd,
            // fragmentId -> RDD
            Map<String, RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRddMap,
            PrestoSparkTaskProcessor<T> taskProcessor)
    {
        requireNonNull(context, "context is null");
        requireNonNull(taskSourceRdd, "taskSourceRdd is null");
        requireNonNull(shuffleInputRddMap, "shuffleInputRddMap is null");
        requireNonNull(taskProcessor, "taskProcessor is null");
        ImmutableList.Builder<String> shuffleInputFragmentIds = ImmutableList.builder();
        ImmutableList.Builder<RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRdds = ImmutableList.builder();
        for (Map.Entry<String, RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> entry : shuffleInputRddMap.entrySet()) {
            shuffleInputFragmentIds.add(entry.getKey());
            shuffleInputRdds.add(entry.getValue());
        }
        return new PrestoSparkNativeTaskRdd<>(context, taskSourceRdd, shuffleInputFragmentIds.build(), shuffleInputRdds.build(), taskProcessor);
    }

    @Override
    public Iterator<Tuple2<MutablePartitionId, T>> compute(Partition split, TaskContext context)
    {
        PrestoSparkTaskSourceRdd taskSourceRdd = getTaskSourceRdd();
        List<Partition> partitions = seqAsJavaList(((ZippedPartitionsPartition) split).partitions());
        int expectedPartitionsSize = (taskSourceRdd != null ? 1 : 0) + getShuffleInputRdds().size();
        checkState(partitions.size() == expectedPartitionsSize,
                format("Unexpected partitions size. Expected: %s. Actual: %s.", expectedPartitionsSize, partitions.size()));

        Iterator<SerializedPrestoSparkTaskSource> taskSourceIterator;
        if (taskSourceRdd != null) {
            taskSourceIterator = taskSourceRdd.iterator(partitions.get(partitions.size() - 1), context);
        }
        else {
            taskSourceIterator = emptyScalaIterator();
        }

        return getTaskProcessor().process(
                taskSourceIterator,
                getShuffleReadDescriptors(partitions),
                getShuffleWriteDescriptor(context.stageId(), split));
    }

    private PrestoSparkNativeTaskRdd(
            SparkContext context,
            Optional<PrestoSparkTaskSourceRdd> taskSourceRdd,
            List<String> shuffleInputFragmentIds,
            List<RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRdds,
            PrestoSparkTaskProcessor<T> taskProcessor)
    {
        super(context, taskSourceRdd, shuffleInputFragmentIds, shuffleInputRdds, taskProcessor);
    }

    private Map<String, PrestoSparkShuffleReadDescriptor> getShuffleReadDescriptors(List<Partition> partitions)
    {
        //The classloader_interface package tries to have minimal external dependencies (except the spark-core), so we use HashMap instead of Guava's ImmutableMap
        ImmutableMap.Builder<String, PrestoSparkShuffleReadDescriptor> shuffleReadDescriptors = ImmutableMap.builder();

        // Get shuffle information from ShuffledRdds for shuffle read
        int numPartitions = partitions.size();
        List<RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRdds = getShuffleInputRdds();
        List<String> shuffleInputFragmentIds = getShuffleInputFragmentIds();
        checkState(
                numPartitions >= shuffleInputRdds.size() && numPartitions >= shuffleInputFragmentIds.size(),
                format(
                        "Size of shuffleInputRdds %d or shuffleInputFragmentIds %d is not equal to number of partitions %d",
                        shuffleInputRdds.size(),
                        shuffleInputFragmentIds.size(),
                        numPartitions));
        for (int i = 0; i < shuffleInputRdds.size(); i++) {
            Partition partition = partitions.get(i);
            checkState(partition != null);
            checkState(partition instanceof ShuffledRDDPartition,
                    "partition is required to be ShuffledRddPartition, but got: %s", partition.getClass().getName());

            RDD<?> shuffleRdd = shuffleInputRdds.get(i);
            checkState(shuffleRdd != null);
            checkState(shuffleRdd instanceof ShuffledRDD, "ShuffledRdd is required but got: %s", shuffleRdd.getClass().getName());

            ShuffleHandle handle = ((ShuffleDependency<?, ?, ?>) shuffleRdd.dependencies().head()).shuffleHandle();
            shuffleReadDescriptors.put(
                    shuffleInputFragmentIds.get(i),
                    new PrestoSparkShuffleReadDescriptor(
                            partition,
                            handle,
                            shuffleRdd.getNumPartitions(),
                            getBlockIds(((ShuffledRDDPartition) partition), handle),
                            getPartitionIds(((ShuffledRDDPartition) partition), handle),
                            getPartitionSize(((ShuffledRDDPartition) partition), handle)));
        }
        return shuffleReadDescriptors.build();
    }

    private Optional<PrestoSparkShuffleWriteDescriptor> getShuffleWriteDescriptor(int stageId, Partition split)
    {
        // Get shuffle information from Spark shuffle manager for shuffle write
        checkState(
                SparkEnv.get().shuffleManager() instanceof PrestoSparkNativeExecutionShuffleManager,
                "Native execution requires to use PrestoSparkNativeExecutionShuffleManager. But got: %s", SparkEnv.get().shuffleManager().getClass().getName());
        PrestoSparkNativeExecutionShuffleManager shuffleManager = (PrestoSparkNativeExecutionShuffleManager) SparkEnv.get().shuffleManager();
        Optional<ShuffleHandle> shuffleHandle = shuffleManager.getShuffleHandle(stageId, split.index());

        return shuffleHandle.map(handle -> new PrestoSparkShuffleWriteDescriptor(handle, shuffleManager.getNumOfPartitions(handle.shuffleId())));
    }

    private List<String> getBlockIds(ShuffledRDDPartition partition, ShuffleHandle shuffleHandle)
    {
        MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker();
        Collection<Tuple2<BlockManagerId, Seq<Tuple2<BlockId, Object>>>> mapSizes = asJavaCollection(mapOutputTracker.getMapSizesByExecutorId(
                shuffleHandle.shuffleId(), partition.idx(), partition.idx() + 1));
        return mapSizes.stream().map(item -> item._1.executorId()).collect(Collectors.toList());
    }

    private List<String> getPartitionIds(ShuffledRDDPartition partition, ShuffleHandle shuffleHandle)
    {
        MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker();
        Collection<Tuple2<BlockManagerId, Seq<Tuple2<BlockId, Object>>>> mapSizes = asJavaCollection(mapOutputTracker.getMapSizesByExecutorId(
                shuffleHandle.shuffleId(), partition.idx(), partition.idx() + 1));
        return mapSizes.stream()
                .map(item -> asJavaCollection(item._2))
                .flatMap(Collection::stream)
                .map(i -> i._1.toString())
                .collect(Collectors.toList());
    }

    private List<Long> getPartitionSize(ShuffledRDDPartition partition, ShuffleHandle shuffleHandle)
    {
        MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker();
        Collection<Tuple2<BlockManagerId, Seq<Tuple2<BlockId, Object>>>> mapSizes = asJavaCollection(mapOutputTracker.getMapSizesByExecutorId(
                shuffleHandle.shuffleId(), partition.idx(), partition.idx() + 1));
        //Each partition/BlockManagerId can contain multiple blocks (with BlockId), here sums up all the blocks from each BlockManagerId/Partition
        return mapSizes.stream()
                .map(
                        item -> seqAsJavaList(item._2)
                                .stream()
                                .map(item2 -> ((Long) item2._2))
                                .reduce(0L, Long::sum))
                .collect(Collectors.toList());
    }
}