ArrowFlightSqlClientHandler.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.driver.jdbc.client;

import com.google.common.collect.ImmutableMap;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.ChannelOption;
import java.io.IOException;
import java.net.URI;
import java.security.GeneralSecurityException;
import java.sql.SQLException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.arrow.driver.jdbc.client.oauth.OAuthConfiguration;
import org.apache.arrow.driver.jdbc.client.oauth.OAuthCredentialWriter;
import org.apache.arrow.driver.jdbc.client.oauth.OAuthTokenProvider;
import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils;
import org.apache.arrow.driver.jdbc.client.utils.FlightClientCache;
import org.apache.arrow.driver.jdbc.client.utils.FlightLocationQueue;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.CloseSessionRequest;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightGrpcUtils;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.LocationSchemes;
import org.apache.arrow.flight.SessionOptionValueFactory;
import org.apache.arrow.flight.SetSessionOptionsRequest;
import org.apache.arrow.flight.SetSessionOptionsResult;
import org.apache.arrow.flight.auth2.BearerCredentialWriter;
import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler;
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
import org.apache.arrow.flight.client.ClientCookieMiddleware;
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.flight.grpc.NettyClientBuilder;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo;
import org.apache.arrow.flight.sql.util.TableRef;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.DriverVersion;
import org.apache.calcite.avatica.Meta.StatementType;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A {@link FlightSqlClient} handler. */
public final class ArrowFlightSqlClientHandler implements AutoCloseable {
  private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFlightSqlClientHandler.class);
  // JDBC connection string query parameter
  private static final String CATALOG = "catalog";

  private final String cacheKey;
  private final FlightSqlClient sqlClient;
  private final Set<CallOption> options = new HashSet<>();
  private final Builder builder;
  private final Optional<String> catalog;
  private final @Nullable FlightClientCache flightClientCache;

  ArrowFlightSqlClientHandler(
      final String cacheKey,
      final FlightSqlClient sqlClient,
      final Builder builder,
      final Collection<CallOption> credentialOptions,
      final Optional<String> catalog,
      final @Nullable FlightClientCache flightClientCache) {
    this.options.addAll(builder.options);
    this.options.addAll(credentialOptions);
    this.cacheKey = Preconditions.checkNotNull(cacheKey);
    this.sqlClient = Preconditions.checkNotNull(sqlClient);
    this.builder = builder;
    this.catalog = catalog;
    this.flightClientCache = flightClientCache;
  }

  /**
   * Creates a new {@link ArrowFlightSqlClientHandler} from the provided {@code client} and {@code
   * options}.
   *
   * @param client the {@link FlightClient} to manage under a {@link FlightSqlClient} wrapper.
   * @param options the {@link CallOption}s to persist in between subsequent client calls.
   * @return a new {@link ArrowFlightSqlClientHandler}.
   */
  static ArrowFlightSqlClientHandler createNewHandler(
      final String cacheKey,
      final FlightClient client,
      final Builder builder,
      final Collection<CallOption> options,
      final Optional<String> catalog,
      final @Nullable FlightClientCache flightClientCache) {
    final ArrowFlightSqlClientHandler handler =
        new ArrowFlightSqlClientHandler(
            cacheKey, new FlightSqlClient(client), builder, options, catalog, flightClientCache);
    handler.setSetCatalogInSessionIfPresent();
    return handler;
  }

  /**
   * Gets the {@link #options} for the subsequent calls from this handler.
   *
   * @return the {@link CallOption}s.
   */
  private CallOption[] getOptions() {
    return options.toArray(new CallOption[0]);
  }

  /**
   * Makes an RPC "getStream" request based on the provided {@link FlightInfo} object. Retrieves the
   * result of the query previously prepared with "getInfo."
   *
   * @param flightInfo The {@link FlightInfo} instance from which to fetch results.
   * @return a {@code FlightStream} of results.
   */
  public List<CloseableEndpointStreamPair> getStreams(final FlightInfo flightInfo)
      throws SQLException {
    final ArrayList<CloseableEndpointStreamPair> endpoints =
        new ArrayList<>(flightInfo.getEndpoints().size());

    try {
      for (FlightEndpoint endpoint : flightInfo.getEndpoints()) {
        if (endpoint.getLocations().isEmpty()) {
          // Create a stream using the current client only and do not close the client at
          // the end.
          endpoints.add(
              new CloseableEndpointStreamPair(
                  sqlClient.getStream(endpoint.getTicket(), getOptions()), null));
        } else {
          // Clone the builder and then set the new endpoint on it.

          // GH-38574: Currently a new FlightClient will be made for each partition that
          // returns a
          // non-empty Location then disposed of. It may be better to cache clients
          // because a server
          // may report the same Locations. It would also be good to identify when the
          // reported
          // location
          // is the same as the original connection's Location and skip creating a
          // FlightClient in
          // that scenario.
          // Also copy the cache to the client so we can share a cache. Cache needs to
          // cache
          // negative attempts too.
          List<Exception> exceptions = new ArrayList<>();
          CloseableEndpointStreamPair stream = null;
          FlightLocationQueue locations =
              new FlightLocationQueue(flightClientCache, endpoint.getLocations());
          while (locations.hasNext()) {
            Location location = locations.next();
            final URI endpointUri = location.getUri();
            if (endpointUri.getScheme().equals(LocationSchemes.REUSE_CONNECTION)) {
              stream =
                  new CloseableEndpointStreamPair(
                      sqlClient.getStream(endpoint.getTicket(), getOptions()), null);
              break;
            }
            final Builder builderForEndpoint =
                new Builder(ArrowFlightSqlClientHandler.this.builder)
                    .withHost(endpointUri.getHost())
                    .withPort(endpointUri.getPort())
                    .withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS))
                    .withClientCache(flightClientCache)
                    .withConnectTimeout(builder.connectTimeout);

            ArrowFlightSqlClientHandler endpointHandler = null;
            try {
              endpointHandler = builderForEndpoint.build();
              stream =
                  new CloseableEndpointStreamPair(
                      endpointHandler.sqlClient.getStream(
                          endpoint.getTicket(), endpointHandler.getOptions()),
                      endpointHandler.sqlClient);
              // Make sure we actually get data from the server
              stream.getStream().getSchema();
            } catch (Exception ex) {
              if (endpointHandler != null) {
                // If the exception is related to connectivity, mark the client as a dud.
                if (flightClientCache != null) {
                  if (ex instanceof FlightRuntimeException
                      && ((FlightRuntimeException) ex).status().code()
                          == FlightStatusCode.UNAVAILABLE
                      &&
                      // IOException covers SocketException and Netty's (private)
                      // AnnotatedSocketException
                      // We are looking for things like "Network is unreachable"
                      ex.getCause() instanceof IOException) {
                    flightClientCache.markLocationAsDud(location.toString());
                  }
                }

                AutoCloseables.close(endpointHandler);
              }
              exceptions.add(ex);
              continue;
            }

            if (flightClientCache != null) {
              flightClientCache.markLocationAsReachable(location.toString());
            }
            break;
          }
          if (stream != null) {
            endpoints.add(stream);
          } else if (exceptions.isEmpty()) {
            // This should never happen...
            throw new IllegalStateException("Could not connect to endpoint and no errors occurred");
          } else {
            Exception ex = exceptions.remove(0);
            while (!exceptions.isEmpty()) {
              ex.addSuppressed(exceptions.remove(exceptions.size() - 1));
            }
            throw ex;
          }
        }
      }
    } catch (Exception outerException) {
      try {
        AutoCloseables.close(endpoints);
      } catch (Exception innerEx) {
        outerException.addSuppressed(innerEx);
      }

      if (outerException instanceof SQLException) {
        throw (SQLException) outerException;
      }
      throw new SQLException(outerException);
    }
    return endpoints;
  }

  /**
   * Makes an RPC "getInfo" request based on the provided {@code query} object.
   *
   * @param query The query.
   * @return a {@code FlightStream} of results.
   */
  public FlightInfo getInfo(final String query) {
    return sqlClient.execute(query, getOptions());
  }

  @Override
  public void close() throws SQLException {
    if (catalog.isPresent()) {
      try {
        sqlClient.closeSession(new CloseSessionRequest(), getOptions());
      } catch (FlightRuntimeException fre) {
        handleBenignCloseException(
            fre, "Failed to close Flight SQL session.", "closing Flight SQL session");
      }
    }
    try {
      AutoCloseables.close(sqlClient);
    } catch (FlightRuntimeException fre) {
      handleBenignCloseException(
          fre, "Failed to clean up client resources.", "closing Flight SQL client");
    } catch (final Exception e) {
      throw new SQLException("Failed to clean up client resources.", e);
    }
  }

  /**
   * Handles FlightRuntimeException during close operations, suppressing benign gRPC shutdown errors
   * while re-throwing genuine failures.
   *
   * @param fre the FlightRuntimeException to handle
   * @param sqlErrorMessage the SQLException message to use for genuine failures
   * @param operationDescription description of the operation for logging
   * @throws SQLException if the exception represents a genuine failure
   */
  private void handleBenignCloseException(
      FlightRuntimeException fre, String sqlErrorMessage, String operationDescription)
      throws SQLException {
    if (isBenignCloseException(fre)) {
      logSuppressedCloseException(fre, operationDescription);
    } else {
      throw new SQLException(sqlErrorMessage, fre);
    }
  }

  /**
   * Handles FlightRuntimeException during close operations, suppressing benign gRPC shutdown errors
   * while re-throwing genuine failures as FlightRuntimeException.
   *
   * @param fre the FlightRuntimeException to handle
   * @param operationDescription description of the operation for logging
   * @throws FlightRuntimeException if the exception represents a genuine failure
   */
  private void handleBenignCloseException(FlightRuntimeException fre, String operationDescription)
      throws FlightRuntimeException {
    if (isBenignCloseException(fre)) {
      logSuppressedCloseException(fre, operationDescription);
    } else {
      throw fre;
    }
  }

  /**
   * Determines if a FlightRuntimeException represents a benign close operation error that should be
   * suppressed.
   *
   * @param fre the FlightRuntimeException to check
   * @return true if the exception should be suppressed, false otherwise
   */
  private boolean isBenignCloseException(FlightRuntimeException fre) {
    return fre.status().code().equals(FlightStatusCode.UNAVAILABLE)
        || (fre.status().code().equals(FlightStatusCode.INTERNAL)
            && fre.getMessage() != null
            && fre.getMessage().contains("Connection closed after GOAWAY"));
  }

  /**
   * Logs a suppressed close exception with appropriate level based on debug settings.
   *
   * @param fre the FlightRuntimeException being suppressed
   * @param operationDescription description of the operation for logging
   */
  private void logSuppressedCloseException(
      FlightRuntimeException fre, String operationDescription) {
    // ARROW-17785 and GH-863: suppress exceptions caused by flaky gRPC layer during
    // shutdown
    LOGGER.debug("Suppressed error {}", operationDescription, fre);
  }

  /** A prepared statement handler. */
  public interface PreparedStatement extends AutoCloseable {
    /**
     * Executes this {@link PreparedStatement}.
     *
     * @return the {@link FlightInfo} representing the outcome of this query execution.
     * @throws SQLException on error.
     */
    FlightInfo executeQuery() throws SQLException;

    /**
     * Executes a {@link StatementType#UPDATE} query.
     *
     * @return the number of rows affected.
     */
    long executeUpdate();

    /**
     * Gets the {@link StatementType} of this {@link PreparedStatement}.
     *
     * @return the Statement Type.
     */
    StatementType getType();

    /**
     * Gets the {@link Schema} of this {@link PreparedStatement}.
     *
     * @return {@link Schema}.
     */
    Schema getDataSetSchema();

    /**
     * Gets the {@link Schema} of the parameters for this {@link PreparedStatement}.
     *
     * @return {@link Schema}.
     */
    Schema getParameterSchema();

    void setParameters(VectorSchemaRoot parameters);

    @Override
    void close();
  }

  /** A connection is created with catalog set as a session option. */
  private void setSetCatalogInSessionIfPresent() {
    if (catalog.isPresent()) {
      try {
        setCatalog(catalog.get());
      } catch (SQLException e) {
        throw CallStatus.INVALID_ARGUMENT
            .withDescription(e.getMessage())
            .withCause(e)
            .toRuntimeException();
      }
    }
  }

  /**
   * Sets the catalog for the current session.
   *
   * @param catalog the catalog to set.
   * @throws SQLException if an error occurs while setting the catalog.
   */
  public void setCatalog(final String catalog) throws SQLException {
    final SetSessionOptionsRequest request =
        new SetSessionOptionsRequest(
            ImmutableMap.of(CATALOG, SessionOptionValueFactory.makeSessionOptionValue(catalog)));
    try {
      final SetSessionOptionsResult result = sqlClient.setSessionOptions(request, getOptions());
      if (result.hasErrors()) {
        final Map<String, SetSessionOptionsResult.Error> errors = result.getErrors();
        for (final Map.Entry<String, SetSessionOptionsResult.Error> error : errors.entrySet()) {
          LOGGER.warn(error.toString());
        }
        throw new SQLException(
            String.format(
                "Cannot set session option for catalog = %s. Check log for details.", catalog));
      }
    } catch (final FlightRuntimeException e) {
      throw new SQLException(e);
    }
  }

  /**
   * Creates a new {@link PreparedStatement} for the given {@code query}.
   *
   * @param query the SQL query.
   * @return a new prepared statement.
   */
  public PreparedStatement prepare(final String query) {
    final FlightSqlClient.PreparedStatement preparedStatement =
        sqlClient.prepare(query, getOptions());
    return new PreparedStatement() {
      @Override
      public FlightInfo executeQuery() throws SQLException {
        return preparedStatement.execute(getOptions());
      }

      @Override
      public long executeUpdate() {
        return preparedStatement.executeUpdate(getOptions());
      }

      @Override
      public StatementType getType() {
        final Schema schema = preparedStatement.getResultSetSchema();
        return schema.getFields().isEmpty() ? StatementType.UPDATE : StatementType.SELECT;
      }

      @Override
      public Schema getDataSetSchema() {
        return preparedStatement.getResultSetSchema();
      }

      @Override
      public Schema getParameterSchema() {
        return preparedStatement.getParameterSchema();
      }

      @Override
      public void setParameters(VectorSchemaRoot parameters) {
        preparedStatement.setParameters(parameters);
      }

      @Override
      public void close() {
        try {
          preparedStatement.close(getOptions());
        } catch (FlightRuntimeException fre) {
          handleBenignCloseException(fre, "closing PreparedStatement");
        }
      }
    };
  }

  /**
   * Makes an RPC "getCatalogs" request.
   *
   * @return a {@code FlightStream} of results.
   */
  public FlightInfo getCatalogs() {
    return sqlClient.getCatalogs(getOptions());
  }

  /**
   * Makes an RPC "getImportedKeys" request based on the provided info.
   *
   * @param catalog The catalog name. Must match the catalog name as it is stored in the database.
   *     Retrieves those without a catalog. Null means that the catalog name should not be used to
   *     narrow the search.
   * @param schema The schema name. Must match the schema name as it is stored in the database. ""
   *     retrieves those without a schema. Null means that the schema name should not be used to
   *     narrow the search.
   * @param table The table name. Must match the table name as it is stored in the database.
   * @return a {@code FlightStream} of results.
   */
  public FlightInfo getImportedKeys(final String catalog, final String schema, final String table) {
    return sqlClient.getImportedKeys(TableRef.of(catalog, schema, table), getOptions());
  }

  /**
   * Makes an RPC "getExportedKeys" request based on the provided info.
   *
   * @param catalog The catalog name. Must match the catalog name as it is stored in the database.
   *     Retrieves those without a catalog. Null means that the catalog name should not be used to
   *     narrow the search.
   * @param schema The schema name. Must match the schema name as it is stored in the database. ""
   *     retrieves those without a schema. Null means that the schema name should not be used to
   *     narrow the search.
   * @param table The table name. Must match the table name as it is stored in the database.
   * @return a {@code FlightStream} of results.
   */
  public FlightInfo getExportedKeys(final String catalog, final String schema, final String table) {
    return sqlClient.getExportedKeys(TableRef.of(catalog, schema, table), getOptions());
  }

  /**
   * Makes an RPC "getSchemas" request based on the provided info.
   *
   * @param catalog The catalog name. Must match the catalog name as it is stored in the database.
   *     Retrieves those without a catalog. Null means that the catalog name should not be used to
   *     narrow the search.
   * @param schemaPattern The schema name pattern. Must match the schema name as it is stored in the
   *     database. Null means that schema name should not be used to narrow down the search.
   * @return a {@code FlightStream} of results.
   */
  public FlightInfo getSchemas(final String catalog, final String schemaPattern) {
    return sqlClient.getSchemas(catalog, schemaPattern, getOptions());
  }

  /**
   * Makes an RPC "getTableTypes" request.
   *
   * @return a {@code FlightStream} of results.
   */
  public FlightInfo getTableTypes() {
    return sqlClient.getTableTypes(getOptions());
  }

  /**
   * Makes an RPC "getTables" request based on the provided info.
   *
   * @param catalog The catalog name. Must match the catalog name as it is stored in the database.
   *     Retrieves those without a catalog. Null means that the catalog name should not be used to
   *     narrow the search.
   * @param schemaPattern The schema name pattern. Must match the schema name as it is stored in the
   *     database. "" retrieves those without a schema. Null means that the schema name should not
   *     be used to narrow the search.
   * @param tableNamePattern The table name pattern. Must match the table name as it is stored in
   *     the database.
   * @param types The list of table types, which must be from the list of table types to include.
   *     Null returns all types.
   * @param includeSchema Whether to include schema.
   * @return a {@code FlightStream} of results.
   */
  public FlightInfo getTables(
      final String catalog,
      final String schemaPattern,
      final String tableNamePattern,
      final List<String> types,
      final boolean includeSchema) {

    return sqlClient.getTables(
        catalog, schemaPattern, tableNamePattern, types, includeSchema, getOptions());
  }

  /**
   * Gets SQL info.
   *
   * @return the SQL info.
   */
  public FlightInfo getSqlInfo(SqlInfo... info) {
    return sqlClient.getSqlInfo(info, getOptions());
  }

  /**
   * Makes an RPC "getPrimaryKeys" request based on the provided info.
   *
   * @param catalog The catalog name; must match the catalog name as it is stored in the database.
   *     "" retrieves those without a catalog. Null means that the catalog name should not be used
   *     to narrow the search.
   * @param schema The schema name; must match the schema name as it is stored in the database. ""
   *     retrieves those without a schema. Null means that the schema name should not be used to
   *     narrow the search.
   * @param table The table name. Must match the table name as it is stored in the database.
   * @return a {@code FlightStream} of results.
   */
  public FlightInfo getPrimaryKeys(final String catalog, final String schema, final String table) {
    return sqlClient.getPrimaryKeys(TableRef.of(catalog, schema, table), getOptions());
  }

  /**
   * Makes an RPC "getCrossReference" request based on the provided info.
   *
   * @param pkCatalog The catalog name. Must match the catalog name as it is stored in the database.
   *     Retrieves those without a catalog. Null means that the catalog name should not be used to
   *     narrow the search.
   * @param pkSchema The schema name. Must match the schema name as it is stored in the database. ""
   *     retrieves those without a schema. Null means that the schema name should not be used to
   *     narrow the search.
   * @param pkTable The table name. Must match the table name as it is stored in the database.
   * @param fkCatalog The catalog name. Must match the catalog name as it is stored in the database.
   *     Retrieves those without a catalog. Null means that the catalog name should not be used to
   *     narrow the search.
   * @param fkSchema The schema name. Must match the schema name as it is stored in the database. ""
   *     retrieves those without a schema. Null means that the schema name should not be used to
   *     narrow the search.
   * @param fkTable The table name. Must match the table name as it is stored in the database.
   * @return a {@code FlightStream} of results.
   */
  public FlightInfo getCrossReference(
      String pkCatalog,
      String pkSchema,
      String pkTable,
      String fkCatalog,
      String fkSchema,
      String fkTable) {
    return sqlClient.getCrossReference(
        TableRef.of(pkCatalog, pkSchema, pkTable),
        TableRef.of(fkCatalog, fkSchema, fkTable),
        getOptions());
  }

  /** Builder for {@link ArrowFlightSqlClientHandler}. */
  public static final class Builder {
    static final String USER_AGENT_TEMPLATE = "JDBC Flight SQL Driver %s";
    static final String DEFAULT_VERSION = "(unknown or development build)";

    private final Set<FlightClientMiddleware.Factory> middlewareFactories = new HashSet<>();
    private final Set<CallOption> options = new HashSet<>();
    private String host;
    private int port;

    @VisibleForTesting String username;

    @VisibleForTesting String password;

    @VisibleForTesting String trustStorePath;

    @VisibleForTesting String trustStorePassword;

    @VisibleForTesting String token;

    @VisibleForTesting boolean useEncryption = true;

    @VisibleForTesting boolean disableCertificateVerification;

    @VisibleForTesting boolean useSystemTrustStore = true;

    @VisibleForTesting String tlsRootCertificatesPath;

    @VisibleForTesting String clientCertificatePath;

    @VisibleForTesting String clientKeyPath;

    @VisibleForTesting private BufferAllocator allocator;

    @VisibleForTesting boolean retainCookies = true;

    @VisibleForTesting boolean retainAuth = true;

    @VisibleForTesting Optional<String> catalog = Optional.empty();

    @VisibleForTesting @Nullable FlightClientCache flightClientCache;

    @VisibleForTesting @Nullable Duration connectTimeout;

    @VisibleForTesting @Nullable OAuthConfiguration oauthConfig;

    // These two middleware are for internal use within build() and should not be
    // exposed by builder
    // APIs.
    // Note that these middleware may not necessarily be registered.
    @VisibleForTesting
    ClientIncomingAuthHeaderMiddleware.Factory authFactory =
        new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());

    @VisibleForTesting
    ClientCookieMiddleware.Factory cookieFactory = new ClientCookieMiddleware.Factory();

    DriverVersion driverVersion;

    public Builder() {}

    /**
     * Copies the builder.
     *
     * @param original The builder to base this copy off of.
     */
    @VisibleForTesting
    Builder(Builder original) {
      this.middlewareFactories.addAll(original.middlewareFactories);
      this.options.addAll(original.options);
      this.host = original.host;
      this.port = original.port;
      this.username = original.username;
      this.password = original.password;
      this.trustStorePath = original.trustStorePath;
      this.trustStorePassword = original.trustStorePassword;
      this.token = original.token;
      this.useEncryption = original.useEncryption;
      this.disableCertificateVerification = original.disableCertificateVerification;
      this.useSystemTrustStore = original.useSystemTrustStore;
      this.tlsRootCertificatesPath = original.tlsRootCertificatesPath;
      this.clientCertificatePath = original.clientCertificatePath;
      this.clientKeyPath = original.clientKeyPath;
      this.allocator = original.allocator;
      this.catalog = original.catalog;
      this.oauthConfig = original.oauthConfig;

      if (original.retainCookies) {
        this.cookieFactory = original.cookieFactory;
      }

      if (original.retainAuth) {
        this.authFactory = original.authFactory;
      }

      this.driverVersion = original.driverVersion;
    }

    /**
     * Sets the host for this handler.
     *
     * @param host the host.
     * @return this instance.
     */
    public Builder withHost(final String host) {
      this.host = host;
      return this;
    }

    /**
     * Sets the port for this handler.
     *
     * @param port the port.
     * @return this instance.
     */
    public Builder withPort(final int port) {
      this.port = port;
      return this;
    }

    /**
     * Sets the username for this handler.
     *
     * @param username the username.
     * @return this instance.
     */
    public Builder withUsername(final String username) {
      this.username = username;
      return this;
    }

    /**
     * Sets the password for this handler.
     *
     * @param password the password.
     * @return this instance.
     */
    public Builder withPassword(final String password) {
      this.password = password;
      return this;
    }

    /**
     * Sets the KeyStore path for this handler.
     *
     * @param trustStorePath the KeyStore path.
     * @return this instance.
     */
    public Builder withTrustStorePath(final String trustStorePath) {
      this.trustStorePath = trustStorePath;
      return this;
    }

    /**
     * Sets the KeyStore password for this handler.
     *
     * @param trustStorePassword the KeyStore password.
     * @return this instance.
     */
    public Builder withTrustStorePassword(final String trustStorePassword) {
      this.trustStorePassword = trustStorePassword;
      return this;
    }

    /**
     * Sets whether to use TLS encryption in this handler.
     *
     * @param useEncryption whether to use TLS encryption.
     * @return this instance.
     */
    public Builder withEncryption(final boolean useEncryption) {
      this.useEncryption = useEncryption;
      return this;
    }

    /**
     * Sets whether to disable the certificate verification in this handler.
     *
     * @param disableCertificateVerification whether to disable certificate verification.
     * @return this instance.
     */
    public Builder withDisableCertificateVerification(
        final boolean disableCertificateVerification) {
      this.disableCertificateVerification = disableCertificateVerification;
      return this;
    }

    /**
     * Sets whether to use the certificates from the operating system.
     *
     * @param useSystemTrustStore whether to use the system operating certificates.
     * @return this instance.
     */
    public Builder withSystemTrustStore(final boolean useSystemTrustStore) {
      this.useSystemTrustStore = useSystemTrustStore;
      return this;
    }

    /**
     * Sets the TLS root certificate path as an alternative to using the System or other Trust
     * Store. The path must contain a valid PEM file.
     *
     * @param tlsRootCertificatesPath the TLS root certificate path (if TLS is required).
     * @return this instance.
     */
    public Builder withTlsRootCertificates(final String tlsRootCertificatesPath) {
      this.tlsRootCertificatesPath = tlsRootCertificatesPath;
      return this;
    }

    /**
     * Sets the mTLS client certificate path (if mTLS is required).
     *
     * @param clientCertificatePath the mTLS client certificate path (if mTLS is required).
     * @return this instance.
     */
    public Builder withClientCertificate(final String clientCertificatePath) {
      this.clientCertificatePath = clientCertificatePath;
      return this;
    }

    /**
     * Sets the mTLS client certificate private key path (if mTLS is required).
     *
     * @param clientKeyPath the mTLS client certificate private key path (if mTLS is required).
     * @return this instance.
     */
    public Builder withClientKey(final String clientKeyPath) {
      this.clientKeyPath = clientKeyPath;
      return this;
    }

    /**
     * Sets the token used in the token authentication.
     *
     * @param token the token value.
     * @return this builder instance.
     */
    public Builder withToken(final String token) {
      this.token = token;
      return this;
    }

    /**
     * Sets the {@link BufferAllocator} to use in this handler.
     *
     * @param allocator the allocator.
     * @return this instance.
     */
    public Builder withBufferAllocator(final BufferAllocator allocator) {
      this.allocator =
          allocator.newChildAllocator("ArrowFlightSqlClientHandler", 0, allocator.getLimit());
      return this;
    }

    /**
     * Indicates if cookies should be re-used by connections spawned for getStreams() calls.
     *
     * @param retainCookies The flag indicating if cookies should be re-used.
     * @return this builder instance.
     */
    public Builder withRetainCookies(boolean retainCookies) {
      this.retainCookies = retainCookies;
      return this;
    }

    /**
     * Indicates if bearer tokens negotiated should be re-used by connections spawned for
     * getStreams() calls.
     *
     * @param retainAuth The flag indicating if auth tokens should be re-used.
     * @return this builder instance.
     */
    public Builder withRetainAuth(boolean retainAuth) {
      this.retainAuth = retainAuth;
      return this;
    }

    /**
     * Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this
     * handler.
     *
     * @param factories the factories to add.
     * @return this instance.
     */
    public Builder withMiddlewareFactories(final FlightClientMiddleware.Factory... factories) {
      return withMiddlewareFactories(Arrays.asList(factories));
    }

    /**
     * Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this
     * handler.
     *
     * @param factories the factories to add.
     * @return this instance.
     */
    public Builder withMiddlewareFactories(
        final Collection<FlightClientMiddleware.Factory> factories) {
      this.middlewareFactories.addAll(factories);
      return this;
    }

    /**
     * Adds the provided {@link CallOption}s to this handler.
     *
     * @param options the options
     * @return this instance.
     */
    public Builder withCallOptions(final CallOption... options) {
      return withCallOptions(Arrays.asList(options));
    }

    /**
     * Adds the provided {@link CallOption}s to this handler.
     *
     * @param options the options
     * @return this instance.
     */
    public Builder withCallOptions(final Collection<CallOption> options) {
      this.options.addAll(options);
      return this;
    }

    /**
     * Sets the catalog for this handler if it is not null.
     *
     * @param catalog the catalog
     * @return this instance.
     */
    public Builder withCatalog(@Nullable final String catalog) {
      this.catalog = Optional.ofNullable(catalog);
      return this;
    }

    public Builder withClientCache(FlightClientCache flightClientCache) {
      this.flightClientCache = flightClientCache;
      return this;
    }

    public Builder withConnectTimeout(Duration connectTimeout) {
      this.connectTimeout = connectTimeout;
      return this;
    }

    /**
     * Sets the driver version for this handler.
     *
     * @param driverVersion the driver version to set
     * @return this builder instance
     */
    public Builder withDriverVersion(DriverVersion driverVersion) {
      this.driverVersion = driverVersion;
      return this;
    }

    /**
     * Sets the OAuth configuration for this handler.
     *
     * @param oauthConfig the OAuth configuration
     * @return this builder instance
     */
    public Builder withOAuthConfiguration(final OAuthConfiguration oauthConfig) {
      this.oauthConfig = oauthConfig;
      return this;
    }

    public String getCacheKey() {
      return getLocation().toString();
    }

    /** Get the location that this client will connect to. */
    public Location getLocation() {
      if (useEncryption) {
        return Location.forGrpcTls(host, port);
      }
      return Location.forGrpcInsecure(host, port);
    }

    /**
     * Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields.
     *
     * @return a new client handler.
     * @throws SQLException on error.
     */
    public ArrowFlightSqlClientHandler build() throws SQLException {
      // Copy middleware so that the build method doesn't change the state of the
      // builder fields
      // itself.
      Set<FlightClientMiddleware.Factory> buildTimeMiddlewareFactories =
          new HashSet<>(this.middlewareFactories);
      FlightClient client = null;
      boolean isUsingUserPasswordAuth = username != null && token == null;

      try {
        // Token should take priority since some apps pass in a username/password even
        // when a token
        // is provided
        if (isUsingUserPasswordAuth) {
          buildTimeMiddlewareFactories.add(authFactory);
        }
        final NettyClientBuilder clientBuilder = new NettyClientBuilder();
        clientBuilder.allocator(allocator);

        String userAgent = String.format(USER_AGENT_TEMPLATE, DEFAULT_VERSION);
        if (driverVersion != null && driverVersion.versionString != null) {
          userAgent = String.format(USER_AGENT_TEMPLATE, driverVersion.versionString);
        }

        buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory());
        buildTimeMiddlewareFactories.forEach(clientBuilder::intercept);
        if (useEncryption) {
          clientBuilder.useTls();
        }
        Location location = getLocation();
        clientBuilder.location(location);

        if (useEncryption) {
          if (disableCertificateVerification) {
            clientBuilder.verifyServer(false);
          } else {
            if (tlsRootCertificatesPath != null) {
              clientBuilder.trustedCertificates(
                  ClientAuthenticationUtils.getTlsRootCertificatesStream(tlsRootCertificatesPath));
            } else if (useSystemTrustStore) {
              clientBuilder.trustedCertificates(
                  ClientAuthenticationUtils.getCertificateInputStreamFromSystem(
                      trustStorePassword));
            } else if (trustStorePath != null) {
              clientBuilder.trustedCertificates(
                  ClientAuthenticationUtils.getCertificateStream(
                      trustStorePath, trustStorePassword));
            }
          }

          if (clientCertificatePath != null && clientKeyPath != null) {
            clientBuilder.clientCertificate(
                ClientAuthenticationUtils.getClientCertificateStream(clientCertificatePath),
                ClientAuthenticationUtils.getClientKeyStream(clientKeyPath));
          }
        }

        NettyChannelBuilder channelBuilder = clientBuilder.build();

        channelBuilder.userAgent(userAgent);

        if (connectTimeout != null) {
          channelBuilder.withOption(
              ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connectTimeout.toMillis());
        }
        client =
            FlightGrpcUtils.createFlightClient(
                allocator, channelBuilder.build(), clientBuilder.middleware());
        final ArrayList<CallOption> credentialOptions = new ArrayList<>();
        // Authentication priority: OAuth > token > username/password
        if (oauthConfig != null) {
          OAuthTokenProvider tokenProvider = oauthConfig.createTokenProvider();
          credentialOptions.add(new CredentialCallOption(new OAuthCredentialWriter(tokenProvider)));
        } else if (isUsingUserPasswordAuth) {
          // If the authFactory has already been used for a handshake, use the existing
          // token.
          // This can occur if the authFactory is being re-used for a new connection
          // spawned for
          // getStream().
          if (authFactory.getCredentialCallOption() != null) {
            credentialOptions.add(authFactory.getCredentialCallOption());
          } else {
            // Otherwise do the handshake and get the token if possible.
            credentialOptions.add(
                ClientAuthenticationUtils.getAuthenticate(
                    client, username, password, authFactory, options.toArray(new CallOption[0])));
          }
        } else if (token != null) {
          credentialOptions.add(
              ClientAuthenticationUtils.getAuthenticate(
                  client,
                  new CredentialCallOption(new BearerCredentialWriter(token)),
                  options.toArray(new CallOption[0])));
        }
        return ArrowFlightSqlClientHandler.createNewHandler(
            getCacheKey(), client, this, credentialOptions, catalog, flightClientCache);

      } catch (final IllegalArgumentException
          | GeneralSecurityException
          | IOException
          | FlightRuntimeException e) {
        final SQLException originalException = new SQLException(e);
        if (client != null) {
          try {
            client.close();
          } catch (final InterruptedException interruptedException) {
            originalException.addSuppressed(interruptedException);
          }
        }
        throw originalException;
      }
    }
  }
}