TestJdbcConnection.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.jdbc;

import com.facebook.airlift.log.Logging;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.hive.HiveHadoop2Plugin;
import com.facebook.presto.server.testing.TestingPrestoServer;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.ConnectorTableMetadata;
import com.facebook.presto.spi.InMemoryRecordSet;
import com.facebook.presto.spi.RecordCursor;
import com.facebook.presto.spi.SchemaTableName;
import com.facebook.presto.spi.SystemTable;
import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.inject.Module;
import com.google.inject.Scopes;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType;
import static com.facebook.presto.jdbc.TestPrestoDriver.closeQuietly;
import static com.facebook.presto.metadata.MetadataUtil.TableMetadataBuilder.tableMetadataBuilder;
import static com.facebook.presto.spi.SystemTable.Distribution.ALL_NODES;
import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static java.lang.String.format;
import static org.assertj.core.api.Assertions.assertThat;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;

public class TestJdbcConnection
{
    private TestingPrestoServer server;

    @DataProvider(name = "customHeaderWithSpecialCharacter")
    public static Object[][] customHeaderWithSpecialCharacter()
    {
        return new Object[][] {{"test.com:1234"}, {"test@test.com"}};
    }

    @BeforeClass
    public void setupServer()
            throws Exception
    {
        Logging.initialize();
        Module systemTables = binder -> newSetBinder(binder, SystemTable.class)
                .addBinding().to(ExtraCredentialsSystemTable.class).in(Scopes.SINGLETON);
        server = new TestingPrestoServer(ImmutableList.of(systemTables));

        server.installPlugin(new HiveHadoop2Plugin());
        server.createCatalog("hive", "hive-hadoop2", ImmutableMap.<String, String>builder()
                .put("hive.metastore", "file")
                .put("hive.metastore.catalog.dir", server.getDataDirectory().resolve("hive").toFile().toURI().toString())
                .put("hive.security", "sql-standard")
                .build());

        try (Connection connection = createConnection();
                Statement statement = connection.createStatement()) {
            statement.execute("SET ROLE admin");
            statement.execute("CREATE SCHEMA default");
            statement.execute("CREATE SCHEMA fruit");
        }
    }

    @AfterClass(alwaysRun = true)
    public void teardownServer()
    {
        closeQuietly(server);
    }

    @Test
    public void testCommit()
            throws SQLException
    {
        try {
            try (Connection connection = createConnection()) {
                connection.setAutoCommit(false);
                try (Statement statement = connection.createStatement()) {
                    statement.execute("CREATE TABLE test_commit (x bigint)");
                }

                try (Connection otherConnection = createConnection()) {
                    assertThat(listTables(otherConnection)).doesNotContain("test_commit");
                }

                connection.commit();
            }

            try (Connection connection = createConnection()) {
                assertThat(listTables(connection)).contains("test_commit");
            }
        }
        finally {
            try (Connection connection = createConnection()) {
                try (Statement statement = connection.createStatement()) {
                    statement.execute("DROP TABLE test_commit");
                }
            }
        }
    }

    @Test
    public void testAutoCommit()
            throws SQLException
    {
        try {
            try (Connection connection = createConnection()) {
                connection.setAutoCommit(true);
                try (Statement statement = connection.createStatement()) {
                    statement.execute("CREATE TABLE test_autocommit (x bigint)");
                }
            }

            try (Connection connection = createConnection()) {
                assertThat(listTables(connection)).contains("test_autocommit");
            }
        }
        finally {
            try (Connection connection = createConnection()) {
                try (Statement statement = connection.createStatement()) {
                    statement.execute("DROP TABLE test_autocommit");
                }
            }
        }
    }

    @Test
    public void testResetAutoCommit()
            throws SQLException
    {
        try {
            try (Connection connection = createConnection()) {
                connection.setAutoCommit(false);
                try (Statement statement = connection.createStatement()) {
                    statement.execute("CREATE TABLE test_reset_autocommit (x bigint)");
                }

                try (Connection otherConnection = createConnection()) {
                    assertThat(listTables(otherConnection)).doesNotContain("test_reset_autocommit");
                }
                connection.setAutoCommit(true);
            }

            try (Connection connection = createConnection()) {
                assertThat(listTables(connection)).contains("test_reset_autocommit");
            }
        }
        finally {
            try (Connection connection = createConnection()) {
                try (Statement statement = connection.createStatement()) {
                    statement.execute("DROP TABLE test_reset_autocommit");
                }
            }
        }
    }

    @Test
    public void testTableType()
            throws SQLException
    {
        try (Connection connection = createConnection()) {
            assertThat(connection.getCatalog()).isEqualTo("hive");
            assertThat(connection.getSchema()).isEqualTo("default");

            try (Statement statement = connection.createStatement()) {
                statement.execute("CREATE TABLE test_table_type (x bigint)");
                statement.execute("CREATE VIEW table_type_view AS SELECT * FROM test_table_type");
                ResultSet rs = statement.executeQuery("SELECT TABLE_NAME, TABLE_TYPE FROM system.jdbc.tables WHERE TABLE_SCHEM = 'default' AND TABLE_NAME = 'table_type_view'");
                int rowCount = 0;
                while (rs.next()) {
                    assertEquals(rs.getString("TABLE_NAME"), "table_type_view");
                    assertEquals(rs.getString("TABLE_TYPE"), "VIEW");
                    rowCount++;
                }
                assertEquals(rowCount, 1);

                rowCount = 0;
                rs = statement.executeQuery("SELECT TABLE_NAME, TABLE_TYPE FROM system.jdbc.tables WHERE TABLE_SCHEM = 'default' AND TABLE_NAME = 'test_table_type'");
                while (rs.next()) {
                    assertEquals(rs.getString("TABLE_NAME"), "test_table_type");
                    assertEquals(rs.getString("TABLE_TYPE"), "TABLE");
                    rowCount++;
                }
                assertEquals(rowCount, 1);

                statement.execute("DROP TABLE test_table_type");
                statement.execute("DROP VIEW table_type_view");
            }
        }
    }

    @Test
    public void testRollback()
            throws SQLException
    {
        try (Connection connection = createConnection()) {
            connection.setAutoCommit(false);
            try (Statement statement = connection.createStatement()) {
                statement.execute("CREATE TABLE test_rollback (x bigint)");
            }

            try (Connection otherConnection = createConnection()) {
                assertThat(listTables(otherConnection)).doesNotContain("test_rollback");
            }

            connection.rollback();
        }

        try (Connection connection = createConnection()) {
            assertThat(listTables(connection)).doesNotContain("test_rollback");
        }
    }

    @Test
    public void testImmediateCommit()
            throws SQLException
    {
        try (Connection connection = createConnection()) {
            connection.setAutoCommit(false);
            connection.commit();
        }
    }

    @Test
    public void testImmediateRollback()
            throws SQLException
    {
        try (Connection connection = createConnection()) {
            connection.setAutoCommit(false);
            connection.rollback();
        }
    }

    @Test
    public void testUse()
            throws SQLException
    {
        try (Connection connection = createConnection()) {
            assertThat(connection.getCatalog()).isEqualTo("hive");
            assertThat(connection.getSchema()).isEqualTo("default");

            // change schema
            try (Statement statement = connection.createStatement()) {
                statement.execute("USE fruit");
            }

            assertThat(connection.getCatalog()).isEqualTo("hive");
            assertThat(connection.getSchema()).isEqualTo("fruit");

            // change catalog and schema
            try (Statement statement = connection.createStatement()) {
                statement.execute("USE system.runtime");
            }

            assertThat(connection.getCatalog()).isEqualTo("system");
            assertThat(connection.getSchema()).isEqualTo("runtime");

            // run multiple queries
            assertThat(listTables(connection)).contains("nodes");
            assertThat(listTables(connection)).contains("queries");
            assertThat(listTables(connection)).contains("tasks");
        }
    }

    @Test
    public void testSession()
            throws SQLException
    {
        try (Connection connection = createConnection("sessionProperties=query_max_run_time:2d;max_failed_task_percentage:0.6")) {
            assertThat(listSession(connection))
                    .contains("join_distribution_type|AUTOMATIC|AUTOMATIC")
                    .contains("exchange_compression_codec|NONE|NONE")
                    .contains("query_max_run_time|2d|100.00d")
                    .contains("max_failed_task_percentage|0.6|0.3");

            try (Statement statement = connection.createStatement()) {
                statement.execute("SET SESSION join_distribution_type = 'BROADCAST'");
            }

            assertThat(listSession(connection))
                    .contains("join_distribution_type|BROADCAST|AUTOMATIC")
                    .contains("exchange_compression_codec|NONE|NONE");

            try (Statement statement = connection.createStatement()) {
                statement.execute("SET SESSION exchange_compression_codec = 'LZ4'");
            }

            assertThat(listSession(connection))
                    .contains("join_distribution_type|BROADCAST|AUTOMATIC")
                    .contains("exchange_compression_codec|LZ4|NONE");
        }
    }

    @Test
    public void testApplicationName()
            throws SQLException
    {
        try (Connection connection = createConnection()) {
            assertConnectionSource(connection, "presto-jdbc");
        }

        try (Connection connection = createConnection()) {
            connection.setClientInfo("ApplicationName", "testing");
            assertConnectionSource(connection, "testing");
        }

        try (Connection connection = createConnection("applicationNamePrefix=fruit:")) {
            assertConnectionSource(connection, "fruit:");
        }

        try (Connection connection = createConnection("applicationNamePrefix=fruit:")) {
            connection.setClientInfo("ApplicationName", "testing");
            assertConnectionSource(connection, "fruit:testing");
        }
    }

    @Test
    public void testHttpProtocols()
            throws SQLException
    {
        String extra = "protocols=http11";
        try (Connection connection = createConnection(extra)) {
            assertThat(connection.getCatalog()).isEqualTo("hive");
        }

        // deduplication
        extra = "protocols=http11,http11";
        try (Connection connection = createConnection(extra)) {
            assertThat(connection.getCatalog()).isEqualTo("hive");
        }
    }

    @Test
    public void testExtraCredentials()
            throws SQLException
    {
        Map<String, String> credentials = ImmutableMap.of("test.token.foo", "bar", "test.token.abc", "xyz");
        Connection connection = createConnection("extraCredentials=test.token.foo:bar;test.token.abc:xyz");

        assertTrue(connection instanceof PrestoConnection);
        PrestoConnection prestoConnection = connection.unwrap(PrestoConnection.class);
        assertEquals(prestoConnection.getExtraCredentials(), credentials);
        assertEquals(listExtraCredentials(connection), credentials);
    }

    @Test
    public void testCustomHeaders()
            throws SQLException, UnsupportedEncodingException
    {
        Map<String, String> customHeadersMap = ImmutableMap.of("testHeaderKey", "testHeaderValue");
        String customHeaders = "testHeaderKey:testHeaderValue";
        String encodedCustomHeaders = URLEncoder.encode(customHeaders, StandardCharsets.UTF_8.toString());
        Connection connection = createConnection("customHeaders=" + encodedCustomHeaders);
        assertTrue(connection instanceof PrestoConnection);
        PrestoConnection prestoConnection = connection.unwrap(PrestoConnection.class);
        assertEquals(prestoConnection.getCustomHeaders(), customHeadersMap);
    }

    @Test(dataProvider = "customHeaderWithSpecialCharacter")
    public void testCustomHeadersWithSpecialCharacters(String testHeaderValue)
            throws SQLException, UnsupportedEncodingException
    {
        Map<String, String> customHeadersMap = ImmutableMap.of("testHeaderKey", URLEncoder.encode(testHeaderValue, StandardCharsets.UTF_8.toString()));
        String customHeaders = "testHeaderKey:" + URLEncoder.encode(testHeaderValue, StandardCharsets.UTF_8.toString()) + "";
        String encodedCustomHeaders = URLEncoder.encode(customHeaders, StandardCharsets.UTF_8.toString());
        Connection connection = createConnection("customHeaders=" + encodedCustomHeaders);
        assertTrue(connection instanceof PrestoConnection);
        PrestoConnection prestoConnection = connection.unwrap(PrestoConnection.class);
        assertEquals(prestoConnection.getCustomHeaders(), customHeadersMap);
    }

    @Test
    public void testClientTags()
            throws SQLException
    {
        try (Connection connection = createConnection("clientTags=c2,c3")) {
            assertEquals(connection.getClientInfo("ClientTags"), "c2,c3");
        }
    }

    @Test
    public void testQueryInterceptors()
            throws SQLException
    {
        String extra = "queryInterceptors=" + TestNoopQueryInterceptor.class.getName();
        try (PrestoConnection connection = createConnection(extra).unwrap(PrestoConnection.class)) {
            List<QueryInterceptor> queryInterceptorInstances = connection.getQueryInterceptorInstances();
            assertEquals(queryInterceptorInstances.size(), 1);
            assertEquals(queryInterceptorInstances.get(0).getClass().getName(), TestNoopQueryInterceptor.class.getName());
        }
    }

    @Test
    public void testConnectionProperties()
            throws SQLException
    {
        String extra = "extraCredentials=test.token.foo:bar;test.token.abc:xyz";
        try (PrestoConnection connection = createConnection(extra).unwrap(PrestoConnection.class)) {
            Properties connectionProperties = connection.getConnectionProperties();
            assertTrue(connectionProperties.size() > 0);
            assertNotNull(connectionProperties.getProperty("extraCredentials"));
        }
    }

    public static class TestNoopQueryInterceptor
            implements QueryInterceptor
    {
    }

    private Connection createConnection()
            throws SQLException
    {
        return createConnection("");
    }

    private Connection createConnection(String extra)
            throws SQLException
    {
        String url = format("jdbc:presto://%s/hive/default?%s", server.getAddress(), extra);
        return DriverManager.getConnection(url, "admin", null);
    }

    private static Set<String> listTables(Connection connection)
            throws SQLException
    {
        ImmutableSet.Builder<String> set = ImmutableSet.builder();
        try (Statement statement = connection.createStatement();
                ResultSet rs = statement.executeQuery("SHOW TABLES")) {
            while (rs.next()) {
                set.add(rs.getString(1));
            }
        }
        return set.build();
    }

    private static Set<String> listSession(Connection connection)
            throws SQLException
    {
        ImmutableSet.Builder<String> set = ImmutableSet.builder();
        try (Statement statement = connection.createStatement();
                ResultSet rs = statement.executeQuery("SHOW SESSION")) {
            while (rs.next()) {
                set.add(Joiner.on('|').join(
                        rs.getString(1),
                        rs.getString(2),
                        rs.getString(3)));
            }
        }
        return set.build();
    }

    private static Map<String, String> listExtraCredentials(Connection connection)
            throws SQLException
    {
        ResultSet rs = connection.createStatement().executeQuery("SELECT * FROM system.test.extra_credentials");
        ImmutableMap.Builder<String, String> builder = ImmutableMap.builder();
        while (rs.next()) {
            builder.put(rs.getString("name"), rs.getString("value"));
        }
        return builder.build();
    }

    private static void assertConnectionSource(Connection connection, String expectedSource)
            throws SQLException
    {
        String queryId;
        try (Statement statement = connection.createStatement();
                ResultSet rs = statement.executeQuery("SELECT 123")) {
            queryId = rs.unwrap(PrestoResultSet.class).getQueryId();
        }

        try (PreparedStatement statement = connection.prepareStatement(
                "SELECT source FROM system.runtime.queries WHERE query_id = ?")) {
            statement.setString(1, queryId);
            try (ResultSet rs = statement.executeQuery()) {
                assertTrue(rs.next());
                assertThat(rs.getString("source")).isEqualTo(expectedSource);
                assertFalse(rs.next());
            }
        }
    }

    private static class ExtraCredentialsSystemTable
            implements SystemTable
    {
        private static final SchemaTableName NAME = new SchemaTableName("test", "extra_credentials");

        public static final ConnectorTableMetadata METADATA = tableMetadataBuilder(NAME)
                .column("name", createUnboundedVarcharType())
                .column("value", createUnboundedVarcharType())
                .build();

        @Override
        public Distribution getDistribution()
        {
            return ALL_NODES;
        }

        @Override
        public ConnectorTableMetadata getTableMetadata()
        {
            return METADATA;
        }

        @Override
        public RecordCursor cursor(ConnectorTransactionHandle transactionHandle, ConnectorSession session, TupleDomain<Integer> constraint)
        {
            InMemoryRecordSet.Builder table = InMemoryRecordSet.builder(METADATA);
            session.getIdentity().getExtraCredentials().forEach(table::addRow);
            return table.build().cursor();
        }
    }
}