TestNativeSidecarPlugin.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sidecar;

import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils;
import com.facebook.presto.sidecar.functionNamespace.FunctionDefinitionProvider;
import com.facebook.presto.sidecar.functionNamespace.NativeFunctionDefinitionProvider;
import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManager;
import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory;
import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProvider;
import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory;
import com.facebook.presto.sidecar.typemanager.NativeTypeManagerFactory;
import com.facebook.presto.spi.function.FunctionNamespaceManager;
import com.facebook.presto.spi.function.SqlFunction;
import com.facebook.presto.spi.session.WorkerSessionPropertyProvider;
import com.facebook.presto.testing.MaterializedResult;
import com.facebook.presto.testing.MaterializedRow;
import com.facebook.presto.testing.QueryRunner;
import com.facebook.presto.tests.AbstractTestQueryFramework;
import com.facebook.presto.tests.DistributedQueryRunner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.units.DataSize;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.Test;

import java.util.List;
import java.util.UUID;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static com.facebook.presto.common.Utils.checkArgument;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createNation;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrdersEx;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createRegion;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.lang.String.format;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.fail;

public class TestNativeSidecarPlugin
        extends AbstractTestQueryFramework
{
    private static final String REGEX_FUNCTION_NAMESPACE = "native.default.*";
    private static final String REGEX_SESSION_NAMESPACE = "Native Execution only.*";
    private static final long SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB = 128;

    @Override
    protected void createTables()
    {
        QueryRunner queryRunner = (QueryRunner) getExpectedQueryRunner();
        createLineitem(queryRunner);
        createNation(queryRunner);
        createOrders(queryRunner);
        createOrdersEx(queryRunner);
        createRegion(queryRunner);
    }

    @Override
    protected QueryRunner createQueryRunner()
            throws Exception
    {
        DistributedQueryRunner queryRunner = (DistributedQueryRunner) PrestoNativeQueryRunnerUtils.nativeHiveQueryRunnerBuilder()
                .setAddStorageFormatToPath(true)
                .setCoordinatorSidecarEnabled(true)
                .build();
        setupNativeSidecarPlugin(queryRunner);
        return queryRunner;
    }

    @Override
    protected QueryRunner createExpectedQueryRunner()
            throws Exception
    {
        return PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder()
                .setAddStorageFormatToPath(true)
                .build();
    }

    public static void setupNativeSidecarPlugin(QueryRunner queryRunner)
    {
        queryRunner.installCoordinatorPlugin(new NativeSidecarPlugin());
        queryRunner.loadSessionPropertyProvider(
                NativeSystemSessionPropertyProviderFactory.NAME,
                ImmutableMap.of("sidecar.http-client.max-content-length", SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB + "MB"));
        queryRunner.loadFunctionNamespaceManager(
                NativeFunctionNamespaceManagerFactory.NAME,
                "native",
                ImmutableMap.of(
                        "supported-function-languages", "CPP",
                        "function-implementation-type", "CPP",
                        "sidecar.http-client.max-content-length", SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB + "MB"));
        queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME);
        queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of());
    }

    @Test
    public void testHttpClientProperties()
    {
        WorkerSessionPropertyProvider sessionPropertyProvider = getQueryRunner().getMetadata().getSessionPropertyManager().getWorkerSessionPropertyProviders().get(NativeSystemSessionPropertyProviderFactory.NAME);
        checkArgument(sessionPropertyProvider instanceof NativeSystemSessionPropertyProvider, "Expected  NativeSystemSessionPropertyProvider but got  %s", sessionPropertyProvider);
        long sessionProviderHttpClientConfigContentSize = ((NativeSystemSessionPropertyProvider) sessionPropertyProvider).getHttpClient().getMaxContentLength();
        assertEquals(sessionProviderHttpClientConfigContentSize, new DataSize(SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB, MEGABYTE).toBytes());

        FunctionNamespaceManager<? extends SqlFunction> functionNamespaceManager = getQueryRunner().getMetadata().getFunctionAndTypeManager().getFunctionNamespaceManagers().get(NativeFunctionNamespaceManagerFactory.NAME);
        checkArgument(functionNamespaceManager instanceof NativeFunctionNamespaceManager, "Expected  NativeFunctionNamespaceManager but got  %s", functionNamespaceManager);
        FunctionDefinitionProvider functionDefinitionProvider = ((NativeFunctionNamespaceManager) functionNamespaceManager).getFunctionDefinitionProvider();
        checkArgument(functionDefinitionProvider instanceof NativeFunctionDefinitionProvider, "Expected  NativeFunctionDefinitionProvider but got %s", functionDefinitionProvider);
        long functionProviderHttpClientConfigContentSize = ((NativeFunctionDefinitionProvider) functionDefinitionProvider).getHttpClient().getMaxContentLength();
        assertEquals(functionProviderHttpClientConfigContentSize, new DataSize(SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB, MEGABYTE).toBytes());
    }

    @Test
    public void testShowSession()
    {
        @Language("SQL") String sql = "SHOW SESSION";
        MaterializedResult actualResult = computeActual(sql);
        List<MaterializedRow> actualRows = actualResult.getMaterializedRows();
        List<MaterializedRow> filteredRows = excludeSystemSessionProperties(actualRows);
        assertFalse(filteredRows.isEmpty());
    }

    @Test
    public void testSetJavaWorkerSessionProperty()
    {
        assertQueryFails("SET SESSION aggregation_spill_enabled=false", "line 1:1: Session property aggregation_spill_enabled does not exist");
    }

    @Test
    public void testSetNativeWorkerSessionProperty()
    {
        @Language("SQL") String setSession = "SET SESSION driver_cpu_time_slice_limit_ms=500";
        MaterializedResult setSessionResult = computeActual(setSession);
        assertEquals(
                setSessionResult.toString(),
                "MaterializedResult{rows=[[true]], " +
                        "types=[boolean], " +
                        "setSessionProperties={driver_cpu_time_slice_limit_ms=500}, " +
                        "resetSessionProperties=[], updateInfo=UpdateInfo{updateType='SET SESSION', updateObject=''}}");
    }

    @Test
    public void testShowFunctions()
    {
        @Language("SQL") String sql = "SHOW FUNCTIONS";
        MaterializedResult actualResult = computeActual(sql);
        List<MaterializedRow> actualRows = actualResult.getMaterializedRows();
        for (MaterializedRow actualRow : actualRows) {
            List<Object> row = actualRow.getFields();
            // No namespace should be present on the functionNames
            String functionName = row.get(0).toString();
            if (Pattern.matches(REGEX_FUNCTION_NAMESPACE, functionName)) {
                fail(format("Namespace match found for row: %s", row));
            }

            // function namespace should be present.
            String fullFunctionName = row.get(5).toString();
            if (Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) {
                continue;
            }
            fail(format("No namespace match found for row: %s", row));
        }
    }

    @Test
    public void testGeneralQueries()
    {
        assertQuery("SELECT ARRAY['abc']");
        assertQuery("SELECT ARRAY[1, 2, 3]");
        assertQuery("SELECT substr(comment, 1, 10), length(comment), trim(comment) FROM orders");
        assertQuery("SELECT substr(comment, 1, 10), length(comment), ltrim(comment) FROM orders");
        assertQuery("SELECT substr(comment, 1, 10), length(comment), rtrim(comment) FROM orders");
        assertQuery("select lower(comment) from nation");
        assertQuery("SELECT trim(comment, ' ns'), ltrim(comment, 'a b c'), rtrim(comment, 'l y') FROM orders");
        assertQuery("select array[nationkey], array_constructor(comment) from nation");
        assertQuery("SELECT nationkey, bit_count(nationkey, 10) FROM nation ORDER BY 1");
        assertQuery("SELECT * FROM lineitem WHERE shipinstruct like 'TAKE BACK%'");
        assertQuery("SELECT * FROM lineitem WHERE shipinstruct like 'TAKE BACK#%' escape '#'");
        assertQuery("SELECT orderkey, date_trunc('year', from_unixtime(orderkey, '-03:00')), date_trunc('quarter', from_unixtime(orderkey, '+14:00')), " +
                "date_trunc('month', from_unixtime(orderkey, '+03:00')), date_trunc('day', from_unixtime(orderkey, '-07:00')), " +
                "date_trunc('hour', from_unixtime(orderkey, '-09:30')), date_trunc('minute', from_unixtime(orderkey, '+05:30')), " +
                "date_trunc('second', from_unixtime(orderkey, '+00:00')) FROM orders");
        assertQuery("SELECT mod(orderkey, linenumber) FROM lineitem");
        assertQueryFails("SELECT IF(true, 0/0, 1)", "[\\s\\S]*/ by zero native.default.fail[\\s\\S]*");
    }

    @Test
    public void testAggregateFunctions()
    {
        assertQuery("select corr(nationkey, nationkey) from nation");
        assertQuery("select count(comment) from orders");
        assertQuery("select count(*) from nation");
        assertQuery("select count(abs(orderkey) between 1 and 60000) from orders group by orderkey");
        assertQuery("SELECT count(orderkey) FROM orders WHERE orderkey < 0 GROUP BY GROUPING SETS (())");
        // tinyint
        assertQuery("SELECT sum(cast(linenumber as tinyint)), sum(cast(linenumber as tinyint)) FROM lineitem");
        // smallint
        assertQuery("SELECT sum(cast(linenumber as smallint)), sum(cast(linenumber as smallint)) FROM lineitem");
        // integer
        assertQuery("SELECT sum(linenumber), sum(linenumber) FROM lineitem");
        // bigint
        assertQuery("SELECT sum(orderkey), sum(orderkey) FROM lineitem");
        // real
        assertQuery("SELECT sum(tax_as_real), sum(tax_as_real) FROM lineitem");
        // double
        assertQuery("SELECT sum(quantity), sum(quantity) FROM lineitem");
        // date
        assertQuery("SELECT approx_distinct(orderdate, 0.023) FROM orders");
        // timestamp
        assertQuery("SELECT approx_distinct(CAST(orderdate AS TIMESTAMP)) FROM orders");
        assertQuery("SELECT approx_distinct(CAST(orderdate AS TIMESTAMP), 0.023) FROM orders");
        assertQuery("SELECT checksum(from_unixtime(orderkey, '+01:00')) FROM lineitem WHERE orderkey < 20");
        assertQuerySucceeds("SELECT shuffle(array_sort(quantities)) FROM orders_ex");
        assertQuery("SELECT array_sort(shuffle(quantities)) FROM orders_ex");
        assertQuery("SELECT orderkey, array_sort(reduce_agg(linenumber, CAST(array[] as ARRAY(INTEGER)), (s, x) -> s || x, (s, s2) -> s || s2)) FROM lineitem group by orderkey");
    }

    @Test
    public void testWindowFunctions()
    {
        assertQuery("SELECT * FROM (SELECT row_number() over(partition by orderstatus order by orderkey, orderstatus) rn, * from orders) WHERE rn = 1");
        assertQuery("WITH t AS (SELECT linenumber, row_number() over (partition by linenumber order by linenumber) as rn FROM lineitem) SELECT * FROM t WHERE rn = 1");
        assertQuery("SELECT row_number() OVER (PARTITION BY orderdate ORDER BY orderdate) FROM orders");
        assertQuery("SELECT min(orderkey) OVER (PARTITION BY orderdate ORDER BY orderdate, totalprice) FROM orders");
        assertQuery("SELECT sum(rn) FROM (SELECT row_number() over() rn, * from orders) WHERE rn = 10");
        assertQuery("SELECT * FROM (SELECT row_number() over(partition by orderstatus order by orderkey) rn, * from orders) WHERE rn = 1");
        assertQuery("SELECT first_value(orderdate) OVER (PARTITION BY orderkey ORDER BY totalprice RANGE BETWEEN 5 PRECEDING AND CURRENT ROW) FROM orders");
        assertQuery("SELECT lead(orderkey, 5) OVER (PARTITION BY custkey, orderdate ORDER BY totalprice desc ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) FROM orders");
    }

    @Test
    public void testLambdaFunctions()
    {
        // These function signatures are only supported in the native execution engine
        assertQuerySucceeds("select array_sort(array[row('apples', 23), row('bananas', 12), row('grapes', 44)], x -> x[2])");
        assertQuerySucceeds("SELECT array_sort(quantities, x -> abs(x)) FROM orders_ex");
        assertQuerySucceeds("SELECT array_sort(quantities, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) FROM orders_ex");

        assertQuery("SELECT array_sort(map_keys(map_union(quantity_by_linenumber))) FROM orders_ex");
        assertQuery("SELECT filter(quantities, q -> q > 10) FROM orders_ex");
        assertQuery("SELECT all_match(shuffle(quantities), x -> (x > 500.0)) FROM orders_ex");
        assertQuery("SELECT any_match(quantities, x -> TRY(((10 / x) > 2))) FROM orders_ex");
        assertQuery("SELECT TRY(none_match(quantities, x -> ((10 / x) > 2))) FROM orders_ex");
        assertQuery("SELECT reduce(array[nationkey, regionkey], 103, (s, x) -> s + x, s -> s) FROM nation");
        assertQuery("SELECT transform(array[1, 2, 3], x -> x * regionkey + nationkey) FROM nation");
        assertQueryFails(
                "SELECT array_sort(quantities, (x, y, z) -> if (x < y + z, cast(1 as bigint), if (x > y + z, cast(-1 as bigint), cast(0 as bigint)))) FROM orders_ex",
                "Failed to find matching function signature for array_sort, matching failures: \n" +
                        " Exception 1: line 1:31: Expected a lambda that takes ([12])" + Pattern.quote(" argument(s) but got 3\n") +
                        " Exception 2: line 1:31: Expected a lambda that takes ([12])" + Pattern.quote(" argument(s) but got 3\n"));
    }

    @Test
    public void testInformationSchemaTables()
    {
        assertQueryFails("select lower(table_name) from information_schema.tables "
                        + "where table_name = 'lineitem' or table_name = 'LINEITEM' ",
                "Compiler failed");
    }

    @Test
    public void testShowStats()
    {
        String tmpTableName = generateRandomTableName();
        try {
            getQueryRunner().execute(String.format("CREATE TABLE %s (c0 DECIMAL(15,2), c1 DECIMAL(38,2)) WITH (format = 'PARQUET')", tmpTableName));
            getQueryRunner().execute(String.format("INSERT INTO %s VALUES (DECIMAL '0', DECIMAL '0'), (DECIMAL '1.2', DECIMAL '3.4'), "
                    + "(DECIMAL '1000000.12', DECIMAL '28239823232323.57'), " +
                    "(DECIMAL '-542392.89', DECIMAL '-6723982392109.29'), (NULL, NULL), "
                    + "(NULL, DECIMAL'-6723982392109.29'),(DECIMAL'1.2', NULL)", tmpTableName));
            assertQuery(String.format("SHOW STATS for %s", tmpTableName));
        }
        finally {
            dropTableIfExists(tmpTableName);
        }
    }

    @Test
    public void testAnalyzeStats()
    {
        assertUpdate("ANALYZE region", 5);

        // Show stats returns the following stats for each column in region table:
        // column_name | data_size | distinct_values_count | nulls_fraction | row_count | low_value | high_value
        assertQuery("SHOW STATS FOR region",
                "SELECT * FROM (VALUES" +
                        "('regionkey', NULL, 5e0, 0e0, NULL, '0', '4', NULL)," +
                        "('name', 5.4e1, 5e0, 0e0, NULL, NULL, NULL, NULL)," +
                        "('comment', 3.5e2, 5e0, 0e0, NULL, NULL, NULL, NULL)," +
                        "(NULL, NULL, NULL, NULL, 5e0, NULL, NULL, NULL))");

        // Create a partitioned table and run analyze on it.
        String tmpTableName = generateRandomTableName();
        try {
            getQueryRunner().execute(String.format("CREATE TABLE %s (name VARCHAR, regionkey BIGINT," +
                    "nationkey BIGINT) WITH (partitioned_by = ARRAY['regionkey','nationkey'])", tmpTableName));
            getQueryRunner().execute(
                    String.format("INSERT INTO %s SELECT name, regionkey, nationkey FROM nation", tmpTableName));
            assertQuery(String.format("SELECT * FROM %s", tmpTableName),
                    "SELECT name, regionkey, nationkey FROM nation");
            assertUpdate(String.format("ANALYZE %s", tmpTableName), 25);
            assertQuery(String.format("SHOW STATS for %s", tmpTableName),
                    "SELECT * FROM (VALUES" +
                            "('name', 2.77e2, 1e0, 0e0, NULL, NULL, NULL, NULL)," +
                            "('regionkey', NULL, 5e0, 0e0, NULL, '0', '4', NULL)," +
                            "('nationkey', NULL, 2.5e1, 0e0, NULL, '0', '24', NULL)," +
                            "(NULL, NULL, NULL, NULL, 2.5e1, NULL, NULL, NULL))");
            assertUpdate(String.format("ANALYZE %s WITH (partitions = ARRAY[ARRAY['0','0'],ARRAY['4', '11']])", tmpTableName), 2);
            assertQuery(String.format("SHOW STATS for (SELECT * FROM %s where regionkey=4 and nationkey=11)", tmpTableName),
                    "SELECT * FROM (VALUES" +
                            "('name', 8e0, 1e0, 0e0, NULL, NULL, NULL, NULL)," +
                            "('regionkey', NULL, 1e0, 0e0, NULL, '4', '4', NULL)," +
                            "('nationkey', NULL, 1e0, 0e0, NULL, '11', '11', NULL)," +
                            "(NULL, NULL, NULL, NULL, 1e0, NULL, NULL, NULL))");
        }
        finally {
            dropTableIfExists(tmpTableName);
        }
    }

    @Test
    public void testGeometryQueries()
    {
        assertQuery("SELECT ST_DISTANCE(ST_POINT(0,  0), ST_POINT(3, 4))");
        assertQuery("SELECT ST_CONTAINS(" +
                "ST_GeometryFromText('POLYGON((0 0, 0 10, 10 10, 10 0, 0 0))'), " +
                "ST_POINT(5, 5))");
        assertQuery("SELECT ST_POINT(nationkey, regionkey) from nation");
        assertQuery("SELECT " +
                "ST_DISTANCE(ST_POINT(a.nationkey, a.regionkey), ST_POINT(b.nationkey, b.regionkey)) " +
                "FROM nation a JOIN nation b ON a.nationkey < b.nationkey");
    }

    private String generateRandomTableName()
    {
        String tableName = "tmp_presto_" + UUID.randomUUID().toString().replace("-", "");
        // Clean up if the temporary named table already exists.
        dropTableIfExists(tableName);
        return tableName;
    }

    private void dropTableIfExists(String tableName)
    {
        // An ugly workaround for the lack of getExpectedQueryRunner()
        computeExpected(String.format("DROP TABLE IF EXISTS %s", tableName), ImmutableList.of(BIGINT));
    }

    private List<MaterializedRow> excludeSystemSessionProperties(List<MaterializedRow> inputRows)
    {
        return inputRows.stream()
                .filter(row -> Pattern.matches(REGEX_SESSION_NAMESPACE, row.getFields().get(4).toString()))
                .collect(Collectors.toList());
    }
}