TestFilteredAggregations.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.query;

import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.sql.Optimizer;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.Map;

import static com.facebook.presto.SystemSessionProperties.AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static com.facebook.presto.util.MorePredicates.isInstanceOfAny;
import static org.testng.Assert.assertFalse;

public class TestFilteredAggregations
        extends BasePlanTest
{
    private static final Map<String, String> sessionProperties = ImmutableMap.of(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "filter_with_if");
    private QueryAssertions assertions;

    public TestFilteredAggregations()
    {
        super(sessionProperties);
    }

    @BeforeClass
    public void init()
    {
        assertions = new QueryAssertions(sessionProperties);
    }

    @AfterClass(alwaysRun = true)
    public void teardown()
    {
        assertions.close();
        assertions = null;
    }

    @Test
    public void testAddPredicateForFilterClauses()
    {
        assertions.assertQuery(
                "SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)",
                "VALUES (BIGINT '10')");
        assertions.assertQuery(
                "SELECT sum(IF(x > 0, x)) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)",
                "VALUES (BIGINT '10')");

        assertions.assertQuery(
                "SELECT sum(x) FILTER(WHERE x > 0), sum(x) FILTER(WHERE x < 3) FROM (VALUES 1, 1, 0, 5, 3, 8) t(x)",
                "VALUES (BIGINT '18', BIGINT '2')");
        assertions.assertQuery(
                "SELECT sum(IF(x > 0, x)), sum(IF(x < 3, x)) FROM (VALUES 1, 1, 0, 5, 3, 8) t(x)",
                "VALUES (BIGINT '18', BIGINT '2')");
        assertions.assertQuery(
                "SELECT sum(IF(x > 0, x)), sum(x) FILTER(WHERE x < 3) FROM (VALUES 1, 1, 0, 5, 3, 8) t(x)",
                "VALUES (BIGINT '18', BIGINT '2')");

        assertions.assertQuery(
                "SELECT sum(x) FILTER(WHERE x > 1), sum(x) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)",
                "VALUES (BIGINT '8', BIGINT '10')");
        assertions.assertQuery(
                "SELECT sum(IF(x > 1, x)), sum(x) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)",
                "VALUES (BIGINT '8', BIGINT '10')");
    }

    @Test
    public void testGroupAll()
    {
        assertions.assertQuery(
                "SELECT count(DISTINCT x) FILTER (WHERE x > 1) " +
                        "FROM (VALUES 1, 1, 1, 2, 3, 3) t(x)",
                "VALUES BIGINT '2'");
        assertions.assertQuery(
                "SELECT count(DISTINCT IF(x > 1, x)) " +
                        "FROM (VALUES 1, 1, 1, 2, 3, 3) t(x)",
                "VALUES BIGINT '2'");

        assertions.assertQuery(
                "SELECT count(DISTINCT x) FILTER (WHERE x > 1), sum(DISTINCT x) " +
                        "FROM (VALUES 1, 1, 1, 2, 3, 3) t(x)",
                "VALUES (BIGINT '2', BIGINT '6')");
        assertions.assertQuery(
                "SELECT count(DISTINCT IF(x > 1, x)), sum(DISTINCT x) " +
                        "FROM (VALUES 1, 1, 1, 2, 3, 3) t(x)",
                "VALUES (BIGINT '2', BIGINT '6')");

        assertions.assertQuery(
                "SELECT count(DISTINCT x) FILTER (WHERE x > 1), sum(DISTINCT y) FILTER (WHERE x < 3)" +
                        "FROM (VALUES " +
                        "(1, 10)," +
                        "(1, 20)," +
                        "(1, 20)," +
                        "(2, 20)," +
                        "(3, 30)) t(x, y)",
                "VALUES (BIGINT '2', BIGINT '30')");
        assertions.assertQuery(
                "SELECT count(DISTINCT IF(x > 1, x)), sum(DISTINCT IF(x < 3, y)) " +
                        "FROM (VALUES " +
                        "(1, 10)," +
                        "(1, 20)," +
                        "(1, 20)," +
                        "(2, 20)," +
                        "(3, 30)) t(x, y)",
                "VALUES (BIGINT '2', BIGINT '30')");

        assertions.assertQuery(
                "SELECT count(x) FILTER (WHERE x > 1), sum(DISTINCT x) " +
                        "FROM (VALUES 1, 2, 3, 3) t(x)",
                "VALUES (BIGINT '3', BIGINT '6')");
        assertions.assertQuery(
                "SELECT count(IF(x > 1, x)),  sum(DISTINCT x) " +
                        "FROM (VALUES 1, 2, 3, 3) t(x)",
                "VALUES (BIGINT '3', BIGINT '6')");
    }

    @Test
    public void testSetAggWithNulls()
    {
        assertions.assertQuery(
                "SELECT y, set_agg(y) FILTER (WHERE x = 1) FROM (SELECT 1 x, 2 y UNION ALL SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
                "VALUES (INTEGER '2', ARRAY[INTEGER '2']), (INTEGER '20', CAST(NULL AS ARRAY<INTEGER>)), (CAST(NULL AS INTEGER), ARRAY[CAST(NULL AS INTEGER)])");
        assertions.assertQuery(
                "SELECT y, set_agg(IF(x = 1,y)) FROM (SELECT 1 x, 2 y UNION ALL SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
                "VALUES (INTEGER '2', ARRAY[INTEGER '2']), (INTEGER '20', ARRAY[CAST(NULL AS INTEGER)]), (CAST(NULL AS INTEGER), ARRAY[CAST(NULL AS INTEGER)])");
    }

    @Test
    public void testApproxSet()
    {
        assertions.assertQuery(
                "SELECT y, approx_set(y) FILTER (WHERE x = 1) FROM (SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
                "VALUES (INTEGER '20', CAST(NULL AS HyperLogLog)), (CAST(NULL AS INTEGER), CAST(NULL AS HyperLogLog))");
        assertions.assertQuery(
                "SELECT y, approx_set(IF(x = 1,y)) FROM (SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
                "VALUES (INTEGER '20', CAST(NULL AS HyperLogLog)), (CAST(NULL AS INTEGER), CAST(NULL AS HyperLogLog))");
    }

    @Test
    public void testSetUnion()
    {
        assertions.assertQuery(
                "SELECT set_union(x) FILTER (WHERE y > 1) FROM (SELECT ARRAY[1] x, 1 y UNION ALL SELECT NULL x, 1 y)",
                "VALUES (CAST (NULL AS ARRAY<INTEGER>))");
        assertions.assertQuery(
                "SELECT set_union(IF(y > 1, x)) FROM (SELECT ARRAY[1] x, 1 y UNION ALL SELECT NULL x, 1 y)",
                "VALUES (CAST (ARRAY[] AS ARRAY<INTEGER>))");
    }

    @Test
    public void testMapUnion()
    {
        assertions.assertQuery(
                "SELECT map_union(x) FILTER (WHERE y > 1) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
                "VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
        assertions.assertQuery(
                "SELECT map_union(IF(y > 1, x)) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
                "VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
    }

    @Test
    public void testMapUnionSum()
    {
        assertions.assertQuery(
                "SELECT map_union_sum(x) FILTER (WHERE y > 1) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
                "VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
        assertions.assertQuery(
                "SELECT map_union_sum(IF(y > 1, x)) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
                "VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
    }

    @Test
    public void testGroupingSets()
    {
        assertions.assertQuery(
                "SELECT k, count(DISTINCT x) FILTER (WHERE y = 100), count(DISTINCT x) FILTER (WHERE y = 200) FROM " +
                        "(VALUES " +
                        "   (1, 1, 100)," +
                        "   (1, 1, 200)," +
                        "   (1, 2, 100)," +
                        "   (1, 3, 300)," +
                        "   (2, 1, 100)," +
                        "   (2, 10, 100)," +
                        "   (2, 20, 100)," +
                        "   (2, 20, 200)," +
                        "   (2, 30, 300)," +
                        "   (2, 40, 100)" +
                        ") t(k, x, y) " +
                        "GROUP BY GROUPING SETS ((), (k))",
                "VALUES " +
                        "(1, BIGINT '2', BIGINT '1'), " +
                        "(2, BIGINT '4', BIGINT '1'), " +
                        "(CAST(NULL AS INTEGER), BIGINT '5', BIGINT '2')");

        assertions.assertQuery(
                "SELECT k, count(DISTINCT IF(y = 100, x)),  count(DISTINCT IF(y = 200, x)) FROM " +
                        "(VALUES " +
                        "   (1, 1, 100)," +
                        "   (1, 1, 200)," +
                        "   (1, 2, 100)," +
                        "   (1, 3, 300)," +
                        "   (2, 1, 100)," +
                        "   (2, 10, 100)," +
                        "   (2, 20, 100)," +
                        "   (2, 20, 200)," +
                        "   (2, 30, 300)," +
                        "   (2, 40, 100)" +
                        ") t(k, x, y) " +
                        "GROUP BY GROUPING SETS ((), (k))",
                "VALUES " +
                        "(1, BIGINT '2', BIGINT '1'), " +
                        "(2, BIGINT '4', BIGINT '1'), " +
                        "(CAST(NULL AS INTEGER), BIGINT '5', BIGINT '2')");
    }

    @Test
    public void rewriteAddFilterWithMultipleFilters()
    {
        assertPlan(
                "SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FILTER(WHERE custkey > 0) FROM orders",
                anyTree(
                        filter(
                                "(\"totalprice\" > 0E0 OR \"custkey\" > BIGINT '0')",
                                tableScan("orders", ImmutableMap.of("totalprice", "totalprice", "custkey", "custkey")))));

        assertPlan(
                "SELECT sum(IF(totalprice > 0, totalprice)), sum(IF(custkey > 0, custkey)) FROM orders",
                anyTree(
                        filter(
                                "(\"totalprice\" > 0E0 OR \"custkey\" > BIGINT '0')",
                                tableScan("orders", ImmutableMap.of("totalprice", "totalprice", "custkey", "custkey")))));
    }

    @Test
    public void testDoNotPushdownPredicateIfNonFilteredAggregateIsPresent()
    {
        assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FROM orders");
        assertPlanContainsNoFilter("SELECT sum(IF(totalprice > 0, totalprice)), sum(custkey) FROM orders");
    }

    @Test
    public void testPushDownConstantFilterPredicate()
    {
        assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE FALSE) FROM orders");
        assertPlanContainsNoFilter("SELECT sum(IF(FALSE, totalprice)) FROM orders");

        assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE TRUE) FROM orders");
        assertPlanContainsNoFilter("SELECT sum(IF(TRUE, totalprice)) FROM orders");
    }

    @Test
    public void testNoFilterAddedForConstantValueFilters()
    {
        assertPlanContainsNoFilter("SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x) GROUP BY x");
        assertPlanContainsNoFilter("SELECT sum(IF(x > 0, x)) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x) GROUP BY x");

        assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0) FROM orders GROUP BY totalprice");
        assertPlanContainsNoFilter("SELECT sum(IF(totalprice > 0, totalprice)) FROM orders GROUP BY totalprice");
    }

    private void assertPlanContainsNoFilter(String sql)
    {
        assertFalse(
                searchFrom(plan(sql, Optimizer.PlanStage.OPTIMIZED).getRoot())
                        .where(isInstanceOfAny(FilterNode.class))
                        .matches(),
                "Unexpected node for query: " + sql);
    }
}