TestTpchDistributedStats.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.tests;

import com.facebook.presto.tests.statistics.StatisticsAssertion;
import com.facebook.presto.tests.tpch.TpchQueryRunnerBuilder;
import com.facebook.presto.tpch.ColumnNaming;
import com.google.common.collect.ImmutableMap;
import io.airlift.tpch.TpchTable;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import static com.facebook.presto.SystemSessionProperties.PREFER_PARTIAL_AGGREGATION;
import static com.facebook.presto.SystemSessionProperties.PRINT_STATS_FOR_NON_JOIN_QUERY;
import static com.facebook.presto.tests.statistics.MetricComparisonStrategies.absoluteError;
import static com.facebook.presto.tests.statistics.MetricComparisonStrategies.defaultTolerance;
import static com.facebook.presto.tests.statistics.MetricComparisonStrategies.noError;
import static com.facebook.presto.tests.statistics.MetricComparisonStrategies.relativeError;
import static com.facebook.presto.tests.statistics.Metrics.OUTPUT_ROW_COUNT;
import static com.facebook.presto.tests.statistics.Metrics.distinctValuesCount;
import static com.facebook.presto.tpch.TpchConnectorFactory.TPCH_COLUMN_NAMING_PROPERTY;

public class TestTpchDistributedStats
{
    private StatisticsAssertion statisticsAssertion;

    @BeforeClass
    public void setup()
            throws Exception
    {
        DistributedQueryRunner runner = TpchQueryRunnerBuilder.builder()
                // We are not able to calculate stats for PARTIAL aggregations
                .amendSession(builder -> builder.setSystemProperty(PREFER_PARTIAL_AGGREGATION, "false"))
                .amendSession(builder -> builder.setSystemProperty(PRINT_STATS_FOR_NON_JOIN_QUERY, "true"))
                .buildWithoutCatalogs();
        runner.createCatalog(
                "tpch",
                "tpch",
                ImmutableMap.of(TPCH_COLUMN_NAMING_PROPERTY, ColumnNaming.STANDARD.name()));
        statisticsAssertion = new StatisticsAssertion(runner);
    }

    @AfterClass(alwaysRun = true)
    public void tearDown()
    {
        statisticsAssertion.close();
        statisticsAssertion = null;
    }

    @Test
    public void testTableScanStats()
    {
        TpchTable.getTables()
                .forEach(table -> statisticsAssertion.check("SELECT * FROM " + table.getTableName(),
                        checks -> checks.estimate(OUTPUT_ROW_COUNT, noError())));
    }

    @Test
    public void testFilter()
    {
        statisticsAssertion.check("SELECT * FROM lineitem WHERE l_shipdate <= DATE '1998-12-01' - INTERVAL '90' DAY",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, defaultTolerance()));
    }

    @Test
    public void testJoin()
    {
        statisticsAssertion.check("SELECT * FROM  part, partsupp WHERE p_partkey = ps_partkey",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, defaultTolerance()));
    }

    @Test
    public void testUnion()
    {
        statisticsAssertion.check("SELECT * FROM nation UNION SELECT * FROM nation",
                // real count is 25, estimation cannot know all rows are duplicate.
                checks -> checks.estimate(OUTPUT_ROW_COUNT, relativeError(1, 1)));

        statisticsAssertion.check("SELECT * FROM nation UNION ALL SELECT * FROM nation",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, noError()));

        statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 755 OR o_orderstatus = '0' UNION SELECT * FROM orders WHERE o_custkey > 755 OR o_orderstatus = 'F'",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, relativeError(.3, .35)));

        statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 755 OR o_orderstatus = '0' UNION ALL SELECT * FROM orders WHERE o_custkey > 755 OR o_orderstatus = 'F'",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, defaultTolerance()));

        statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 900 UNION SELECT * FROM orders WHERE o_custkey > 600",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, relativeError(.15, .25)));

        statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 900 UNION ALL SELECT * FROM orders WHERE o_custkey > 600",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, defaultTolerance()));
    }

    @Test
    public void testIntersect()
    {
        statisticsAssertion.check("SELECT * FROM nation INTERSECT SELECT * FROM nation",
                checks -> checks.noEstimate(OUTPUT_ROW_COUNT));

        statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 900 INTERSECT SELECT * FROM orders WHERE o_custkey > 600",
                checks -> checks.noEstimate(OUTPUT_ROW_COUNT));
    }

    @Test
    public void testExcept()
    {
        statisticsAssertion.check("SELECT * FROM nation EXCEPT SELECT * FROM nation",
                checks -> checks.noEstimate(OUTPUT_ROW_COUNT));

        statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 900 EXCEPT SELECT * FROM orders WHERE o_custkey > 600",
                checks -> checks.noEstimate(OUTPUT_ROW_COUNT));
    }

    @Test
    public void testEnforceSingleRow()
    {
        statisticsAssertion.check("SELECT (SELECT n_regionkey FROM nation WHERE n_name = 'nosuchvalue') AS sub",
                checks -> checks
                        .estimate(OUTPUT_ROW_COUNT, noError()));

        statisticsAssertion.check("SELECT (SELECT n_regionkey FROM nation WHERE n_name = 'GERMANY') AS sub",
                checks -> checks
                        .estimate(distinctValuesCount("sub"), noError())
                        .estimate(OUTPUT_ROW_COUNT, noError()));
    }

    @Test
    public void testValues()
    {
        statisticsAssertion.check("VALUES 1",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, noError()));
    }

    @Test
    public void testSemiJoin()
    {
        statisticsAssertion.check("SELECT * FROM nation WHERE n_regionkey IN (SELECT r_regionkey FROM region)",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, noError()));

        statisticsAssertion.check("SELECT * FROM nation WHERE n_regionkey IN (SELECT r_regionkey FROM region WHERE r_regionkey % 3 = 0)",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, absoluteError(15.)));
    }

    @Test
    public void testLimit()
    {
        statisticsAssertion.check("SELECT * FROM nation LIMIT 10",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, noError()));
    }

    @Test
    public void testGroupBy()
    {
        statisticsAssertion.check("SELECT l_returnflag, l_linestatus FROM lineitem GROUP BY l_returnflag, l_linestatus",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, absoluteError(2))); // real row count is 4
    }

    @Test
    public void testSort()
    {
        statisticsAssertion.check("SELECT * FROM nation ORDER BY n_nationkey",
                checks -> checks.estimate(OUTPUT_ROW_COUNT, noError()));
    }
}