OAuthIntegrationTest.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.driver.jdbc;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import mockwebserver3.MockResponse;
import mockwebserver3.MockWebServer;
import mockwebserver3.RecordedRequest;
import mockwebserver3.junit5.StartStop;
import org.apache.arrow.driver.jdbc.authentication.TokenAuthentication;
import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty;
import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
import org.apache.arrow.flight.sql.FlightSqlProducer.Schemas;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

/**
 * Integration tests for OAuth authentication flows in the JDBC driver.
 *
 * <p>These tests verify that OAuth tokens obtained from an OAuth server are correctly used in
 * Flight SQL requests.
 */
public class OAuthIntegrationTest {

  private static final String VALID_ACCESS_TOKEN = "valid-oauth-access-token-12345";
  private static final String CLIENT_ID = "test-client-id";
  private static final String CLIENT_SECRET = "test-client-secret";
  private static final String SUBJECT_TOKEN = "original-subject-token";
  private static final String SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt";
  private static final String TEST_SCOPE = "dremio.all";

  private static final MockFlightSqlProducer FLIGHT_SQL_PRODUCER = new MockFlightSqlProducer();

  @RegisterExtension public static FlightServerTestExtension FLIGHT_SERVER_TEST_EXTENSION;

  static {
    FLIGHT_SERVER_TEST_EXTENSION =
        new FlightServerTestExtension.Builder()
            .authentication(new TokenAuthentication.Builder().token(VALID_ACCESS_TOKEN).build())
            .producer(FLIGHT_SQL_PRODUCER)
            .build();
  }

  @StartStop private final MockWebServer oauthServer = new MockWebServer();
  private URI tokenEndpoint;

  @BeforeAll
  public static void setUpClass() {
    // Register a simple catalog query handler
    FLIGHT_SQL_PRODUCER.addCatalogQuery(
        CommandGetCatalogs.getDefaultInstance(),
        listener -> {
          try (BufferAllocator allocator = new RootAllocator();
              VectorSchemaRoot root =
                  VectorSchemaRoot.create(Schemas.GET_CATALOGS_SCHEMA, allocator)) {
            root.setRowCount(0);
            listener.start(root);
            listener.putNext();
          } catch (Throwable t) {
            listener.error(t);
          } finally {
            listener.completed();
          }
        });

    // Register a simple schema query handler for getSchemas()
    FLIGHT_SQL_PRODUCER.addCatalogQuery(
        CommandGetDbSchemas.getDefaultInstance(),
        listener -> {
          try (BufferAllocator allocator = new RootAllocator();
              VectorSchemaRoot root =
                  VectorSchemaRoot.create(Schemas.GET_SCHEMAS_SCHEMA, allocator)) {
            root.setRowCount(0);
            listener.start(root);
            listener.putNext();
          } catch (Throwable t) {
            listener.error(t);
          } finally {
            listener.completed();
          }
        });
  }

  @AfterAll
  public static void tearDownClass() {
    AutoCloseables.closeNoChecked(FLIGHT_SQL_PRODUCER);
  }

  @BeforeEach
  public void setUp() {
    tokenEndpoint = oauthServer.url("/oauth/token").uri();
  }

  @AfterEach
  public void tearDown() {
    oauthServer.close();
  }

  // Helper methods for mock OAuth responses

  private void enqueueSuccessfulTokenResponse() {
    enqueueSuccessfulTokenResponse(VALID_ACCESS_TOKEN, 3600);
  }

  private void enqueueSuccessfulTokenResponse(String token, int expiresIn) {
    String body =
        String.format(
            "{\"access_token\":\"%s\",\"token_type\":\"Bearer\",\"expires_in\":%d}",
            token, expiresIn);
    oauthServer.enqueue(
        new MockResponse.Builder()
            .code(200)
            .setHeader("Content-Type", "application/json")
            .body(body)
            .build());
  }

  private void enqueueErrorResponse(String error, String description) {
    String body =
        String.format("{\"error\":\"%s\",\"error_description\":\"%s\"}", error, description);
    oauthServer.enqueue(
        new MockResponse.Builder()
            .code(400)
            .setHeader("Content-Type", "application/json")
            .body(body)
            .build());
  }

  private Properties createBaseProperties() {
    Properties props = new Properties();
    props.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost");
    props.put(
        ArrowFlightConnectionProperty.PORT.camelName(), FLIGHT_SERVER_TEST_EXTENSION.getPort());
    props.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), false);
    return props;
  }

  private String getJdbcUrl() {
    return String.format(
        "jdbc:arrow-flight-sql://localhost:%d", FLIGHT_SERVER_TEST_EXTENSION.getPort());
  }

  // ==================== Client Credentials Flow Tests ====================

  @Test
  public void testClientCredentialsFlowSuccess() throws Exception {
    enqueueSuccessfulTokenResponse();

    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "client_credentials");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), tokenEndpoint.toString());
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_ID.camelName(), CLIENT_ID);
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_SECRET.camelName(), CLIENT_SECRET);
    props.put(ArrowFlightConnectionProperty.OAUTH_SCOPE.camelName(), TEST_SCOPE);

    try (Connection conn = DriverManager.getConnection(getJdbcUrl(), props)) {
      assertFalse(conn.isClosed());
      // Trigger a Flight call to force OAuth token retrieval
      conn.getMetaData().getCatalogs().close();
    }

    // Verify OAuth request was made
    RecordedRequest oauthRequest = oauthServer.takeRequest(5, TimeUnit.SECONDS);
    assertNotNull(oauthRequest, "OAuth request should have been made");
    assertEquals("POST", oauthRequest.getMethod());
    String body = oauthRequest.getBody().utf8();
    assertTrue(body.contains("grant_type=client_credentials"));
    assertTrue(body.contains("scope=" + TEST_SCOPE));
  }

  @Test
  public void testClientCredentialsFlowWithUrlParameters() throws Exception {
    enqueueSuccessfulTokenResponse();

    String url =
        String.format(
            "jdbc:arrow-flight-sql://localhost:%d?useEncryption=false"
                + "&oauth.flow=client_credentials"
                + "&oauth.tokenUri=%s"
                + "&oauth.clientId=%s"
                + "&oauth.clientSecret=%s",
            FLIGHT_SERVER_TEST_EXTENSION.getPort(),
            tokenEndpoint.toString(),
            CLIENT_ID,
            CLIENT_SECRET);

    try (Connection conn = DriverManager.getConnection(url)) {
      conn.getMetaData().getCatalogs().close();
    }

    assertEquals(1, oauthServer.getRequestCount());
  }

  @Test
  public void testClientCredentialsFlowInvalidCredentials() throws Exception {
    enqueueErrorResponse("invalid_client", "Client authentication failed");

    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "client_credentials");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), tokenEndpoint.toString());
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_ID.camelName(), "wrong-client");
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_SECRET.camelName(), "wrong-secret");

    Exception ex =
        assertThrows(
            Exception.class,
            () -> {
              try (Connection conn = DriverManager.getConnection(getJdbcUrl(), props)) {
                conn.getMetaData().getCatalogs().close();
              }
            });
    // Verify the error message contains the OAuth error somewhere in the exception chain
    assertTrue(
        containsInExceptionChain(ex, "invalid_client"),
        "Exception chain should contain 'invalid_client'");
  }

  private boolean containsInExceptionChain(Throwable t, String message) {
    while (t != null) {
      if (t.getMessage() != null && t.getMessage().contains(message)) {
        return true;
      }
      t = t.getCause();
    }
    return false;
  }

  // ==================== Token Exchange Flow Tests ====================

  @Test
  public void testTokenExchangeFlowMinimalParameters() throws Exception {
    enqueueSuccessfulTokenResponse();

    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "token_exchange");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), tokenEndpoint.toString());
    props.put(
        ArrowFlightConnectionProperty.OAUTH_EXCHANGE_SUBJECT_TOKEN.camelName(), SUBJECT_TOKEN);
    props.put(
        ArrowFlightConnectionProperty.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.camelName(),
        SUBJECT_TOKEN_TYPE);

    try (Connection conn = DriverManager.getConnection(getJdbcUrl(), props)) {
      conn.getMetaData().getCatalogs().close();
    }

    RecordedRequest oauthRequest = oauthServer.takeRequest(5, TimeUnit.SECONDS);
    assertNotNull(oauthRequest, "OAuth request should have been made");
    String body = oauthRequest.getBody().utf8();
    assertTrue(
        body.contains("grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange"),
        "Should contain token exchange grant type");
    assertTrue(body.contains("subject_token=" + SUBJECT_TOKEN));
  }

  @Test
  public void testTokenExchangeFlowWithAllParameters() throws Exception {
    enqueueSuccessfulTokenResponse();

    String actorToken = "actor-token-value";
    String actorTokenType = "urn:ietf:params:oauth:token-type:access_token";
    String audience = "https://api.example.com";
    String resource = "https://api.example.com/resource";
    String requestedTokenType = "urn:ietf:params:oauth:token-type:access_token";

    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "token_exchange");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), tokenEndpoint.toString());
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_ID.camelName(), CLIENT_ID);
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_SECRET.camelName(), CLIENT_SECRET);
    props.put(ArrowFlightConnectionProperty.OAUTH_SCOPE.camelName(), TEST_SCOPE);
    props.put(
        ArrowFlightConnectionProperty.OAUTH_EXCHANGE_SUBJECT_TOKEN.camelName(), SUBJECT_TOKEN);
    props.put(
        ArrowFlightConnectionProperty.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.camelName(),
        SUBJECT_TOKEN_TYPE);
    props.put(ArrowFlightConnectionProperty.OAUTH_EXCHANGE_ACTOR_TOKEN.camelName(), actorToken);
    props.put(
        ArrowFlightConnectionProperty.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.camelName(), actorTokenType);
    props.put(ArrowFlightConnectionProperty.OAUTH_EXCHANGE_AUDIENCE.camelName(), audience);
    props.put(ArrowFlightConnectionProperty.OAUTH_RESOURCE.camelName(), resource);
    props.put(
        ArrowFlightConnectionProperty.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE.camelName(),
        requestedTokenType);

    try (Connection conn = DriverManager.getConnection(getJdbcUrl(), props)) {
      conn.getMetaData().getCatalogs().close();
    }

    RecordedRequest oauthRequest = oauthServer.takeRequest(5, TimeUnit.SECONDS);
    assertNotNull(oauthRequest, "OAuth request should have been made");
    String body = oauthRequest.getBody().utf8();
    assertTrue(body.contains("subject_token=" + SUBJECT_TOKEN));
    assertTrue(body.contains("actor_token=" + actorToken));
  }

  @Test
  public void testTokenExchangeFlowWithClientAuthentication() throws Exception {
    enqueueSuccessfulTokenResponse();

    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "token_exchange");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), tokenEndpoint.toString());
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_ID.camelName(), CLIENT_ID);
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_SECRET.camelName(), CLIENT_SECRET);
    props.put(
        ArrowFlightConnectionProperty.OAUTH_EXCHANGE_SUBJECT_TOKEN.camelName(), SUBJECT_TOKEN);
    props.put(
        ArrowFlightConnectionProperty.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.camelName(),
        SUBJECT_TOKEN_TYPE);

    try (Connection conn = DriverManager.getConnection(getJdbcUrl(), props)) {
      conn.getMetaData().getCatalogs().close();
    }

    RecordedRequest oauthRequest = oauthServer.takeRequest(5, TimeUnit.SECONDS);
    assertNotNull(oauthRequest, "OAuth request should have been made");
    String authHeader = oauthRequest.getHeaders().get("Authorization");
    assertNotNull(authHeader, "Should have Basic auth header for client authentication");
    assertTrue(authHeader.startsWith("Basic "));
  }

  // ==================== Token Caching Tests ====================

  @Test
  public void testTokenCachingAcrossMultipleOperations() throws Exception {
    enqueueSuccessfulTokenResponse();

    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "client_credentials");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), tokenEndpoint.toString());
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_ID.camelName(), CLIENT_ID);
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_SECRET.camelName(), CLIENT_SECRET);

    try (Connection conn = DriverManager.getConnection(getJdbcUrl(), props)) {
      // Execute multiple operations
      conn.isValid(5);
      conn.getMetaData().getCatalogs().close();
      conn.getMetaData().getSchemas().close();
    }

    // Should only have made one OAuth request due to caching
    assertEquals(1, oauthServer.getRequestCount());
  }

  @Test
  public void testTokenRefreshAfterExpiration() throws Exception {
    enqueueSuccessfulTokenResponse(VALID_ACCESS_TOKEN, 1);
    enqueueSuccessfulTokenResponse(VALID_ACCESS_TOKEN, 3600);

    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "token_exchange");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), tokenEndpoint.toString());
    props.put(
        ArrowFlightConnectionProperty.OAUTH_EXCHANGE_SUBJECT_TOKEN.camelName(), SUBJECT_TOKEN);
    props.put(
        ArrowFlightConnectionProperty.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.camelName(),
        SUBJECT_TOKEN_TYPE);

    try (Connection conn = DriverManager.getConnection(getJdbcUrl(), props)) {
      // First operation triggers initial token fetch
      conn.getMetaData().getCatalogs().close();

      // Token with 1s expiry is immediately considered expired (due to 30s buffer)
      // so the next operation should trigger a refresh
      conn.getMetaData().getCatalogs().close();
    }

    // Should have made exactly 2 OAuth requests: initial + refresh
    assertEquals(2, oauthServer.getRequestCount());
  }

  // ==================== Error Handling Tests ====================

  @Test
  public void testMissingRequiredParametersClientCredentials() {
    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "client_credentials");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), tokenEndpoint.toString());
    // Missing client_id and client_secret

    assertThrows(SQLException.class, () -> DriverManager.getConnection(getJdbcUrl(), props));
  }

  @Test
  public void testMissingRequiredParametersTokenExchange() {
    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "token_exchange");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), tokenEndpoint.toString());
    // Missing subject_token and subject_token_type

    assertThrows(SQLException.class, () -> DriverManager.getConnection(getJdbcUrl(), props));
  }

  @Test
  public void testInvalidOAuthFlow() {
    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "invalid_flow");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), tokenEndpoint.toString());

    assertThrows(SQLException.class, () -> DriverManager.getConnection(getJdbcUrl(), props));
  }

  @Test
  public void testMalformedTokenEndpoint() {
    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "client_credentials");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), "not-a-valid-uri://");
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_ID.camelName(), CLIENT_ID);
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_SECRET.camelName(), CLIENT_SECRET);

    assertThrows(SQLException.class, () -> DriverManager.getConnection(getJdbcUrl(), props));
  }

  // ==================== Authorization Header Verification ====================

  @Test
  public void testOAuthTokenSentAsBearer() throws Exception {
    enqueueSuccessfulTokenResponse();

    Properties props = createBaseProperties();
    props.put(ArrowFlightConnectionProperty.OAUTH_FLOW.camelName(), "client_credentials");
    props.put(ArrowFlightConnectionProperty.OAUTH_TOKEN_URI.camelName(), tokenEndpoint.toString());
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_ID.camelName(), CLIENT_ID);
    props.put(ArrowFlightConnectionProperty.OAUTH_CLIENT_SECRET.camelName(), CLIENT_SECRET);

    try (Connection conn = DriverManager.getConnection(getJdbcUrl(), props)) {
      conn.getMetaData().getCatalogs().close();
    }

    // Verify the Flight server received the bearer token
    String authHeader =
        FLIGHT_SERVER_TEST_EXTENSION
            .getInterceptorFactory()
            .getHeader(org.apache.arrow.flight.FlightMethod.GET_FLIGHT_INFO, "authorization");
    assertNotNull(authHeader, "Authorization header should be present in Flight requests");
    assertEquals("Bearer " + VALID_ACCESS_TOKEN, authHeader);
  }
}