TestPrestoSparkQueryExecution.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.spark;
import com.facebook.presto.Session;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.spark.classloader_interface.IPrestoSparkQueryExecution;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskRdd;
import com.facebook.presto.spark.execution.FragmentExecutionResult;
import com.facebook.presto.spark.execution.PrestoSparkAdaptiveQueryExecution;
import com.facebook.presto.spark.execution.PrestoSparkStaticQueryExecution;
import com.facebook.presto.sql.planner.SubPlan;
import com.facebook.presto.testing.MaterializedResult;
import com.facebook.presto.testing.QueryRunner;
import com.facebook.presto.tests.AbstractTestQueryFramework;
import com.facebook.presto.tests.QueryAssertions;
import com.google.common.collect.ImmutableList;
import org.apache.spark.Dependency;
import org.apache.spark.MapOutputStatistics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.ShuffledRDD;
import org.testng.annotations.Test;
import java.util.Collection;
import java.util.Optional;
import java.util.stream.Collectors;
import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static com.facebook.presto.spark.PrestoSparkQueryRunner.createHivePrestoSparkQueryRunner;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.SPARK_ADAPTIVE_QUERY_EXECUTION_ENABLED;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.SPARK_RETRY_ON_OUT_OF_MEMORY_WITH_INCREASED_MEMORY_SETTINGS_ENABLED;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.STORAGE_BASED_BROADCAST_JOIN_ENABLED;
import static com.facebook.presto.spark.execution.RuntimeStatistics.createRuntimeStats;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
public class TestPrestoSparkQueryExecution
extends AbstractTestQueryFramework
{
PrestoSparkQueryRunner prestoSparkQueryRunner;
@Override
protected QueryRunner createQueryRunner()
{
prestoSparkQueryRunner = createHivePrestoSparkQueryRunner();
return prestoSparkQueryRunner;
}
private IPrestoSparkQueryExecution getPrestoSparkQueryExecution(Session session, String sql)
{
return prestoSparkQueryRunner.createPrestoSparkQueryExecution(session, sql, ImmutableList.of());
}
@Test
public void testQueryExecutionCreation()
{
String sqlText = "select * from lineitem";
Session session = Session.builder(getSession())
.setSystemProperty(SPARK_ADAPTIVE_QUERY_EXECUTION_ENABLED, "false")
.build();
IPrestoSparkQueryExecution psQueryExecution = getPrestoSparkQueryExecution(session, sqlText);
assertTrue(psQueryExecution instanceof PrestoSparkStaticQueryExecution);
session = Session.builder(getSession())
.setSystemProperty(SPARK_ADAPTIVE_QUERY_EXECUTION_ENABLED, "true")
.build();
psQueryExecution = getPrestoSparkQueryExecution(session, sqlText);
assertTrue(psQueryExecution instanceof PrestoSparkAdaptiveQueryExecution);
}
@Test
public void testSingleFragmentQueryAdaptiveExecution()
{
String sqlText = "select * from lineitem";
Session session = Session.builder(getSession())
.setSystemProperty(SPARK_ADAPTIVE_QUERY_EXECUTION_ENABLED, "false")
.build();
MaterializedResult staticResults = prestoSparkQueryRunner.execute(session, sqlText);
session = Session.builder(getSession())
.setSystemProperty(SPARK_ADAPTIVE_QUERY_EXECUTION_ENABLED, "true")
.build();
MaterializedResult dynamicResults = prestoSparkQueryRunner.execute(session, sqlText);
QueryAssertions.assertEqualsIgnoreOrder(staticResults, dynamicResults);
}
@Test
public void testJoinQueryAdaptiveExecution()
{
String sqlText = "select * from lineitem l join orders o on l.orderkey = o.orderkey";
Session session = Session.builder(getSession())
.setSystemProperty(SPARK_ADAPTIVE_QUERY_EXECUTION_ENABLED, "false")
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, "partitioned")
.setSystemProperty(SPARK_RETRY_ON_OUT_OF_MEMORY_WITH_INCREASED_MEMORY_SETTINGS_ENABLED, "false")
.build();
MaterializedResult staticResults = prestoSparkQueryRunner.execute(session, sqlText);
session = Session.builder(getSession())
.setSystemProperty(SPARK_ADAPTIVE_QUERY_EXECUTION_ENABLED, "true")
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, "partitioned")
.setSystemProperty(SPARK_RETRY_ON_OUT_OF_MEMORY_WITH_INCREASED_MEMORY_SETTINGS_ENABLED, "false")
.build();
MaterializedResult dynamicResults = prestoSparkQueryRunner.execute(session, sqlText);
QueryAssertions.assertEqualsIgnoreOrder(staticResults, dynamicResults);
}
@Test
public void testGroupByAdaptiveExecution()
{
String sqlText = "SELECT custkey, orderstatus FROM orders ORDER BY orderkey DESC LIMIT 10";
Session session = Session.builder(getSession())
.setSystemProperty(SPARK_ADAPTIVE_QUERY_EXECUTION_ENABLED, "false")
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, "partitioned")
.setSystemProperty(SPARK_RETRY_ON_OUT_OF_MEMORY_WITH_INCREASED_MEMORY_SETTINGS_ENABLED, "false")
.build();
MaterializedResult staticResults = prestoSparkQueryRunner.execute(session, sqlText);
session = Session.builder(getSession())
.setSystemProperty(SPARK_ADAPTIVE_QUERY_EXECUTION_ENABLED, "true")
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, "partitioned")
.setSystemProperty(SPARK_RETRY_ON_OUT_OF_MEMORY_WITH_INCREASED_MEMORY_SETTINGS_ENABLED, "false")
.build();
MaterializedResult dynamicResults = prestoSparkQueryRunner.execute(session, sqlText);
QueryAssertions.assertEqualsIgnoreOrder(staticResults, dynamicResults);
}
@Test
public void testRddCreationForPartitionedJoinWithoutShuffle()
{
Session session = Session.builder(getSession())
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, "partitioned")
.setSystemProperty(SPARK_RETRY_ON_OUT_OF_MEMORY_WITH_INCREASED_MEMORY_SETTINGS_ENABLED, "false")
.build();
String sql = "select * from lineitem l join orders o on l.orderkey = o.orderkey";
validateFragmentedRddCreation(session, sql);
}
@Test
public void testRddCreationForPartitionedJoinWithShuffle()
{
Session session = Session.builder(getSession())
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, "partitioned")
.setSystemProperty(SPARK_RETRY_ON_OUT_OF_MEMORY_WITH_INCREASED_MEMORY_SETTINGS_ENABLED, "false")
.build();
String sql = "select * from lineitem l join orders o on l.orderkey = o.orderkey UNION ALL select * from lineitem l join orders o on l.orderkey = o.orderkey";
validateFragmentedRddCreation(session, sql);
}
@Test
public void testRddCreationForMemoryBasedBroadcastJoin()
{
Session session = Session.builder(getSession())
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, "broadcast")
.setSystemProperty(SPARK_RETRY_ON_OUT_OF_MEMORY_WITH_INCREASED_MEMORY_SETTINGS_ENABLED, "false")
.setSystemProperty(STORAGE_BASED_BROADCAST_JOIN_ENABLED, "false")
.build();
String sql = "select * from lineitem l join orders o on l.orderkey = o.orderkey";
validateFragmentedRddCreation(session, sql);
}
@Test
public void testRddCreationForStorageBasedBroadcastJoin()
{
Session session = Session.builder(getSession())
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, "broadcast")
.setSystemProperty(SPARK_RETRY_ON_OUT_OF_MEMORY_WITH_INCREASED_MEMORY_SETTINGS_ENABLED, "false")
.setSystemProperty(STORAGE_BASED_BROADCAST_JOIN_ENABLED, "true")
.build();
String sql = "select * from lineitem l join orders o on l.orderkey = o.orderkey";
validateFragmentedRddCreation(session, sql);
}
@Test
public void testMapOutputStatsExtraction()
{
Session session = Session.builder(getSession())
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, "broadcast")
.setSystemProperty(SPARK_RETRY_ON_OUT_OF_MEMORY_WITH_INCREASED_MEMORY_SETTINGS_ENABLED, "false")
.build();
Optional<PlanNodeStatsEstimate> planNodeStatsEstimate;
// Empty stats case
planNodeStatsEstimate = createRuntimeStats(Optional.empty());
assertFalse(planNodeStatsEstimate.isPresent());
// Empty partition array
planNodeStatsEstimate = createRuntimeStats(Optional.of(new MapOutputStatistics(0, new long[] {})));
assertEquals(planNodeStatsEstimate.get().getOutputSizeInBytes(), 0);
// One partition case
planNodeStatsEstimate = createRuntimeStats(Optional.of(new MapOutputStatistics(0, new long[] {23})));
assertEquals(planNodeStatsEstimate.get().getOutputSizeInBytes(), 23);
// Multiple partition case
planNodeStatsEstimate = createRuntimeStats(Optional.of(new MapOutputStatistics(0, new long[] {23, 520, 190})));
assertEquals(planNodeStatsEstimate.get().getOutputSizeInBytes(), 733);
}
private void validateFragmentedRddCreation(Session session, String sql)
{
PrestoSparkStaticQueryExecution execution = (PrestoSparkStaticQueryExecution) getPrestoSparkQueryExecution(session, sql);
RddAndMore rddAndMoreStatic = null;
RddAndMore rddAndMoreFromFragmentedExecution = null;
try {
SubPlan rootFragmentedPlan = execution.createFragmentedPlan();
TableWriteInfo tableWriteInfo = execution.getTableWriteInfo(session, rootFragmentedPlan);
rddAndMoreStatic = execution.createRdd(rootFragmentedPlan, PrestoSparkSerializedPage.class, tableWriteInfo);
FragmentExecutionResult fragmentExecutionResult = executeInStages(execution, rootFragmentedPlan, tableWriteInfo);
rddAndMoreFromFragmentedExecution = fragmentExecutionResult.getRddAndMore();
}
catch (Exception e) {
fail("Failed while creating RDD", e);
}
assertRddAndMoreEquals(rddAndMoreStatic, rddAndMoreFromFragmentedExecution);
}
// For testing purpose, execute plan by executing sub-plans starting at(in the same order) :
// 1. Level 2
// 2. Level 1
// 3. level 0
// For plans with more than 3 levels, it will execute child lvels in same step
// todo - this can be updated in future with methods in Adaptive execution to get subplans at shuffle boundaries
private FragmentExecutionResult executeInStages(PrestoSparkStaticQueryExecution execution,
SubPlan rootFragmentedPlan,
TableWriteInfo tableWriteInfo)
{
// Level 2
rootFragmentedPlan.getChildren().stream()
.map(SubPlan::getChildren)
.flatMap(Collection::stream)
.forEach(subPlan -> excecuteSubPlanWithUncheckedException(execution, subPlan, tableWriteInfo, Optional.empty()));
// Level 1
rootFragmentedPlan.getChildren().stream()
.forEach(subPlan -> excecuteSubPlanWithUncheckedException(execution, subPlan, tableWriteInfo, Optional.empty()));
// Level 0 - Root
return excecuteSubPlanWithUncheckedException(execution, rootFragmentedPlan, tableWriteInfo, Optional.empty());
}
// This exists as forEach can't handle methods with checked exception
private FragmentExecutionResult excecuteSubPlanWithUncheckedException(PrestoSparkStaticQueryExecution execution,
SubPlan subPlan,
TableWriteInfo tableWriteInfo,
Optional<Class<?>> outputType)
{
try {
return execution.executeFragment(subPlan, tableWriteInfo, outputType);
}
catch (Exception e) {
throwIfUnchecked(e);
throw new RuntimeException(e);
}
}
private void assertRddAndMoreEquals(RddAndMore rddAndMore1, RddAndMore rddAndMore2)
{
assertEquals(rddAndMore1.getBroadcastDependencies().size(), rddAndMore2.getBroadcastDependencies().size());
assertRddEquals(rddAndMore1.getRdd().rdd(), rddAndMore2.getRdd().rdd());
}
private void assertRddEquals(RDD rdd1, RDD rdd2)
{
assertEquals(rdd1.name(), rdd2.name());
assertEquals(rdd1.getClass(), rdd2.getClass());
assertEquals(rdd1.getDependencies().size(), rdd2.getDependencies().size());
assertEquals(rdd1.getNumPartitions(), rdd2.getNumPartitions());
// type specific assertions
if (rdd1 instanceof PrestoSparkTaskRdd) {
assertPrestoSparkTaskRddEquals((PrestoSparkTaskRdd) rdd1, (PrestoSparkTaskRdd) rdd2);
}
else if (rdd1 instanceof ShuffledRDD) {
assertShuffledRddEquals((ShuffledRDD) rdd1, (ShuffledRDD) rdd2);
}
}
private void assertShuffledRddEquals(ShuffledRDD shuffledRDD1, ShuffledRDD shuffledRDD2)
{
assertEquals(shuffledRDD1.getNumPartitions(), shuffledRDD2.getNumPartitions());
assertRddEquals(shuffledRDD1.prev(), shuffledRDD2.prev());
for (int i = 0; i < shuffledRDD1.getDependencies().size(); i++) {
assertRddEquals(
((Dependency) shuffledRDD1.getDependencies().apply(i)).rdd(),
((Dependency) shuffledRDD2.getDependencies().apply(i)).rdd());
}
}
private void assertPrestoSparkTaskRddEquals(PrestoSparkTaskRdd prestoSparkTaskRdd1, PrestoSparkTaskRdd prestoSparkTaskRdd2)
{
assertEquals(
prestoSparkTaskRdd1.getShuffleInputRdds().size(),
prestoSparkTaskRdd2.getShuffleInputRdds().size(),
"Expected same number of shuffle inputs");
assertEquals(
prestoSparkTaskRdd1.getShuffleInputFragmentIds().stream().collect(Collectors.toSet()),
prestoSparkTaskRdd2.getShuffleInputFragmentIds().stream().collect(Collectors.toSet()),
"Expected same input fragment ids");
assertEquals(
prestoSparkTaskRdd1.getTaskSourceRdd() == null,
prestoSparkTaskRdd2.getTaskSourceRdd() == null,
"Expected both RDDs to either contain TaskSourceRdd or not contain it");
for (int i = 0; i < prestoSparkTaskRdd1.getShuffleInputRdds().size(); i++) {
assertRddEquals((RDD) prestoSparkTaskRdd1.getShuffleInputRdds().get(i), (RDD) prestoSparkTaskRdd2.getShuffleInputRdds().get(i));
}
for (int i = 0; i < prestoSparkTaskRdd1.getDependencies().size(); i++) {
assertRddEquals(
((Dependency) prestoSparkTaskRdd1.getDependencies().apply(i)).rdd(),
((Dependency) prestoSparkTaskRdd2.getDependencies().apply(i)).rdd());
}
}
}