TestPrestoDriverAuth.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.jdbc;
import com.facebook.airlift.log.Logging;
import com.facebook.airlift.security.pem.PemReader;
import com.facebook.presto.server.testing.TestingPrestoServer;
import com.facebook.presto.sql.parser.SqlParserOptions;
import com.facebook.presto.tpch.TpchPlugin;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import java.io.File;
import java.net.URL;
import java.security.PrivateKey;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Base64;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;
import static com.facebook.presto.jdbc.TestPrestoDriver.closeQuietly;
import static com.facebook.presto.jdbc.TestPrestoDriver.waitForNodeRefresh;
import static com.google.common.io.Files.asCharSource;
import static com.google.common.io.Resources.getResource;
import static io.jsonwebtoken.JwsHeader.KEY_ID;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.util.Base64.getMimeDecoder;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;
public class TestPrestoDriverAuth
{
private static final String TEST_CATALOG = "test_catalog";
private TestingPrestoServer server;
private byte[] defaultKey;
private byte[] hmac222;
private PrivateKey privateKey33;
@BeforeClass
public void setup()
throws Exception
{
Logging.initialize();
URL resource = getClass().getClassLoader().getResource("33.privateKey");
assertNotNull(resource, "key directory not found");
File keyDir = new File(resource.getFile()).getAbsoluteFile().getParentFile();
defaultKey = getMimeDecoder().decode(asCharSource(new File(keyDir, "default-key.key"), US_ASCII).read().getBytes(US_ASCII));
hmac222 = getMimeDecoder().decode(asCharSource(new File(keyDir, "222.key"), US_ASCII).read().getBytes(US_ASCII));
privateKey33 = PemReader.loadPrivateKey(new File(keyDir, "33.privateKey"), Optional.empty());
server = new TestingPrestoServer(
true,
ImmutableMap.<String, String>builder()
.put("http-server.authentication.type", "JWT")
.put("http.authentication.jwt.key-file", new File(keyDir, "${KID}.key").toString())
.put("http-server.https.enabled", "true")
.put("http-server.https.keystore.path", getResource("localhost.keystore").getPath())
.put("http-server.https.keystore.key", "changeit")
.build(),
null,
null,
new SqlParserOptions(),
ImmutableList.of());
server.installPlugin(new TpchPlugin());
server.createCatalog(TEST_CATALOG, "tpch");
waitForNodeRefresh(server);
}
@AfterClass(alwaysRun = true)
public void teardown()
{
closeQuietly(server);
}
@Test
public void testSuccessDefaultKey()
throws Exception
{
String accessToken = Jwts.builder()
.setSubject("test")
.signWith(SignatureAlgorithm.HS512, defaultKey)
.compact();
try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
try (Statement statement = connection.createStatement()) {
assertTrue(statement.execute("SELECT 123"));
ResultSet rs = statement.getResultSet();
assertTrue(rs.next());
assertEquals(rs.getLong(1), 123);
assertFalse(rs.next());
}
}
}
@Test
public void testSuccessHmac()
throws Exception
{
String accessToken = Jwts.builder()
.setSubject("test")
.setHeaderParam(KEY_ID, "222")
.signWith(SignatureAlgorithm.HS512, hmac222)
.compact();
try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
try (Statement statement = connection.createStatement()) {
assertTrue(statement.execute("SELECT 123"));
ResultSet rs = statement.getResultSet();
assertTrue(rs.next());
assertEquals(rs.getLong(1), 123);
assertFalse(rs.next());
}
}
}
@Test
public void testSuccessPublicKey()
throws Exception
{
String accessToken = Jwts.builder()
.setSubject("test")
.setHeaderParam(KEY_ID, "33")
.signWith(SignatureAlgorithm.RS256, privateKey33)
.compact();
try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
try (Statement statement = connection.createStatement()) {
assertTrue(statement.execute("SELECT 123"));
ResultSet rs = statement.getResultSet();
assertTrue(rs.next());
assertEquals(rs.getLong(1), 123);
assertFalse(rs.next());
}
}
}
@Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Authentication failed: Unauthorized")
public void testFailedNoToken()
throws Exception
{
try (Connection connection = createConnection(ImmutableMap.of())) {
try (Statement statement = connection.createStatement()) {
statement.execute("SELECT 123");
}
}
}
@Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Authentication failed: Unsigned Claims JWTs are not supported.")
public void testFailedUnsigned()
throws Exception
{
String accessToken = Jwts.builder()
.setSubject("test")
.compact();
try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
try (Statement statement = connection.createStatement()) {
statement.execute("SELECT 123");
}
}
}
@Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Authentication failed: JWT signature does not match.*")
public void testFailedBadHmacSignature()
throws Exception
{
String badKey = "iPqFfWmGvClP953xU9110Q48qB4F5dcJ7QQel3O1k0xU52mlR6fT51SMa2f4KzhFRqqpwGUOud8Eo12pK9EW5H4N";
String accessToken = Jwts.builder()
.setSubject("test")
.signWith(SignatureAlgorithm.HS512, Base64.getEncoder().encodeToString(badKey.getBytes(US_ASCII)))
.compact();
try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
try (Statement statement = connection.createStatement()) {
statement.execute("SELECT 123");
}
}
}
@Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Authentication failed: JWT signature does not match.*")
public void testFailedWrongPublicKey()
throws Exception
{
String accessToken = Jwts.builder()
.setSubject("test")
.setHeaderParam(KEY_ID, "42")
.signWith(SignatureAlgorithm.RS256, privateKey33)
.compact();
try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
try (Statement statement = connection.createStatement()) {
statement.execute("SELECT 123");
}
}
}
@Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Authentication failed: Unknown signing key ID")
public void testFailedUnknownPublicKey()
throws Exception
{
String accessToken = Jwts.builder()
.setSubject("test")
.setHeaderParam(KEY_ID, "unknown")
.signWith(SignatureAlgorithm.RS256, privateKey33)
.compact();
try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
try (Statement statement = connection.createStatement()) {
statement.execute("SELECT 123");
}
}
}
private Connection createConnection(Map<String, String> additionalProperties)
throws SQLException
{
String url = format("jdbc:presto://localhost:%s", server.getHttpsAddress().getPort());
Properties properties = new Properties();
properties.setProperty("user", "test");
properties.setProperty("SSL", "true");
properties.setProperty("SSLTrustStorePath", getResource("localhost.truststore").getPath());
properties.setProperty("SSLTrustStorePassword", "changeit");
properties.putAll(additionalProperties);
return DriverManager.getConnection(url, properties);
}
}