TestDruidQueryGenerator.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.druid;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.testng.Assert.assertEquals;
public class TestDruidQueryGenerator
extends TestDruidQueryBase
{
private static final SessionHolder defaultSessionHolder = new SessionHolder();
private static final DruidTableHandle druidTable = realtimeOnlyTable;
private void testDQL(
DruidConfig givenDruidConfig,
Function<PlanBuilder, PlanNode> planBuilderConsumer,
String expectedDQL, SessionHolder sessionHolder,
Map<String, String> outputVariables)
{
PlanNode planNode = planBuilderConsumer.apply(createPlanBuilder(sessionHolder));
DruidQueryGenerator.DruidQueryGeneratorResult druidQueryGeneratorResult = new DruidQueryGenerator(functionAndTypeManager, functionAndTypeManager, standardFunctionResolution).generate(planNode, sessionHolder.getConnectorSession()).get();
if (expectedDQL.contains("__expressions__")) {
String expressions = planNode.getOutputVariables().stream().map(v -> outputVariables.get(v.getName())).filter(v -> v != null).collect(Collectors.joining(", "));
expectedDQL = expectedDQL.replace("__expressions__", expressions);
}
String generateDQL = druidQueryGeneratorResult.getGeneratedDql().getDql();
generateDQL = generateDQL.replaceAll("\\\\\"", "");
assertEquals(generateDQL, expectedDQL);
}
private void testDQL(Function<PlanBuilder, PlanNode> planBuilderConsumer, String expectedDQL, SessionHolder sessionHolder, Map<String, String> outputVariables)
{
testDQL(druidConfig, planBuilderConsumer, expectedDQL, sessionHolder, outputVariables);
}
private void testDQL(Function<PlanBuilder, PlanNode> planBuilderConsumer, String expectedDQL, SessionHolder sessionHolder)
{
testDQL(planBuilderConsumer, expectedDQL, sessionHolder, ImmutableMap.of());
}
private void testDQL(Function<PlanBuilder, PlanNode> planBuilderConsumer, String expectedDQL)
{
testDQL(planBuilderConsumer, expectedDQL, defaultSessionHolder);
}
private PlanNode buildPlan(Function<PlanBuilder, PlanNode> consumer)
{
PlanBuilder planBuilder = createPlanBuilder(defaultSessionHolder);
return consumer.apply(planBuilder);
}
@Test
public void testSimpleSelectStar()
{
testDQL(
planBuilder -> limit(planBuilder, 50L, tableScan(planBuilder, druidTable, regionId, city, fare, secondsSinceEpoch)),
"SELECT \"region.Id\", \"city\", \"fare\", \"secondsSinceEpoch\" FROM \"realtimeOnly\" LIMIT 50");
testDQL(
planBuilder -> limit(planBuilder, 10L, tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch)),
"SELECT \"region.Id\", \"secondsSinceEpoch\" FROM \"realtimeOnly\" LIMIT 10");
}
@Test
public void testSimpleSelectWithFilterLimit()
{
testDQL(
planBuilder -> limit(
planBuilder,
30L,
project(
planBuilder,
filter(
planBuilder,
tableScan(planBuilder, druidTable, regionId, city, fare, secondsSinceEpoch),
getRowExpression("secondssinceepoch > 20", defaultSessionHolder)),
ImmutableList.of("city", "secondssinceepoch"))),
"SELECT \"city\", \"secondsSinceEpoch\" FROM \"realtimeOnly\" WHERE (\"secondsSinceEpoch\" > 20) LIMIT 30");
}
@Test
public void testCountStar()
{
BiConsumer<PlanBuilder, PlanBuilder.AggregationBuilder> aggregationFunctionBuilder = (planBuilder, aggregationBuilder) -> aggregationBuilder.addAggregation(planBuilder.variable("agg"), getRowExpression("count(*)", defaultSessionHolder));
PlanNode justScan = buildPlan(planBuilder -> tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch, city, fare));
PlanNode filter = buildPlan(planBuilder -> filter(planBuilder, tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch, city, fare), getRowExpression("fare > 3", defaultSessionHolder)));
PlanNode anotherFilter = buildPlan(planBuilder ->
filter(planBuilder,
tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch, city, fare),
getRowExpression("secondssinceepoch between 200 and 300 and \"region.id\" >= 40", defaultSessionHolder)));
testDQL(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(justScan).globalGrouping())),
"SELECT count(*) FROM \"realtimeOnly\"");
testDQL(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(filter).globalGrouping())),
"SELECT count(*) FROM \"realtimeOnly\" WHERE (\"fare\" > 3)");
testDQL(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(filter).singleGroupingSet(variable("region.id")))),
"SELECT \"region.Id\", count(*) FROM \"realtimeOnly\" WHERE (\"fare\" > 3) GROUP BY \"region.Id\"");
testDQL(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(justScan).singleGroupingSet(variable("region.id")))),
"SELECT \"region.Id\", count(*) FROM \"realtimeOnly\" GROUP BY \"region.Id\"");
testDQL(
planBuilder -> limit(planBuilder, 5L, planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(justScan).singleGroupingSet(variable("region.id"))))),
"SELECT \"region.Id\", count(*) FROM \"realtimeOnly\" GROUP BY \"region.Id\" LIMIT 5");
testDQL(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(anotherFilter).singleGroupingSet(variable("region.id"), variable("city")))),
"SELECT \"region.Id\", \"city\", count(*) FROM \"realtimeOnly\" WHERE ((\"secondsSinceEpoch\" BETWEEN 200 AND 300) AND (\"region.Id\" >= 40)) GROUP BY \"region.Id\", \"city\"");
}
@Test
public void testDistinctSelection()
{
PlanNode justScan = buildPlan(planBuilder -> tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch, city, fare));
testDQL(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggBuilder.source(justScan).singleGroupingSet(variable("region.id"))),
"SELECT \"region.Id\", count(*) FROM \"realtimeOnly\" GROUP BY \"region.Id\"");
}
@Test
public void testDistinctCountPushdown()
{
PlanNode justScan = buildPlan(planBuilder -> tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch, city, fare));
PlanNode distinctAggregation = buildPlan(planBuilder -> planBuilder.aggregation(aggBuilder -> aggBuilder.source(justScan).singleGroupingSet(variable("region.id"))));
testDQL(
planBuilder -> planBuilder.aggregation(
aggBuilder -> aggBuilder.source(distinctAggregation).globalGrouping().addAggregation(variable("region.id"),
getRowExpression("count(\"region.id\")", defaultSessionHolder))),
"SELECT count ( distinct \"region.Id\") FROM \"realtimeOnly\"");
}
@Test
public void testGroupByPushdown()
{
PlanNode justScan = buildPlan(planBuilder -> tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch, city, fare));
testDQL(
planBuilder -> planBuilder.aggregation(
aggBuilder -> aggBuilder.source(justScan).singleGroupingSet(variable("city"), variable("region.id"), variable("secondssinceepoch"))
.addAggregation(variable("totalfare"), getRowExpression("sum(\"fare\")", defaultSessionHolder))),
"SELECT \"city\", \"region.Id\", \"secondsSinceEpoch\", sum(fare) FROM \"realtimeOnly\" GROUP BY \"city\", \"region.Id\", \"secondsSinceEpoch\"");
}
@Test
public void testDistinctCountGroupByPushdown()
{
PlanNode justScan = buildPlan(planBuilder -> tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch, city, fare));
PlanNode distinctAggregation = buildPlan(planBuilder -> planBuilder.aggregation(
aggBuilder -> aggBuilder.source(justScan).singleGroupingSet(variable("city"), variable("region.id"))));
testDQL(
planBuilder -> planBuilder.aggregation(
aggBuilder -> aggBuilder.source(distinctAggregation).singleGroupingSet(variable("city"))
.addAggregation(variable("region.id"), getRowExpression("count(\"region.id\")", defaultSessionHolder))),
"SELECT \"city\", count ( distinct \"region.Id\") FROM \"realtimeOnly\" GROUP BY \"city\"");
}
@Test
public void testTimestampLiteralPushdown()
{
//the timezone of the session is Pacific/Apia UTC+13
//the timezone of the connector session is UTC
//so the time needs to be adjust for 13 hours if the timezone not specified
testDQL(
planBuilder -> project(
planBuilder,
filter(
planBuilder,
tableScan(planBuilder, druidTable, regionId, city, fare, datetime),
getRowExpression("datetime = timestamp '2016-06-26 19:00:00.000'", defaultSessionHolder)),
ImmutableList.of("city", "datetime")),
"SELECT \"city\", \"datetime\" FROM \"realtimeOnly\" WHERE (\"datetime\" = TIMESTAMP '2016-06-26 06:00:00.000')");
//test timestamp with timezone
testDQL(
planBuilder -> project(
planBuilder,
filter(
planBuilder,
tableScan(planBuilder, druidTable, regionId, city, fare, datetime),
getRowExpression("datetime > timestamp '2016-06-26 19:00:00.000 UTC'", defaultSessionHolder)),
ImmutableList.of("city", "datetime")),
"SELECT \"city\", \"datetime\" FROM \"realtimeOnly\" WHERE (\"datetime\" > TIMESTAMP '2016-06-26 19:00:00.000')");
}
}