PrestoSparkTaskSourceRdd.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.spark.classloader_interface;
import org.apache.spark.Dependency;
import org.apache.spark.InterruptibleIterator;
import org.apache.spark.Partition;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.ParallelCollectionPartition;
import org.apache.spark.rdd.RDD;
import scala.collection.Iterator;
import scala.reflect.ClassTag;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
import static scala.collection.JavaConversions.asScalaBuffer;
public class PrestoSparkTaskSourceRdd
extends RDD<SerializedPrestoSparkTaskSource>
{
/**
* Each element in taskSourcesByPartitionId is a list of task sources assigned to the same Spark partition.
* When input tables are unbucketed, task sources are distributed randomly across all partitions (tasks).
* When input tables are bucketed, each bucket in task sources will be assigned to one Spark partition (task),
* and the assignment is compatible to potential shuffle inputs.
*/
private transient List<List<SerializedPrestoSparkTaskSource>> taskSourcesByPartitionId;
public PrestoSparkTaskSourceRdd(SparkContext sparkContext, List<List<SerializedPrestoSparkTaskSource>> taskSourcesByPartitionId)
{
super(sparkContext, asScalaBuffer(Collections.<Dependency<?>>emptyList()).toSeq(), fakeClassTag());
this.taskSourcesByPartitionId = requireNonNull(taskSourcesByPartitionId, "taskSourcesByPartitionId is null").stream()
.map(ArrayList::new)
.collect(toList());
}
private static <T> ClassTag<T> fakeClassTag()
{
return scala.reflect.ClassTag$.MODULE$.apply(SerializedPrestoSparkTaskSource.class);
}
@Override
public Partition[] getPartitions()
{
Partition[] partitions = new Partition[taskSourcesByPartitionId.size()];
for (int partitionId = 0; partitionId < taskSourcesByPartitionId.size(); partitionId++) {
partitions[partitionId] = new ParallelCollectionPartition<>(
id(),
partitionId,
asScalaBuffer(taskSourcesByPartitionId.get(partitionId)).toSeq(),
fakeClassTag());
}
return partitions;
}
@Override
public Iterator<SerializedPrestoSparkTaskSource> compute(Partition partition, TaskContext context)
{
ParallelCollectionPartition<SerializedPrestoSparkTaskSource> parallelCollectionPartition = toParallelCollectionPartition(partition);
return new InterruptibleIterator<>(context, parallelCollectionPartition.iterator());
}
@SuppressWarnings("unchecked")
private static <T> ParallelCollectionPartition<T> toParallelCollectionPartition(Partition partition)
{
return (ParallelCollectionPartition<T>) partition;
}
@Override
public void clearDependencies()
{
super.clearDependencies();
taskSourcesByPartitionId = null;
}
}