ArrowFlightSqlClientHandler.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.driver.jdbc.client;
import com.google.common.collect.ImmutableMap;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.ChannelOption;
import java.io.IOException;
import java.net.URI;
import java.security.GeneralSecurityException;
import java.sql.SQLException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.arrow.driver.jdbc.client.oauth.OAuthConfiguration;
import org.apache.arrow.driver.jdbc.client.oauth.OAuthCredentialWriter;
import org.apache.arrow.driver.jdbc.client.oauth.OAuthTokenProvider;
import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils;
import org.apache.arrow.driver.jdbc.client.utils.FlightClientCache;
import org.apache.arrow.driver.jdbc.client.utils.FlightLocationQueue;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.CloseSessionRequest;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightGrpcUtils;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.LocationSchemes;
import org.apache.arrow.flight.SessionOptionValueFactory;
import org.apache.arrow.flight.SetSessionOptionsRequest;
import org.apache.arrow.flight.SetSessionOptionsResult;
import org.apache.arrow.flight.auth2.BearerCredentialWriter;
import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler;
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
import org.apache.arrow.flight.client.ClientCookieMiddleware;
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.flight.grpc.NettyClientBuilder;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo;
import org.apache.arrow.flight.sql.util.TableRef;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.DriverVersion;
import org.apache.calcite.avatica.Meta.StatementType;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** A {@link FlightSqlClient} handler. */
public final class ArrowFlightSqlClientHandler implements AutoCloseable {
private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFlightSqlClientHandler.class);
// JDBC connection string query parameter
private static final String CATALOG = "catalog";
private final String cacheKey;
private final FlightSqlClient sqlClient;
private final Set<CallOption> options = new HashSet<>();
private final Builder builder;
private final Optional<String> catalog;
private final @Nullable FlightClientCache flightClientCache;
ArrowFlightSqlClientHandler(
final String cacheKey,
final FlightSqlClient sqlClient,
final Builder builder,
final Collection<CallOption> credentialOptions,
final Optional<String> catalog,
final @Nullable FlightClientCache flightClientCache) {
this.options.addAll(builder.options);
this.options.addAll(credentialOptions);
this.cacheKey = Preconditions.checkNotNull(cacheKey);
this.sqlClient = Preconditions.checkNotNull(sqlClient);
this.builder = builder;
this.catalog = catalog;
this.flightClientCache = flightClientCache;
}
/**
* Creates a new {@link ArrowFlightSqlClientHandler} from the provided {@code client} and {@code
* options}.
*
* @param client the {@link FlightClient} to manage under a {@link FlightSqlClient} wrapper.
* @param options the {@link CallOption}s to persist in between subsequent client calls.
* @return a new {@link ArrowFlightSqlClientHandler}.
*/
static ArrowFlightSqlClientHandler createNewHandler(
final String cacheKey,
final FlightClient client,
final Builder builder,
final Collection<CallOption> options,
final Optional<String> catalog,
final @Nullable FlightClientCache flightClientCache) {
final ArrowFlightSqlClientHandler handler =
new ArrowFlightSqlClientHandler(
cacheKey, new FlightSqlClient(client), builder, options, catalog, flightClientCache);
handler.setSetCatalogInSessionIfPresent();
return handler;
}
/**
* Gets the {@link #options} for the subsequent calls from this handler.
*
* @return the {@link CallOption}s.
*/
private CallOption[] getOptions() {
return options.toArray(new CallOption[0]);
}
/**
* Makes an RPC "getStream" request based on the provided {@link FlightInfo} object. Retrieves the
* result of the query previously prepared with "getInfo."
*
* @param flightInfo The {@link FlightInfo} instance from which to fetch results.
* @return a {@code FlightStream} of results.
*/
public List<CloseableEndpointStreamPair> getStreams(final FlightInfo flightInfo)
throws SQLException {
final ArrayList<CloseableEndpointStreamPair> endpoints =
new ArrayList<>(flightInfo.getEndpoints().size());
try {
for (FlightEndpoint endpoint : flightInfo.getEndpoints()) {
if (endpoint.getLocations().isEmpty()) {
// Create a stream using the current client only and do not close the client at
// the end.
endpoints.add(
new CloseableEndpointStreamPair(
sqlClient.getStream(endpoint.getTicket(), getOptions()), null));
} else {
// Clone the builder and then set the new endpoint on it.
// GH-38574: Currently a new FlightClient will be made for each partition that
// returns a
// non-empty Location then disposed of. It may be better to cache clients
// because a server
// may report the same Locations. It would also be good to identify when the
// reported
// location
// is the same as the original connection's Location and skip creating a
// FlightClient in
// that scenario.
// Also copy the cache to the client so we can share a cache. Cache needs to
// cache
// negative attempts too.
List<Exception> exceptions = new ArrayList<>();
CloseableEndpointStreamPair stream = null;
FlightLocationQueue locations =
new FlightLocationQueue(flightClientCache, endpoint.getLocations());
while (locations.hasNext()) {
Location location = locations.next();
final URI endpointUri = location.getUri();
if (endpointUri.getScheme().equals(LocationSchemes.REUSE_CONNECTION)) {
stream =
new CloseableEndpointStreamPair(
sqlClient.getStream(endpoint.getTicket(), getOptions()), null);
break;
}
final Builder builderForEndpoint =
new Builder(ArrowFlightSqlClientHandler.this.builder)
.withHost(endpointUri.getHost())
.withPort(endpointUri.getPort())
.withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS))
.withClientCache(flightClientCache)
.withConnectTimeout(builder.connectTimeout);
ArrowFlightSqlClientHandler endpointHandler = null;
try {
endpointHandler = builderForEndpoint.build();
stream =
new CloseableEndpointStreamPair(
endpointHandler.sqlClient.getStream(
endpoint.getTicket(), endpointHandler.getOptions()),
endpointHandler.sqlClient);
// Make sure we actually get data from the server
stream.getStream().getSchema();
} catch (Exception ex) {
if (endpointHandler != null) {
// If the exception is related to connectivity, mark the client as a dud.
if (flightClientCache != null) {
if (ex instanceof FlightRuntimeException
&& ((FlightRuntimeException) ex).status().code()
== FlightStatusCode.UNAVAILABLE
&&
// IOException covers SocketException and Netty's (private)
// AnnotatedSocketException
// We are looking for things like "Network is unreachable"
ex.getCause() instanceof IOException) {
flightClientCache.markLocationAsDud(location.toString());
}
}
AutoCloseables.close(endpointHandler);
}
exceptions.add(ex);
continue;
}
if (flightClientCache != null) {
flightClientCache.markLocationAsReachable(location.toString());
}
break;
}
if (stream != null) {
endpoints.add(stream);
} else if (exceptions.isEmpty()) {
// This should never happen...
throw new IllegalStateException("Could not connect to endpoint and no errors occurred");
} else {
Exception ex = exceptions.remove(0);
while (!exceptions.isEmpty()) {
ex.addSuppressed(exceptions.remove(exceptions.size() - 1));
}
throw ex;
}
}
}
} catch (Exception outerException) {
try {
AutoCloseables.close(endpoints);
} catch (Exception innerEx) {
outerException.addSuppressed(innerEx);
}
if (outerException instanceof SQLException) {
throw (SQLException) outerException;
}
throw new SQLException(outerException);
}
return endpoints;
}
/**
* Makes an RPC "getInfo" request based on the provided {@code query} object.
*
* @param query The query.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getInfo(final String query) {
return sqlClient.execute(query, getOptions());
}
@Override
public void close() throws SQLException {
if (catalog.isPresent()) {
try {
sqlClient.closeSession(new CloseSessionRequest(), getOptions());
} catch (FlightRuntimeException fre) {
handleBenignCloseException(
fre, "Failed to close Flight SQL session.", "closing Flight SQL session");
}
}
try {
AutoCloseables.close(sqlClient);
} catch (FlightRuntimeException fre) {
handleBenignCloseException(
fre, "Failed to clean up client resources.", "closing Flight SQL client");
} catch (final Exception e) {
throw new SQLException("Failed to clean up client resources.", e);
}
}
/**
* Handles FlightRuntimeException during close operations, suppressing benign gRPC shutdown errors
* while re-throwing genuine failures.
*
* @param fre the FlightRuntimeException to handle
* @param sqlErrorMessage the SQLException message to use for genuine failures
* @param operationDescription description of the operation for logging
* @throws SQLException if the exception represents a genuine failure
*/
private void handleBenignCloseException(
FlightRuntimeException fre, String sqlErrorMessage, String operationDescription)
throws SQLException {
if (isBenignCloseException(fre)) {
logSuppressedCloseException(fre, operationDescription);
} else {
throw new SQLException(sqlErrorMessage, fre);
}
}
/**
* Handles FlightRuntimeException during close operations, suppressing benign gRPC shutdown errors
* while re-throwing genuine failures as FlightRuntimeException.
*
* @param fre the FlightRuntimeException to handle
* @param operationDescription description of the operation for logging
* @throws FlightRuntimeException if the exception represents a genuine failure
*/
private void handleBenignCloseException(FlightRuntimeException fre, String operationDescription)
throws FlightRuntimeException {
if (isBenignCloseException(fre)) {
logSuppressedCloseException(fre, operationDescription);
} else {
throw fre;
}
}
/**
* Determines if a FlightRuntimeException represents a benign close operation error that should be
* suppressed.
*
* @param fre the FlightRuntimeException to check
* @return true if the exception should be suppressed, false otherwise
*/
private boolean isBenignCloseException(FlightRuntimeException fre) {
return fre.status().code().equals(FlightStatusCode.UNAVAILABLE)
|| (fre.status().code().equals(FlightStatusCode.INTERNAL)
&& fre.getMessage() != null
&& fre.getMessage().contains("Connection closed after GOAWAY"));
}
/**
* Logs a suppressed close exception with appropriate level based on debug settings.
*
* @param fre the FlightRuntimeException being suppressed
* @param operationDescription description of the operation for logging
*/
private void logSuppressedCloseException(
FlightRuntimeException fre, String operationDescription) {
// ARROW-17785 and GH-863: suppress exceptions caused by flaky gRPC layer during
// shutdown
LOGGER.debug("Suppressed error {}", operationDescription, fre);
}
/** A prepared statement handler. */
public interface PreparedStatement extends AutoCloseable {
/**
* Executes this {@link PreparedStatement}.
*
* @return the {@link FlightInfo} representing the outcome of this query execution.
* @throws SQLException on error.
*/
FlightInfo executeQuery() throws SQLException;
/**
* Executes a {@link StatementType#UPDATE} query.
*
* @return the number of rows affected.
*/
long executeUpdate();
/**
* Gets the {@link StatementType} of this {@link PreparedStatement}.
*
* @return the Statement Type.
*/
StatementType getType();
/**
* Gets the {@link Schema} of this {@link PreparedStatement}.
*
* @return {@link Schema}.
*/
Schema getDataSetSchema();
/**
* Gets the {@link Schema} of the parameters for this {@link PreparedStatement}.
*
* @return {@link Schema}.
*/
Schema getParameterSchema();
void setParameters(VectorSchemaRoot parameters);
@Override
void close();
}
/** A connection is created with catalog set as a session option. */
private void setSetCatalogInSessionIfPresent() {
if (catalog.isPresent()) {
try {
setCatalog(catalog.get());
} catch (SQLException e) {
throw CallStatus.INVALID_ARGUMENT
.withDescription(e.getMessage())
.withCause(e)
.toRuntimeException();
}
}
}
/**
* Sets the catalog for the current session.
*
* @param catalog the catalog to set.
* @throws SQLException if an error occurs while setting the catalog.
*/
public void setCatalog(final String catalog) throws SQLException {
final SetSessionOptionsRequest request =
new SetSessionOptionsRequest(
ImmutableMap.of(CATALOG, SessionOptionValueFactory.makeSessionOptionValue(catalog)));
try {
final SetSessionOptionsResult result = sqlClient.setSessionOptions(request, getOptions());
if (result.hasErrors()) {
final Map<String, SetSessionOptionsResult.Error> errors = result.getErrors();
for (final Map.Entry<String, SetSessionOptionsResult.Error> error : errors.entrySet()) {
LOGGER.warn(error.toString());
}
throw new SQLException(
String.format(
"Cannot set session option for catalog = %s. Check log for details.", catalog));
}
} catch (final FlightRuntimeException e) {
throw new SQLException(e);
}
}
/**
* Creates a new {@link PreparedStatement} for the given {@code query}.
*
* @param query the SQL query.
* @return a new prepared statement.
*/
public PreparedStatement prepare(final String query) {
final FlightSqlClient.PreparedStatement preparedStatement =
sqlClient.prepare(query, getOptions());
return new PreparedStatement() {
@Override
public FlightInfo executeQuery() throws SQLException {
return preparedStatement.execute(getOptions());
}
@Override
public long executeUpdate() {
return preparedStatement.executeUpdate(getOptions());
}
@Override
public StatementType getType() {
final Schema schema = preparedStatement.getResultSetSchema();
return schema.getFields().isEmpty() ? StatementType.UPDATE : StatementType.SELECT;
}
@Override
public Schema getDataSetSchema() {
return preparedStatement.getResultSetSchema();
}
@Override
public Schema getParameterSchema() {
return preparedStatement.getParameterSchema();
}
@Override
public void setParameters(VectorSchemaRoot parameters) {
preparedStatement.setParameters(parameters);
}
@Override
public void close() {
try {
preparedStatement.close(getOptions());
} catch (FlightRuntimeException fre) {
handleBenignCloseException(fre, "closing PreparedStatement");
}
}
};
}
/**
* Makes an RPC "getCatalogs" request.
*
* @return a {@code FlightStream} of results.
*/
public FlightInfo getCatalogs() {
return sqlClient.getCatalogs(getOptions());
}
/**
* Makes an RPC "getImportedKeys" request based on the provided info.
*
* @param catalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param schema The schema name. Must match the schema name as it is stored in the database. ""
* retrieves those without a schema. Null means that the schema name should not be used to
* narrow the search.
* @param table The table name. Must match the table name as it is stored in the database.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getImportedKeys(final String catalog, final String schema, final String table) {
return sqlClient.getImportedKeys(TableRef.of(catalog, schema, table), getOptions());
}
/**
* Makes an RPC "getExportedKeys" request based on the provided info.
*
* @param catalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param schema The schema name. Must match the schema name as it is stored in the database. ""
* retrieves those without a schema. Null means that the schema name should not be used to
* narrow the search.
* @param table The table name. Must match the table name as it is stored in the database.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getExportedKeys(final String catalog, final String schema, final String table) {
return sqlClient.getExportedKeys(TableRef.of(catalog, schema, table), getOptions());
}
/**
* Makes an RPC "getSchemas" request based on the provided info.
*
* @param catalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param schemaPattern The schema name pattern. Must match the schema name as it is stored in the
* database. Null means that schema name should not be used to narrow down the search.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getSchemas(final String catalog, final String schemaPattern) {
return sqlClient.getSchemas(catalog, schemaPattern, getOptions());
}
/**
* Makes an RPC "getTableTypes" request.
*
* @return a {@code FlightStream} of results.
*/
public FlightInfo getTableTypes() {
return sqlClient.getTableTypes(getOptions());
}
/**
* Makes an RPC "getTables" request based on the provided info.
*
* @param catalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param schemaPattern The schema name pattern. Must match the schema name as it is stored in the
* database. "" retrieves those without a schema. Null means that the schema name should not
* be used to narrow the search.
* @param tableNamePattern The table name pattern. Must match the table name as it is stored in
* the database.
* @param types The list of table types, which must be from the list of table types to include.
* Null returns all types.
* @param includeSchema Whether to include schema.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getTables(
final String catalog,
final String schemaPattern,
final String tableNamePattern,
final List<String> types,
final boolean includeSchema) {
return sqlClient.getTables(
catalog, schemaPattern, tableNamePattern, types, includeSchema, getOptions());
}
/**
* Gets SQL info.
*
* @return the SQL info.
*/
public FlightInfo getSqlInfo(SqlInfo... info) {
return sqlClient.getSqlInfo(info, getOptions());
}
/**
* Makes an RPC "getPrimaryKeys" request based on the provided info.
*
* @param catalog The catalog name; must match the catalog name as it is stored in the database.
* "" retrieves those without a catalog. Null means that the catalog name should not be used
* to narrow the search.
* @param schema The schema name; must match the schema name as it is stored in the database. ""
* retrieves those without a schema. Null means that the schema name should not be used to
* narrow the search.
* @param table The table name. Must match the table name as it is stored in the database.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getPrimaryKeys(final String catalog, final String schema, final String table) {
return sqlClient.getPrimaryKeys(TableRef.of(catalog, schema, table), getOptions());
}
/**
* Makes an RPC "getCrossReference" request based on the provided info.
*
* @param pkCatalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param pkSchema The schema name. Must match the schema name as it is stored in the database. ""
* retrieves those without a schema. Null means that the schema name should not be used to
* narrow the search.
* @param pkTable The table name. Must match the table name as it is stored in the database.
* @param fkCatalog The catalog name. Must match the catalog name as it is stored in the database.
* Retrieves those without a catalog. Null means that the catalog name should not be used to
* narrow the search.
* @param fkSchema The schema name. Must match the schema name as it is stored in the database. ""
* retrieves those without a schema. Null means that the schema name should not be used to
* narrow the search.
* @param fkTable The table name. Must match the table name as it is stored in the database.
* @return a {@code FlightStream} of results.
*/
public FlightInfo getCrossReference(
String pkCatalog,
String pkSchema,
String pkTable,
String fkCatalog,
String fkSchema,
String fkTable) {
return sqlClient.getCrossReference(
TableRef.of(pkCatalog, pkSchema, pkTable),
TableRef.of(fkCatalog, fkSchema, fkTable),
getOptions());
}
/** Builder for {@link ArrowFlightSqlClientHandler}. */
public static final class Builder {
static final String USER_AGENT_TEMPLATE = "JDBC Flight SQL Driver %s";
static final String DEFAULT_VERSION = "(unknown or development build)";
private final Set<FlightClientMiddleware.Factory> middlewareFactories = new HashSet<>();
private final Set<CallOption> options = new HashSet<>();
private String host;
private int port;
@VisibleForTesting String username;
@VisibleForTesting String password;
@VisibleForTesting String trustStorePath;
@VisibleForTesting String trustStorePassword;
@VisibleForTesting String token;
@VisibleForTesting boolean useEncryption = true;
@VisibleForTesting boolean disableCertificateVerification;
@VisibleForTesting boolean useSystemTrustStore = true;
@VisibleForTesting String tlsRootCertificatesPath;
@VisibleForTesting String clientCertificatePath;
@VisibleForTesting String clientKeyPath;
@VisibleForTesting private BufferAllocator allocator;
@VisibleForTesting boolean retainCookies = true;
@VisibleForTesting boolean retainAuth = true;
@VisibleForTesting Optional<String> catalog = Optional.empty();
@VisibleForTesting @Nullable FlightClientCache flightClientCache;
@VisibleForTesting @Nullable Duration connectTimeout;
@VisibleForTesting @Nullable OAuthConfiguration oauthConfig;
// These two middleware are for internal use within build() and should not be
// exposed by builder
// APIs.
// Note that these middleware may not necessarily be registered.
@VisibleForTesting
ClientIncomingAuthHeaderMiddleware.Factory authFactory =
new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());
@VisibleForTesting
ClientCookieMiddleware.Factory cookieFactory = new ClientCookieMiddleware.Factory();
DriverVersion driverVersion;
public Builder() {}
/**
* Copies the builder.
*
* @param original The builder to base this copy off of.
*/
@VisibleForTesting
Builder(Builder original) {
this.middlewareFactories.addAll(original.middlewareFactories);
this.options.addAll(original.options);
this.host = original.host;
this.port = original.port;
this.username = original.username;
this.password = original.password;
this.trustStorePath = original.trustStorePath;
this.trustStorePassword = original.trustStorePassword;
this.token = original.token;
this.useEncryption = original.useEncryption;
this.disableCertificateVerification = original.disableCertificateVerification;
this.useSystemTrustStore = original.useSystemTrustStore;
this.tlsRootCertificatesPath = original.tlsRootCertificatesPath;
this.clientCertificatePath = original.clientCertificatePath;
this.clientKeyPath = original.clientKeyPath;
this.allocator = original.allocator;
this.catalog = original.catalog;
this.oauthConfig = original.oauthConfig;
if (original.retainCookies) {
this.cookieFactory = original.cookieFactory;
}
if (original.retainAuth) {
this.authFactory = original.authFactory;
}
this.driverVersion = original.driverVersion;
}
/**
* Sets the host for this handler.
*
* @param host the host.
* @return this instance.
*/
public Builder withHost(final String host) {
this.host = host;
return this;
}
/**
* Sets the port for this handler.
*
* @param port the port.
* @return this instance.
*/
public Builder withPort(final int port) {
this.port = port;
return this;
}
/**
* Sets the username for this handler.
*
* @param username the username.
* @return this instance.
*/
public Builder withUsername(final String username) {
this.username = username;
return this;
}
/**
* Sets the password for this handler.
*
* @param password the password.
* @return this instance.
*/
public Builder withPassword(final String password) {
this.password = password;
return this;
}
/**
* Sets the KeyStore path for this handler.
*
* @param trustStorePath the KeyStore path.
* @return this instance.
*/
public Builder withTrustStorePath(final String trustStorePath) {
this.trustStorePath = trustStorePath;
return this;
}
/**
* Sets the KeyStore password for this handler.
*
* @param trustStorePassword the KeyStore password.
* @return this instance.
*/
public Builder withTrustStorePassword(final String trustStorePassword) {
this.trustStorePassword = trustStorePassword;
return this;
}
/**
* Sets whether to use TLS encryption in this handler.
*
* @param useEncryption whether to use TLS encryption.
* @return this instance.
*/
public Builder withEncryption(final boolean useEncryption) {
this.useEncryption = useEncryption;
return this;
}
/**
* Sets whether to disable the certificate verification in this handler.
*
* @param disableCertificateVerification whether to disable certificate verification.
* @return this instance.
*/
public Builder withDisableCertificateVerification(
final boolean disableCertificateVerification) {
this.disableCertificateVerification = disableCertificateVerification;
return this;
}
/**
* Sets whether to use the certificates from the operating system.
*
* @param useSystemTrustStore whether to use the system operating certificates.
* @return this instance.
*/
public Builder withSystemTrustStore(final boolean useSystemTrustStore) {
this.useSystemTrustStore = useSystemTrustStore;
return this;
}
/**
* Sets the TLS root certificate path as an alternative to using the System or other Trust
* Store. The path must contain a valid PEM file.
*
* @param tlsRootCertificatesPath the TLS root certificate path (if TLS is required).
* @return this instance.
*/
public Builder withTlsRootCertificates(final String tlsRootCertificatesPath) {
this.tlsRootCertificatesPath = tlsRootCertificatesPath;
return this;
}
/**
* Sets the mTLS client certificate path (if mTLS is required).
*
* @param clientCertificatePath the mTLS client certificate path (if mTLS is required).
* @return this instance.
*/
public Builder withClientCertificate(final String clientCertificatePath) {
this.clientCertificatePath = clientCertificatePath;
return this;
}
/**
* Sets the mTLS client certificate private key path (if mTLS is required).
*
* @param clientKeyPath the mTLS client certificate private key path (if mTLS is required).
* @return this instance.
*/
public Builder withClientKey(final String clientKeyPath) {
this.clientKeyPath = clientKeyPath;
return this;
}
/**
* Sets the token used in the token authentication.
*
* @param token the token value.
* @return this builder instance.
*/
public Builder withToken(final String token) {
this.token = token;
return this;
}
/**
* Sets the {@link BufferAllocator} to use in this handler.
*
* @param allocator the allocator.
* @return this instance.
*/
public Builder withBufferAllocator(final BufferAllocator allocator) {
this.allocator =
allocator.newChildAllocator("ArrowFlightSqlClientHandler", 0, allocator.getLimit());
return this;
}
/**
* Indicates if cookies should be re-used by connections spawned for getStreams() calls.
*
* @param retainCookies The flag indicating if cookies should be re-used.
* @return this builder instance.
*/
public Builder withRetainCookies(boolean retainCookies) {
this.retainCookies = retainCookies;
return this;
}
/**
* Indicates if bearer tokens negotiated should be re-used by connections spawned for
* getStreams() calls.
*
* @param retainAuth The flag indicating if auth tokens should be re-used.
* @return this builder instance.
*/
public Builder withRetainAuth(boolean retainAuth) {
this.retainAuth = retainAuth;
return this;
}
/**
* Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this
* handler.
*
* @param factories the factories to add.
* @return this instance.
*/
public Builder withMiddlewareFactories(final FlightClientMiddleware.Factory... factories) {
return withMiddlewareFactories(Arrays.asList(factories));
}
/**
* Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this
* handler.
*
* @param factories the factories to add.
* @return this instance.
*/
public Builder withMiddlewareFactories(
final Collection<FlightClientMiddleware.Factory> factories) {
this.middlewareFactories.addAll(factories);
return this;
}
/**
* Adds the provided {@link CallOption}s to this handler.
*
* @param options the options
* @return this instance.
*/
public Builder withCallOptions(final CallOption... options) {
return withCallOptions(Arrays.asList(options));
}
/**
* Adds the provided {@link CallOption}s to this handler.
*
* @param options the options
* @return this instance.
*/
public Builder withCallOptions(final Collection<CallOption> options) {
this.options.addAll(options);
return this;
}
/**
* Sets the catalog for this handler if it is not null.
*
* @param catalog the catalog
* @return this instance.
*/
public Builder withCatalog(@Nullable final String catalog) {
this.catalog = Optional.ofNullable(catalog);
return this;
}
public Builder withClientCache(FlightClientCache flightClientCache) {
this.flightClientCache = flightClientCache;
return this;
}
public Builder withConnectTimeout(Duration connectTimeout) {
this.connectTimeout = connectTimeout;
return this;
}
/**
* Sets the driver version for this handler.
*
* @param driverVersion the driver version to set
* @return this builder instance
*/
public Builder withDriverVersion(DriverVersion driverVersion) {
this.driverVersion = driverVersion;
return this;
}
/**
* Sets the OAuth configuration for this handler.
*
* @param oauthConfig the OAuth configuration
* @return this builder instance
*/
public Builder withOAuthConfiguration(final OAuthConfiguration oauthConfig) {
this.oauthConfig = oauthConfig;
return this;
}
public String getCacheKey() {
return getLocation().toString();
}
/** Get the location that this client will connect to. */
public Location getLocation() {
if (useEncryption) {
return Location.forGrpcTls(host, port);
}
return Location.forGrpcInsecure(host, port);
}
/**
* Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields.
*
* @return a new client handler.
* @throws SQLException on error.
*/
public ArrowFlightSqlClientHandler build() throws SQLException {
// Copy middleware so that the build method doesn't change the state of the
// builder fields
// itself.
Set<FlightClientMiddleware.Factory> buildTimeMiddlewareFactories =
new HashSet<>(this.middlewareFactories);
FlightClient client = null;
boolean isUsingUserPasswordAuth = username != null && token == null;
try {
// Token should take priority since some apps pass in a username/password even
// when a token
// is provided
if (isUsingUserPasswordAuth) {
buildTimeMiddlewareFactories.add(authFactory);
}
final NettyClientBuilder clientBuilder = new NettyClientBuilder();
clientBuilder.allocator(allocator);
String userAgent = String.format(USER_AGENT_TEMPLATE, DEFAULT_VERSION);
if (driverVersion != null && driverVersion.versionString != null) {
userAgent = String.format(USER_AGENT_TEMPLATE, driverVersion.versionString);
}
buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory());
buildTimeMiddlewareFactories.forEach(clientBuilder::intercept);
if (useEncryption) {
clientBuilder.useTls();
}
Location location = getLocation();
clientBuilder.location(location);
if (useEncryption) {
if (disableCertificateVerification) {
clientBuilder.verifyServer(false);
} else {
if (tlsRootCertificatesPath != null) {
clientBuilder.trustedCertificates(
ClientAuthenticationUtils.getTlsRootCertificatesStream(tlsRootCertificatesPath));
} else if (useSystemTrustStore) {
clientBuilder.trustedCertificates(
ClientAuthenticationUtils.getCertificateInputStreamFromSystem(
trustStorePassword));
} else if (trustStorePath != null) {
clientBuilder.trustedCertificates(
ClientAuthenticationUtils.getCertificateStream(
trustStorePath, trustStorePassword));
}
}
if (clientCertificatePath != null && clientKeyPath != null) {
clientBuilder.clientCertificate(
ClientAuthenticationUtils.getClientCertificateStream(clientCertificatePath),
ClientAuthenticationUtils.getClientKeyStream(clientKeyPath));
}
}
NettyChannelBuilder channelBuilder = clientBuilder.build();
channelBuilder.userAgent(userAgent);
if (connectTimeout != null) {
channelBuilder.withOption(
ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connectTimeout.toMillis());
}
client =
FlightGrpcUtils.createFlightClient(
allocator, channelBuilder.build(), clientBuilder.middleware());
final ArrayList<CallOption> credentialOptions = new ArrayList<>();
// Authentication priority: OAuth > token > username/password
if (oauthConfig != null) {
OAuthTokenProvider tokenProvider = oauthConfig.createTokenProvider();
credentialOptions.add(new CredentialCallOption(new OAuthCredentialWriter(tokenProvider)));
} else if (isUsingUserPasswordAuth) {
// If the authFactory has already been used for a handshake, use the existing
// token.
// This can occur if the authFactory is being re-used for a new connection
// spawned for
// getStream().
if (authFactory.getCredentialCallOption() != null) {
credentialOptions.add(authFactory.getCredentialCallOption());
} else {
// Otherwise do the handshake and get the token if possible.
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client, username, password, authFactory, options.toArray(new CallOption[0])));
}
} else if (token != null) {
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client,
new CredentialCallOption(new BearerCredentialWriter(token)),
options.toArray(new CallOption[0])));
}
return ArrowFlightSqlClientHandler.createNewHandler(
getCacheKey(), client, this, credentialOptions, catalog, flightClientCache);
} catch (final IllegalArgumentException
| GeneralSecurityException
| IOException
| FlightRuntimeException e) {
final SQLException originalException = new SQLException(e);
if (client != null) {
try {
client.close();
} catch (final InterruptedException interruptedException) {
originalException.addSuppressed(interruptedException);
}
}
throw originalException;
}
}
}
}