FlightServerTestExtension.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.driver.jdbc;
import static org.apache.arrow.driver.jdbc.utils.FlightSqlTestCertificates.CertKeyPair;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import org.apache.arrow.driver.jdbc.authentication.Authentication;
import org.apache.arrow.driver.jdbc.authentication.TokenAuthentication;
import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication;
import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightMethod;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.RequestContext;
import org.apache.arrow.flight.sql.FlightSqlProducer;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Utility class for unit tests that need to instantiate a {@link FlightServer} and interact with
* it.
*/
public class FlightServerTestExtension
implements BeforeAllCallback, AfterAllCallback, AutoCloseable {
public static final String DEFAULT_USER = "flight-test-user";
public static final String DEFAULT_PASSWORD = "flight-test-password";
private static final Logger LOGGER = LoggerFactory.getLogger(FlightServerTestExtension.class);
private final Properties properties;
private final ArrowFlightConnectionConfigImpl config;
private final BufferAllocator allocator;
private final FlightSqlProducer producer;
private final Authentication authentication;
private final CertKeyPair certKeyPair;
private final File mTlsCACert;
private final InterceptorMiddleware.Factory interceptorFactory =
new InterceptorMiddleware.Factory();
private FlightServerTestExtension(
final Properties properties,
final ArrowFlightConnectionConfigImpl config,
final BufferAllocator allocator,
final FlightSqlProducer producer,
final Authentication authentication,
final CertKeyPair certKeyPair,
final File mTlsCACert) {
this.properties = Preconditions.checkNotNull(properties);
this.config = Preconditions.checkNotNull(config);
this.allocator = Preconditions.checkNotNull(allocator);
this.producer = Preconditions.checkNotNull(producer);
this.authentication = authentication;
this.certKeyPair = certKeyPair;
this.mTlsCACert = mTlsCACert;
}
/**
* Create a {@link FlightServerTestExtension} with standard values such as: user, password,
* localhost.
*
* @param producer the producer used to create the FlightServerTestExtension.
* @return the FlightServerTestExtension.
*/
public static FlightServerTestExtension createStandardTestExtension(
final FlightSqlProducer producer) {
UserPasswordAuthentication authentication =
new UserPasswordAuthentication.Builder().user(DEFAULT_USER, DEFAULT_PASSWORD).build();
return new Builder().authentication(authentication).producer(producer).build();
}
ArrowFlightJdbcDataSource createDataSource() {
return ArrowFlightJdbcDataSource.createNewDataSource(properties);
}
public ArrowFlightJdbcConnectionPoolDataSource createConnectionPoolDataSource() {
return ArrowFlightJdbcConnectionPoolDataSource.createNewDataSource(properties);
}
public ArrowFlightJdbcConnectionPoolDataSource createConnectionPoolDataSource(
boolean useEncryption) {
setUseEncryption(useEncryption);
return ArrowFlightJdbcConnectionPoolDataSource.createNewDataSource(properties);
}
public Connection getConnection(boolean useEncryption, String token) throws SQLException {
properties.put("token", token);
return getConnection(useEncryption);
}
public Connection getConnection(boolean useEncryption) throws SQLException {
setUseEncryption(useEncryption);
return this.createDataSource().getConnection();
}
public Connection getConnection(String timezone) throws SQLException {
setUseEncryption(false);
properties.put("timezone", timezone);
return this.createDataSource().getConnection();
}
private void setUseEncryption(boolean useEncryption) {
properties.put("useEncryption", useEncryption);
}
public InterceptorMiddleware.Factory getInterceptorFactory() {
return interceptorFactory;
}
@FunctionalInterface
public interface CheckedFunction<T, R> {
R apply(T t) throws IOException;
}
private FlightServer initiateServer(Location location) throws IOException {
FlightServer.Builder builder =
FlightServer.builder(allocator, location, producer)
.headerAuthenticator(authentication.authenticate())
.middleware(FlightServerMiddleware.Key.of("KEY"), interceptorFactory);
if (certKeyPair != null) {
builder.useTls(certKeyPair.cert, certKeyPair.key);
}
if (mTlsCACert != null) {
builder.useMTlsClientVerification(mTlsCACert);
}
return builder.build();
}
@Override
public void beforeAll(ExtensionContext context) throws Exception {
try {
FlightServer flightServer = getStartServer(this::initiateServer, 3);
properties.put("port", flightServer.getPort());
LOGGER.info("Started " + FlightServer.class.getName() + " as " + flightServer);
context.getStore(ExtensionContext.Namespace.GLOBAL).put("flightServer", flightServer);
} catch (Exception e) {
LOGGER.error("Failed to start FlightServer", e);
throw e;
}
}
@Override
public void afterAll(ExtensionContext context) throws Exception {
FlightServer flightServer =
context.getStore(ExtensionContext.Namespace.GLOBAL).get("flightServer", FlightServer.class);
if (flightServer != null) {
flightServer.close();
}
close();
}
private FlightServer getStartServer(
CheckedFunction<Location, FlightServer> newServerFromLocation, int retries)
throws IOException {
final Deque<ReflectiveOperationException> exceptions = new ArrayDeque<>();
for (; retries > 0; retries--) {
final FlightServer server =
newServerFromLocation.apply(Location.forGrpcInsecure("localhost", 0));
try {
Method start = server.getClass().getMethod("start");
start.setAccessible(true);
start.invoke(server);
return server;
} catch (ReflectiveOperationException e) {
exceptions.add(e);
}
}
exceptions.forEach(e -> LOGGER.error("Failed to start FlightServer", e));
throw new IOException(exceptions.pop().getCause());
}
/**
* Sets a port to be used.
*
* @return the port value.
*/
public int getPort() {
return config.getPort();
}
/**
* Sets a host to be used.
*
* @return the host value.
*/
public String getHost() {
return config.getHost();
}
@Override
public void close() throws Exception {
allocator.getChildAllocators().forEach(BufferAllocator::close);
AutoCloseables.close(allocator);
}
/** Builder for {@link FlightServerTestExtension}. */
public static final class Builder {
private final Properties properties;
private FlightSqlProducer producer;
private Authentication authentication;
private CertKeyPair certKeyPair;
private File mTlsCACert;
public Builder() {
this.properties = new Properties();
this.properties.put("host", "localhost");
}
/**
* Sets the producer that will be used in the server rule.
*
* @param producer the flight sql producer.
* @return the Builder.
*/
public Builder producer(final FlightSqlProducer producer) {
this.producer = producer;
return this;
}
/**
* Sets the type of the authentication that will be used in the server rules. There are two
* types of authentication: {@link UserPasswordAuthentication} and {@link TokenAuthentication}.
*
* @param authentication the type of authentication.
* @return the Builder.
*/
public Builder authentication(final Authentication authentication) {
this.authentication = authentication;
return this;
}
/**
* Enable TLS on the server.
*
* @param certChain The certificate chain to use.
* @param key The private key to use.
* @return the Builder.
*/
public Builder useEncryption(final File certChain, final File key) {
certKeyPair = new CertKeyPair(certChain, key);
return this;
}
/**
* Enable Client Verification via mTLS on the server.
*
* @param mTlsCACert The CA certificate to use for client verification.
* @return the Builder.
*/
public Builder useMTlsClientVerification(final File mTlsCACert) {
this.mTlsCACert = mTlsCACert;
return this;
}
/**
* Builds the {@link FlightServerTestExtension} using the provided values.
*
* @return a {@link FlightServerTestExtension}.
*/
public FlightServerTestExtension build() {
authentication.populateProperties(properties);
return new FlightServerTestExtension(
properties,
new ArrowFlightConnectionConfigImpl(properties),
new RootAllocator(Long.MAX_VALUE),
producer,
authentication,
certKeyPair,
mTlsCACert);
}
}
/**
* A middleware to handle with the cookies in the server. It is used to test if cookies are being
* sent properly.
*/
static class InterceptorMiddleware implements FlightServerMiddleware {
private final Factory factory;
public InterceptorMiddleware(Factory factory) {
this.factory = factory;
}
@Override
public void onBeforeSendingHeaders(CallHeaders callHeaders) {
if (!factory.receivedCookieHeader) {
callHeaders.insert("Set-Cookie", "k=v");
}
}
@Override
public void onCallCompleted(CallStatus callStatus) {}
@Override
public void onCallErrored(Throwable throwable) {}
/** A factory for the MiddlewareCookie. */
static class Factory implements FlightServerMiddleware.Factory<InterceptorMiddleware> {
private final Map<FlightMethod, CallHeaders> receivedCallHeaders = new HashMap<>();
private boolean receivedCookieHeader = false;
private String cookie;
@Override
public InterceptorMiddleware onCallStarted(
CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) {
cookie = callHeaders.get("Cookie");
receivedCookieHeader = null != cookie;
receivedCallHeaders.put(callInfo.method(), callHeaders);
return new InterceptorMiddleware(this);
}
public String getCookie() {
return cookie;
}
public String getHeader(FlightMethod method, String key) {
CallHeaders headers = receivedCallHeaders.get(method);
if (headers == null) {
return null;
}
return headers.get(key);
}
}
}
}