AbstractTestNativeJoinQueries.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.nativeworker;
import com.facebook.presto.Session;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.testing.QueryRunner;
import com.facebook.presto.tests.AbstractTestQueryFramework;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createBucketedCustomer;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createBucketedLineitemAndOrders;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createCustomer;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createNation;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrdersEx;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createPartitionedNation;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createRegion;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPLICATE;
public abstract class AbstractTestNativeJoinQueries
extends AbstractTestQueryFramework
{
@Override
protected void createTables()
{
QueryRunner queryRunner = (QueryRunner) getExpectedQueryRunner();
createLineitem(queryRunner);
createOrders(queryRunner);
createBucketedLineitemAndOrders(queryRunner);
createOrdersEx(queryRunner);
createNation(queryRunner);
createPartitionedNation(queryRunner);
createRegion(queryRunner);
createCustomer(queryRunner);
createBucketedCustomer(queryRunner);
}
@Override
protected FeaturesConfig createFeaturesConfig()
{
return new FeaturesConfig().setNativeExecutionEnabled(true);
}
@Test(dataProvider = "joinTypeProvider")
public void testInnerJoin(Session joinTypeSession)
{
assertQuery(joinTypeSession, "SELECT o.orderstatus, l.linenumber FROM orders o, lineitem l WHERE o.orderkey = l.orderkey");
assertQuery(joinTypeSession, "SELECT count(*) FROM orders o, lineitem l WHERE o.orderkey = l.orderkey AND o.orderkey > 10000");
assertQuery(joinTypeSession, "SELECT count(*) FROM orders o, lineitem l WHERE o.orderkey = l.orderkey AND o.orderkey % 2 = 1");
}
@Test(dataProvider = "joinTypeProvider")
public void testBucketedInnerJoin(Session joinTypeSession)
{
assertQuery(joinTypeSession, "SELECT b.name, c.name FROM customer_bucketed b, customer c WHERE b.name=c.name");
assertQuery(joinTypeSession, "SELECT b.name, c.custkey FROM customer_bucketed b, customer c " +
"WHERE b.name=c.name AND \"$bucket\" = 7");
assertQuery(joinTypeSession, "SELECT b.* FROM customer_bucketed b, customer c " +
"WHERE b.name=c.name AND \"$bucket\" IN (2, 5, 8)");
assertQuery(joinTypeSession, "SELECT * FROM customer_bucketed b, customer c " +
"WHERE b.name=c.name AND \"$bucket\" = 5");
}
@Test
public void testSemiJoinPlan()
{
String sql = "SELECT orderkey FROM orders WHERE orderdate IN (SELECT shipdate FROM lineitem)";
assertPlan(
partitionedJoin(),
sql,
anyTree(
semiJoin("orderdate", "shipdate", "orderkey_1",
exchange(REMOTE_STREAMING, REPARTITION,
tableScan("orders", ImmutableMap.of(
"orderkey", "orderkey",
"orderdate", "orderdate"))),
exchange(REMOTE_STREAMING, REPARTITION,
tableScan("lineitem", ImmutableMap.of(
"shipdate", "shipdate"))))));
assertPlan(
broadcastJoin(),
sql,
anyTree(
semiJoin("orderdate", "shipdate", "orderkey_1",
tableScan("orders", ImmutableMap.of(
"orderkey", "orderkey",
"orderdate", "orderdate")),
exchange(REMOTE_STREAMING, REPLICATE,
tableScan("lineitem", ImmutableMap.of(
"shipdate", "shipdate"))))));
}
@Test(dataProvider = "joinTypeProvider")
public void testSemiJoin(Session joinTypeSession)
{
assertQuery(joinTypeSession, "SELECT * FROM orders WHERE orderdate IN (SELECT shipdate FROM lineitem) or orderdate IN (SELECT commitdate FROM lineitem)");
assertQuery(joinTypeSession, "SELECT * FROM lineitem WHERE orderkey IN (SELECT orderkey FROM orders WHERE (orderkey + custkey) % 2 = 0)");
assertQuery(joinTypeSession, "SELECT * FROM lineitem WHERE linenumber = 3 OR orderkey IN (SELECT orderkey FROM orders WHERE (orderkey + custkey) % 2 = 0)");
assertQuery(joinTypeSession, "WITH\n" +
"users AS (SELECT orderkey FROM orders ),\n" +
"left_table AS (SELECT * FROM ( VALUES (0, NULL), (283755559, NULL), (NULL, NULL) ) AS left_table (userid, sid_cast))\n" +
"SELECT userid FROM left_table WHERE (sid_cast IS NOT NULL) OR (NOT (userid IN (SELECT * FROM users)))");
}
@Test(dataProvider = "joinTypeProvider")
public void testAntiJoin(Session joinTypeSession)
{
assertQuery(joinTypeSession, "SELECT * FROM lineitem WHERE orderkey NOT IN (SELECT orderkey FROM orders WHERE (orderkey + custkey) % 2 = 0)");
assertQuery(joinTypeSession, "SELECT * FROM lineitem " +
"WHERE linenumber = 3 OR orderkey NOT IN (SELECT orderkey FROM orders WHERE (orderkey + custkey) % 2 = 0)");
assertQuery("WITH mapping AS (\n" +
" SELECT orderkey, custkey FROM orders GROUP BY 1, 2\n" +
")\n" +
"SELECT \n" +
" custkey\n" +
"FROM \n" +
" mapping m \n" +
"WHERE \n" +
" m.custkey = 38 \n" +
" AND m.orderkey NOT IN (SELECT orderkey FROM lineitem)");
}
@Test(dataProvider = "joinTypeProvider")
public void testLeftJoin(Session joinTypeSession)
{
assertQuery(joinTypeSession, "SELECT * FROM orders o LEFT JOIN lineitem l ON o.orderkey = l.orderkey AND l.linenumber > 5");
}
@Test
public void testRightJoinPartitioned()
{
Session partitionedJoin = partitionedJoin();
assertQuery(partitionedJoin, "SELECT * FROM nation n RIGHT JOIN region r ON n.regionkey = r.regionkey");
assertQuery(partitionedJoin, "SELECT * FROM (SELECT * FROM nation WHERE regionkey % 2 = 1) n RIGHT JOIN region r ON n.regionkey = r.regionkey");
}
@Test
public void testCrossJoin()
{
assertQuery("SELECT * FROM nation, region");
assertQuery("SELECT * FROM nation n, region r WHERE n.regionkey < r.regionkey");
assertQueryReturnsEmptyResult("SELECT l.linenumber FROM lineitem l, orders o WHERE l.orderkey = o.orderkey AND o.orderkey = 12345 AND o.totalprice > 0");
assertQuery("SELECT l.linenumber FROM lineitem l, orders o WHERE l.orderkey = o.orderkey AND o.orderkey = 14209 AND o.totalprice > 0");
assertQuery("SELECT * FROM nation_partitioned a, nation_partitioned b");
assertQuery("SELECT name, (SELECT max(name) FROM region WHERE regionkey = nation.regionkey) FROM nation");
}
@Test
public void testMergeJoin()
{
String sql = "SELECT COUNT(*) FROM lineitem_bucketed a, orders_bucketed b WHERE a.orderkey = b.orderkey AND a.ds = '2021-12-20' AND b.ds = '2021-12-20'";
assertQuery(mergeJoin(), sql, getSession(), sql);
}
@Test
public void testJoinsWithoutEquiClause()
{
// Test double filtered left, right, full and inner joins with right constant equality.
String query = "SELECT count(*) FROM (SELECT * FROM lineitem WHERE orderkey %% 1024 = 0) "
+ "lineitem %s JOIN (SELECT * FROM orders WHERE orderkey %% 1024 = 0) "
+ "orders ON orders.orderkey = 1024";
assertQuery(String.format(query, "LEFT"));
assertQuery(String.format(query, "RIGHT"));
assertQuery(String.format(query, "FULL"));
assertQuery(String.format(query, "INNER"));
}
@DataProvider(name = "joinTypeProvider")
public Object[][] joinTypeProvider()
{
return joinTypeProviderImpl();
}
protected Object[][] joinTypeProviderImpl()
{
return new Object[][] {{partitionedJoin()}, {broadcastJoin()}};
}
protected Session partitionedJoin()
{
return Session.builder(getSession())
.setSystemProperty("join_distribution_type", "PARTITIONED")
.build();
}
protected Session broadcastJoin()
{
return Session.builder(getSession())
.setSystemProperty("join_distribution_type", "BROADCAST")
.build();
}
private Session mergeJoin()
{
return Session.builder(getSession())
.setSystemProperty("prefer_merge_join_for_sorted_inputs", "true")
.setCatalogSessionProperty("hive", "order_based_execution_enabled", "true")
.build();
}
}