TestStatementBuilder.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.parser;

import com.facebook.presto.sql.SqlFormatter;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Statement;
import com.google.common.io.Resources;
import org.testng.annotations.Test;

import java.io.IOException;
import java.util.Optional;

import static com.facebook.presto.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE;
import static com.facebook.presto.sql.testing.TreeAssertions.assertFormattedSql;
import static com.google.common.base.Strings.repeat;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;

public class TestStatementBuilder
{
    private static final SqlParser SQL_PARSER = new SqlParser();

    @Test
    public void testStatementBuilder()
    {
        printStatement("select * from foo");
        printStatement("explain select * from foo");
        printStatement("explain (type distributed, format graphviz) select * from foo");

        printStatement("select * from foo /* end */");
        printStatement("/* start */ select * from foo");
        printStatement("/* start */ select * /* middle */ from foo /* end */");
        printStatement("-- start\nselect * -- junk\n-- hi\nfrom foo -- done");

        printStatement("select * from foo a (x, y, z)");

        printStatement("select *, 123, * from foo");

        printStatement("select show from foo");
        printStatement("select extract(day from x), extract(dow from x) from y");

        printStatement("select 1 + 13 || '15' from foo");

        printStatement("select x is distinct from y from foo where a is not distinct from b");

        printStatement("select x[1] from my_table");
        printStatement("select x[1][2] from my_table");
        printStatement("select x[cast(10 * sin(x) as bigint)] from my_table");

        printStatement("select * from unnest(t.my_array)");
        printStatement("select * from unnest(array[1, 2, 3])");
        printStatement("select x from unnest(array[1, 2, 3]) t(x)");
        printStatement("select * from users cross join unnest(friends)");
        printStatement("select id, friend from users cross join unnest(friends) t(friend)");
        printStatement("select * from unnest(t.my_array) with ordinality");
        printStatement("select * from unnest(array[1, 2, 3]) with ordinality");
        printStatement("select x from unnest(array[1, 2, 3]) with ordinality t(x)");
        printStatement("select * from users cross join unnest(friends) with ordinality");
        printStatement("select id, friend from users cross join unnest(friends) with ordinality t(friend)");

        printStatement("select count(*) x from src group by k, v");
        printStatement("select count(*) x from src group by cube (k, v)");
        printStatement("select count(*) x from src group by rollup (k, v)");
        printStatement("select count(*) x from src group by grouping sets ((k, v))");
        printStatement("select count(*) x from src group by grouping sets ((k, v), (v))");
        printStatement("select count(*) x from src group by grouping sets (k, v, k)");

        printStatement("select count(*) filter (where x > 4) y from t");
        printStatement("select sum(x) filter (where x > 4) y from t");
        printStatement("select sum(x) filter (where x > 4) y, sum(x) filter (where x < 2) z from t");
        printStatement("select sum(distinct x) filter (where x > 4) y, sum(x) filter (where x < 2) z from t");
        printStatement("select sum(x) filter (where x > 4) over (partition by y) z from t");

        printStatement("" +
                "select depname, empno, salary\n" +
                ", count(*) over ()\n" +
                ", avg(salary) over (partition by depname)\n" +
                ", rank() over (partition by depname order by salary desc)\n" +
                ", sum(salary) over (order by salary rows unbounded preceding)\n" +
                ", sum(salary) over (partition by depname order by salary rows between current row and 3 following)\n" +
                ", sum(salary) over (partition by depname range unbounded preceding)\n" +
                ", sum(salary) over (rows between 2 preceding and unbounded following)\n" +
                ", lag(salary, 1) ignore nulls over (partition by depname)\n" +
                ", lag(salary, 1) respect nulls over (partition by depname)\n" +
                "from emp");

        printStatement("" +
                "with a (id) as (with x as (select 123 from z) select * from x) " +
                "   , b (id) as (select 999 from z) " +
                "select * from a join b using (id)");

        printStatement("with recursive t as (select * from x) select * from t");

        printStatement("select * from information_schema.tables");

        printStatement("show catalogs");

        printStatement("show schemas");
        printStatement("show schemas from sys");

        printStatement("show tables");
        printStatement("show tables from information_schema");
        printStatement("show tables like '%'");
        printStatement("show tables from information_schema like '%'");

        printStatement("show functions");

        printStatement("select cast('123' as bigint), try_cast('foo' as bigint)");

        printStatement("select * from a.b.c");
        printStatement("select * from a.b.c.e.f.g");

        printStatement("select \"TOTALPRICE\" \"my price\" from \"$MY\"\"ORDERS\"");

        printStatement("select * from foo tablesample system (10+1)");
        printStatement("select * from foo tablesample system (10) join bar tablesample bernoulli (30) on a.id = b.id");
        printStatement("select * from foo tablesample system (10) join bar tablesample bernoulli (30) on not(a.id > b.id)");

        printStatement("select * from foo for version as of 8772871542276440693");
        printStatement("select * from foo for system_version as of 8772871542276440693");
        printStatement("select * from foo for timestamp as of timestamp '2023-08-17 13:29:46.822 America/Los_Angeles'");
        printStatement("select * from foo for system_time as of timestamp '2023-08-17 13:29:46.822 America/Los_Angeles'");
        printStatement("select * from foo for version before 8772871542276440693");
        printStatement("select * from foo for system_version before 8772871542276440693");
        printStatement("select * from foo for timestamp before timestamp '2023-08-17 13:29:46.822 America/Los_Angeles'");
        printStatement("select * from foo for system_time before timestamp '2023-08-17 13:29:46.822 America/Los_Angeles'");

        printStatement("create table foo as (select * from abc)");
        printStatement("create table if not exists foo as (select * from abc)");
        printStatement("create table foo with (a = 'apple', b = 'banana') as select * from abc");
        printStatement("create table foo comment 'test' with (a = 'apple') as select * from abc");
        printStatement("create table foo as select * from abc WITH NO DATA");

        printStatement("create table foo as (with t(x) as (values 1) select x from t)");
        printStatement("create table if not exists foo as (with t(x) as (values 1) select x from t)");
        printStatement("create table foo as (with t(x) as (values 1) select x from t) WITH DATA");
        printStatement("create table if not exists foo as (with t(x) as (values 1) select x from t) WITH DATA");
        printStatement("create table foo as (with t(x) as (values 1) select x from t) WITH NO DATA");
        printStatement("create table if not exists foo as (with t(x) as (values 1) select x from t) WITH NO DATA");

        printStatement("create table foo(a) as (with t(x) as (values 1) select x from t)");
        printStatement("create table if not exists foo(a) as (with t(x) as (values 1) select x from t)");
        printStatement("create table foo(a) as (with t(x) as (values 1) select x from t) WITH DATA");
        printStatement("create table if not exists foo(a) as (with t(x) as (values 1) select x from t) WITH DATA");
        printStatement("create table foo(a) as (with t(x) as (values 1) select x from t) WITH NO DATA");
        printStatement("create table if not exists foo(a) as (with t(x) as (values 1) select x from t) WITH NO DATA");
        printStatement("drop table foo");

        printStatement("insert into foo select * from abc");

        printStatement("delete from foo");
        printStatement("delete from foo where a = b");

        printStatement("truncate table foo");

        printStatement("values ('a', 1, 2.2), ('b', 2, 3.3)");

        printStatement("table foo");
        printStatement("table foo order by x limit 10");
        printStatement("(table foo)");
        printStatement("(table foo) limit 10");
        printStatement("(table foo limit 5) limit 10");

        printStatement("select * from a limit all");
        printStatement("select * from a order by x limit all");

        printStatement("select * from a union select * from b");
        printStatement("table a union all table b");
        printStatement("(table foo) union select * from foo union (table foo order by x)");

        printStatement("table a union table b intersect table c");
        printStatement("(table a union table b) intersect table c");
        printStatement("table a union table b except table c intersect table d");
        printStatement("(table a union table b except table c) intersect table d");
        printStatement("((table a union table b) except table c) intersect table d");
        printStatement("(table a union (table b except table c)) intersect table d");
        printStatement("table a intersect table b union table c");
        printStatement("table a intersect (table b union table c)");

        printStatement("alter table foo rename to bar");
        printStatement("alter table a.b.c rename to d.e.f");

        printStatement("alter table foo set properties (a='1')");
        printStatement("alter table a.b.c set properties (a=true, b=123, c='x')");
        printStatement("alter table if exists bar set properties (b='1')");

        printStatement("alter table a.b.c rename column x to y");

        printStatement("alter table a.b.c add column x bigint");

        printStatement("alter table a.b.c add column x bigint comment 'large x'");
        printStatement("alter table a.b.c add column x bigint with (weight = 2)");
        printStatement("alter table a.b.c add column x bigint comment 'xtra' with (compression = 'LZ4', special = true)");

        printStatement("alter table a.b.c drop column x");

        printStatement("create schema test");
        printStatement("create schema if not exists test");
        printStatement("create schema test with (a = 'apple', b = 123)");

        printStatement("drop schema test");
        printStatement("drop schema test cascade");
        printStatement("drop schema if exists test");
        printStatement("drop schema if exists test restrict");

        printStatement("alter schema foo rename to bar");
        printStatement("alter schema foo.bar rename to baz");

        printStatement("create table test (a boolean, b bigint, c double, d varchar, e timestamp)");
        printStatement("create table test (a boolean, b bigint comment 'test')");
        printStatement("create table if not exists baz (a timestamp, b varchar)");
        printStatement("create table test (a boolean, b bigint) with (a = 'apple', b = 'banana')");
        printStatement("create table test (a boolean, b bigint) comment 'test' with (a = 'apple')");
        printStatement("create table test (a boolean with (a = 'apple', b = 'banana'), b bigint comment 'bla' with (c = 'cherry')) comment 'test' with (a = 'apple')");
        printStatement("drop table test");

        printStatement("create view foo as with a as (select 123) select * from a");
        printStatement("create or replace view foo as select 123 from t");

        printStatement("drop view foo");

        printStatement("insert into t select * from t");
        printStatement("insert into t (c1, c2) select * from t");

        printStatement("start transaction");
        printStatement("start transaction isolation level read uncommitted");
        printStatement("start transaction isolation level read committed");
        printStatement("start transaction isolation level repeatable read");
        printStatement("start transaction isolation level serializable");
        printStatement("start transaction read only");
        printStatement("start transaction read write");
        printStatement("start transaction isolation level read committed, read only");
        printStatement("start transaction read only, isolation level read committed");
        printStatement("start transaction read write, isolation level serializable");
        printStatement("commit");
        printStatement("commit work");
        printStatement("rollback");
        printStatement("rollback work");

        printStatement("call foo()");
        printStatement("call foo(123, a => 1, b => 'go', 456)");

        printStatement("grant select on foo to alice with grant option");
        printStatement("grant all privileges on foo to alice");
        printStatement("grant delete, select on foo to role public");
        printStatement("revoke grant option for select on foo from alice");
        printStatement("revoke all privileges on foo from alice");
        printStatement("revoke insert, delete on foo from role public");
        printStatement("show grants on table t");
        printStatement("show grants on t");
        printStatement("show grants");
        printStatement("show roles");
        printStatement("show roles from foo");
        printStatement("show current roles");
        printStatement("show current roles from foo");
        printStatement("show role grants");
        printStatement("show role grants from foo");

        printStatement("prepare p from select * from (select * from T) \"A B\"");

        printStatement("SELECT * FROM table1 WHERE a >= ALL (VALUES 2, 3, 4)");
        printStatement("SELECT * FROM table1 WHERE a <> ANY (SELECT 2, 3, 4)");
        printStatement("SELECT * FROM table1 WHERE a = SOME (SELECT id FROM table2)");

        // Table Constraints
        printStatement("alter table foo add constraint bar unique (col1, col2, col3)");
        printStatement("alter table foo add constraint bar primary key (col1, col2)");
        printStatement("alter table foo add constraint bar primary key (col1, col2) disabled not rely not enforced");
        printStatement("alter table foo add constraint bar primary key (col1, col2) enforced");
        printStatement("alter table foo add constraint bar primary key (col1, col2) disabled not rely");
        printStatement("alter table foo add constraint bar unique (col1, col2, col3) disabled");
        printStatement("alter table foo add constraint bar unique (col1, col2, col3) not rely not enforced");
        printStatement("alter table a.b.c add constraint bar unique (col1, col2, col3)");
        printStatement("create table t1 (c1 int, c2 varchar, c3 double, c4 int, primary key (c1, c2))");
        printStatement("create table t1 (c1 int, c2 varchar, c3 double, c4 int, primary key (c1, c2), unique (c4), unique (c3))");
        printStatement("create table t1 (c1 int, c2 varchar, c3 double, c4 int, constraint pk1 primary key (c1, c2))");
        printStatement("create table t1 (c1 int, c2 varchar, c3 double, c4 int, constraint pk1 primary key (c1, c2), constraint uq1 unique (c4), unique (c3))");
        printStatement("create table t1 (c1 int, c2 varchar, c3 double, c4 int, constraint pk1 primary key (c1, c2) disabled not rely enforced , constraint uq1 unique (c4) not rely enforced, unique (c3) disabled)");

        printStatement("alter table foo alter column bar set not null");
        printStatement("alter table foo alter column bar drop not null");
        printStatement("alter table if exists foo alter bar set not null");
        printStatement("alter table if exists foo alter bar drop not null");
    }

    @Test
    public void testStringFormatter()
    {
        assertSqlFormatter("U&'hello\\6d4B\\8Bd5\\+10FFFFworld\\7F16\\7801'",
                "U&'hello\\6D4B\\8BD5\\+10FFFFworld\\7F16\\7801'");
        assertSqlFormatter("'hello world'", "'hello world'");
        assertSqlFormatter("U&'!+10FFFF!6d4B!8Bd5ABC!6d4B!8Bd5' UESCAPE '!'", "U&'\\+10FFFF\\6D4B\\8BD5ABC\\6D4B\\8BD5'");
        assertSqlFormatter("U&'\\+10FFFF\\6D4B\\8BD5\\0041\\0042\\0043\\6D4B\\8BD5'", "U&'\\+10FFFF\\6D4B\\8BD5ABC\\6D4B\\8BD5'");
        assertSqlFormatter("U&'\\\\abc\\6D4B'''", "U&'\\\\abc\\6D4B'''");
    }

    @Test
    public void testStatementBuilderTpch()
            throws Exception
    {
        printTpchQuery(1, 3);
        printTpchQuery(2, 33, "part type like", "region name");
        printTpchQuery(3, "market segment", "2013-03-05");
        printTpchQuery(4, "2013-03-05");
        printTpchQuery(5, "region name", "2013-03-05");
        printTpchQuery(6, "2013-03-05", 33, 44);
        printTpchQuery(7, "nation name 1", "nation name 2");
        printTpchQuery(8, "nation name", "region name", "part type");
        printTpchQuery(9, "part name like");
        printTpchQuery(10, "2013-03-05");
        printTpchQuery(11, "nation name", 33);
        printTpchQuery(12, "ship mode 1", "ship mode 2", "2013-03-05");
        printTpchQuery(13, "comment like 1", "comment like 2");
        printTpchQuery(14, "2013-03-05");
        // query 15: views not supported
        printTpchQuery(16, "part brand", "part type like", 3, 4, 5, 6, 7, 8, 9, 10);
        printTpchQuery(17, "part brand", "part container");
        printTpchQuery(18, 33);
        printTpchQuery(19, "part brand 1", "part brand 2", "part brand 3", 11, 22, 33);
        printTpchQuery(20, "part name like", "2013-03-05", "nation name");
        printTpchQuery(21, "nation name");
        printTpchQuery(22,
                "phone 1",
                "phone 2",
                "phone 3",
                "phone 4",
                "phone 5",
                "phone 6",
                "phone 7");
    }

    private static void printStatement(String sql)
    {
        println(sql.trim());
        println("");

        ParsingOptions parsingOptions = new ParsingOptions(AS_DOUBLE /* anything */);
        Statement statement = SQL_PARSER.createStatement(sql, parsingOptions);
        println(statement.toString());
        println("");

        println(SqlFormatter.formatSql(statement, Optional.empty()));
        println("");
        assertFormattedSql(SQL_PARSER, statement);

        println(repeat("=", 60));
        println("");
    }

    private static void assertSqlFormatter(String expression, String formatted)
    {
        Expression originalExpression = SQL_PARSER.createExpression(expression, new ParsingOptions());
        String real = SqlFormatter.formatSql(originalExpression, Optional.empty());
        assertEquals(real, formatted);
    }

    private static void println(String s)
    {
        if (Boolean.parseBoolean(System.getProperty("printParse"))) {
            System.out.println(s);
        }
    }

    private static String getTpchQuery(int q)
            throws IOException
    {
        return readResource("tpch/queries/" + q + ".sql");
    }

    private static void printTpchQuery(int query, Object... values)
            throws IOException
    {
        String sql = getTpchQuery(query);

        for (int i = values.length - 1; i >= 0; i--) {
            sql = sql.replaceAll(format(":%s", i + 1), String.valueOf(values[i]));
        }

        assertFalse(sql.matches("(?s).*:[0-9].*"), "Not all bind parameters were replaced: " + sql);

        sql = fixTpchQuery(sql);
        printStatement(sql);
    }

    private static String readResource(String name)
            throws IOException
    {
        return Resources.toString(Resources.getResource(name), UTF_8);
    }

    private static String fixTpchQuery(String s)
    {
        s = s.replaceFirst("(?m);$", "");
        s = s.replaceAll("(?m)^:[xo]$", "");
        s = s.replaceAll("(?m)^:n -1$", "");
        s = s.replaceAll("(?m)^:n ([0-9]+)$", "LIMIT $1");
        s = s.replace("day (3)", "day"); // for query 1
        return s;
    }
}