PrestoSparkTaskRdd.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.classloader_interface;

import org.apache.spark.Partition;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.ZippedPartitionsBaseRDD;
import org.apache.spark.rdd.ZippedPartitionsPartition;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.reflect.ClassTag;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.spark.classloader_interface.ScalaUtils.emptyScalaIterator;
import static java.lang.String.format;
import static java.util.Collections.unmodifiableMap;
import static java.util.Objects.requireNonNull;
import static scala.collection.JavaConversions.asScalaBuffer;
import static scala.collection.JavaConversions.seqAsJavaList;

/**
 * PrestoSparkTaskRdd represents execution of Presto stage, it contains:
 * - A list of shuffleInputRdds, each of the corresponding to a child stage.
 * - An optional taskSourceRdd, which represents ALL table scan inputs in this stage.
 * <p>
 * Table scan is present when joining a bucketed table with an unbucketed table, for example:
 * Join
 * /  \
 * Scan  Remote Source
 * <p>
 * In this case, bucket to Spark partition mapping has to be consistent with the Spark shuffle partition.
 * <p>
 * When the stage partitioning is SINGLE_DISTRIBUTION and the shuffleInputRdds is empty,
 * the taskSourceRdd is expected to be present and contain exactly one empty partition.
 * <p>
 * The broadcast inputs are encapsulated in taskProcessor.
 */
public class PrestoSparkTaskRdd<T extends PrestoSparkTaskOutput>
        extends ZippedPartitionsBaseRDD<Tuple2<MutablePartitionId, T>>
{
    private List<RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRdds;
    private List<String> shuffleInputFragmentIds;
    private PrestoSparkTaskSourceRdd taskSourceRdd;
    private PrestoSparkTaskProcessor<T> taskProcessor;

    public static <T extends PrestoSparkTaskOutput> PrestoSparkTaskRdd<T> create(
            SparkContext context,
            Optional<PrestoSparkTaskSourceRdd> taskSourceRdd,
            // fragmentId -> RDD
            Map<String, RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRddMap,
            PrestoSparkTaskProcessor<T> taskProcessor)
    {
        requireNonNull(context, "context is null");
        requireNonNull(taskSourceRdd, "taskSourceRdd is null");
        requireNonNull(shuffleInputRddMap, "shuffleInputRdds is null");
        requireNonNull(taskProcessor, "taskProcessor is null");
        List<String> shuffleInputFragmentIds = new ArrayList<>();
        List<RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRdds = new ArrayList<>();
        for (Map.Entry<String, RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> entry : shuffleInputRddMap.entrySet()) {
            shuffleInputFragmentIds.add(entry.getKey());
            shuffleInputRdds.add(entry.getValue());
        }
        return new PrestoSparkTaskRdd<>(context, taskSourceRdd, shuffleInputFragmentIds, shuffleInputRdds, taskProcessor);
    }

    protected PrestoSparkTaskRdd(
            SparkContext context,
            Optional<PrestoSparkTaskSourceRdd> taskSourceRdd,
            List<String> shuffleInputFragmentIds,
            List<RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRdds,
            PrestoSparkTaskProcessor<T> taskProcessor)
    {
        super(context, getRDDSequence(taskSourceRdd, shuffleInputRdds), false, fakeClassTag());
        this.shuffleInputFragmentIds = shuffleInputFragmentIds;
        this.shuffleInputRdds = shuffleInputRdds;
        // Optional is not Java Serializable
        this.taskSourceRdd = taskSourceRdd.orElse(null);
        this.taskProcessor = context.clean(taskProcessor, true);
    }

    private static Seq<RDD<?>> getRDDSequence(Optional<PrestoSparkTaskSourceRdd> taskSourceRdd, List<RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRdds)
    {
        List<RDD<?>> list = new ArrayList<>(shuffleInputRdds);
        taskSourceRdd.ifPresent(list::add);
        return asScalaBuffer(list).toSeq();
    }

    private static <T> ClassTag<T> fakeClassTag()
    {
        return scala.reflect.ClassTag$.MODULE$.apply(Tuple2.class);
    }

    @Override
    public Iterator<Tuple2<MutablePartitionId, T>> compute(Partition split, TaskContext context)
    {
        List<Partition> partitions = seqAsJavaList(((ZippedPartitionsPartition) split).partitions());
        int expectedPartitionsSize = (taskSourceRdd != null ? 1 : 0) + shuffleInputRdds.size();
        if (partitions.size() != expectedPartitionsSize) {
            throw new IllegalArgumentException(format("Unexpected partitions size. Expected: %s. Actual: %s.", expectedPartitionsSize, partitions.size()));
        }

        Map<String, Iterator<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputIterators = new HashMap<>();
        for (int inputIndex = 0; inputIndex < shuffleInputRdds.size(); inputIndex++) {
            shuffleInputIterators.put(
                    shuffleInputFragmentIds.get(inputIndex),
                    shuffleInputRdds.get(inputIndex).iterator(partitions.get(inputIndex), context));
        }

        Iterator<SerializedPrestoSparkTaskSource> taskSourceIterator;
        if (taskSourceRdd != null) {
            taskSourceIterator = taskSourceRdd.iterator(partitions.get(partitions.size() - 1), context);
        }
        else {
            taskSourceIterator = emptyScalaIterator();
        }

        return taskProcessor.process(taskSourceIterator, unmodifiableMap(shuffleInputIterators));
    }

    @Override
    public void clearDependencies()
    {
        super.clearDependencies();
        shuffleInputFragmentIds = null;
        shuffleInputRdds = null;
        taskSourceRdd = null;
        taskProcessor = null;
    }

    public List<RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> getShuffleInputRdds()
    {
        return shuffleInputRdds;
    }

    public PrestoSparkTaskSourceRdd getTaskSourceRdd()
    {
        return taskSourceRdd;
    }

    public List<String> getShuffleInputFragmentIds()
    {
        return shuffleInputFragmentIds;
    }

    public PrestoSparkTaskProcessor<T> getTaskProcessor()
    {
        return taskProcessor;
    }
}