TestClientMiddleware.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight;
import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.junit.jupiter.api.Test;
/** A basic test of client middleware using a simplified OpenTracing-like example. */
public class TestClientMiddleware {
/**
* Test that a client middleware can fail a call before it starts by throwing a {@link
* FlightRuntimeException}.
*/
@Test
public void clientMiddleware_failCallBeforeSending() {
test(
new NoOpFlightProducer(),
null,
Collections.singletonList(new CallRejector.Factory()),
(allocator, client) -> {
FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, client::listActions);
});
}
/**
* Test an OpenTracing-like scenario where client and server middleware work together to propagate
* a request ID without explicit intervention from the service implementation.
*/
@Test
public void middleware_propagateHeader() {
final Context context = new Context("span id");
test(
new NoOpFlightProducer(),
new TestServerMiddleware.ServerMiddlewarePair<>(
FlightServerMiddleware.Key.of("test"), new ServerSpanInjector.Factory()),
Collections.singletonList(new ClientSpanInjector.Factory(context)),
(allocator, client) -> {
FlightTestUtil.assertCode(
FlightStatusCode.UNIMPLEMENTED, () -> client.listActions().forEach(actionType -> {}));
});
assertEquals(context.outgoingSpanId, context.incomingSpanId);
assertNotNull(context.finalStatus);
assertEquals(FlightStatusCode.UNIMPLEMENTED, context.finalStatus.code());
}
/**
* Ensure both server and client can send and receive multi-valued headers (both binary and text
* values).
*/
@Test
public void testMultiValuedHeaders() {
final MultiHeaderClientMiddlewareFactory clientFactory =
new MultiHeaderClientMiddlewareFactory();
test(
new NoOpFlightProducer(),
new TestServerMiddleware.ServerMiddlewarePair<>(
FlightServerMiddleware.Key.of("test"), new MultiHeaderServerMiddlewareFactory()),
Collections.singletonList(clientFactory),
(allocator, client) -> {
FlightTestUtil.assertCode(
FlightStatusCode.UNIMPLEMENTED, () -> client.listActions().forEach(actionType -> {}));
});
// The server echoes the headers we send back to us, so ensure all the ones we sent are present
// with the correct
// values in the correct order.
for (final Map.Entry<String, List<byte[]>> entry : EXPECTED_BINARY_HEADERS.entrySet()) {
// Compare header values entry-by-entry because byte arrays don't compare via equals
final List<byte[]> receivedValues = clientFactory.lastBinaryHeaders.get(entry.getKey());
assertNotNull(receivedValues, "Missing for header: " + entry.getKey());
assertEquals(
entry.getValue().size(),
receivedValues.size(),
"Missing or wrong value for header: " + entry.getKey());
for (int i = 0; i < entry.getValue().size(); i++) {
assertArrayEquals(entry.getValue().get(i), receivedValues.get(i));
}
}
for (final Map.Entry<String, List<String>> entry : EXPECTED_TEXT_HEADERS.entrySet()) {
assertEquals(
entry.getValue(),
clientFactory.lastTextHeaders.get(entry.getKey()),
"Missing or wrong value for header: " + entry.getKey());
}
}
private static <T extends FlightServerMiddleware> void test(
FlightProducer producer,
TestServerMiddleware.ServerMiddlewarePair<T> serverMiddleware,
List<FlightClientMiddleware.Factory> clientMiddleware,
BiConsumer<BufferAllocator, FlightClient> body) {
try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
final FlightServer.Builder serverBuilder =
FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer);
if (serverMiddleware != null) {
serverBuilder.middleware(serverMiddleware.key, serverMiddleware.factory);
}
final FlightServer server = serverBuilder.build().start();
FlightClient.Builder clientBuilder = FlightClient.builder(allocator, server.getLocation());
clientMiddleware.forEach(clientBuilder::intercept);
try (final FlightServer ignored = server;
final FlightClient client = clientBuilder.build()) {
body.accept(allocator, client);
}
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
}
/**
* A server middleware component that reads a request ID from incoming headers and sends the
* request ID back on outgoing headers.
*/
static class ServerSpanInjector implements FlightServerMiddleware {
private final String spanId;
public ServerSpanInjector(String spanId) {
this.spanId = spanId;
}
@Override
public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
outgoingHeaders.insert("x-span", spanId);
}
@Override
public void onCallCompleted(CallStatus status) {}
@Override
public void onCallErrored(Throwable err) {}
static class Factory implements FlightServerMiddleware.Factory<ServerSpanInjector> {
@Override
public ServerSpanInjector onCallStarted(
CallInfo info, CallHeaders incomingHeaders, RequestContext context) {
return new ServerSpanInjector(incomingHeaders.get("x-span"));
}
}
}
/**
* A client middleware component that, given a mock OpenTracing-like "request context", sends the
* request ID in the context on outgoing headers and reads it from incoming headers.
*/
static class ClientSpanInjector implements FlightClientMiddleware {
private final Context context;
public ClientSpanInjector(Context context) {
this.context = context;
}
@Override
public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
outgoingHeaders.insert("x-span", context.outgoingSpanId);
}
@Override
public void onHeadersReceived(CallHeaders incomingHeaders) {
context.incomingSpanId = incomingHeaders.get("x-span");
}
@Override
public void onCallCompleted(CallStatus status) {
context.finalStatus = status;
}
static class Factory implements FlightClientMiddleware.Factory {
private final Context context;
Factory(Context context) {
this.context = context;
}
@Override
public FlightClientMiddleware onCallStarted(CallInfo info) {
return new ClientSpanInjector(context);
}
}
}
/** A mock OpenTracing-like "request context". */
static class Context {
final String outgoingSpanId;
String incomingSpanId;
CallStatus finalStatus;
Context(String spanId) {
this.outgoingSpanId = spanId;
}
}
/** A client middleware that fails outgoing calls. */
static class CallRejector implements FlightClientMiddleware {
@Override
public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {}
@Override
public void onHeadersReceived(CallHeaders incomingHeaders) {}
@Override
public void onCallCompleted(CallStatus status) {}
static class Factory implements FlightClientMiddleware.Factory {
@Override
public FlightClientMiddleware onCallStarted(CallInfo info) {
throw CallStatus.UNAVAILABLE.withDescription("Rejecting call.").toRuntimeException();
}
}
}
// Used to test that middleware can send and receive multi-valued text and binary headers.
static Map<String, List<byte[]>> EXPECTED_BINARY_HEADERS = new HashMap<String, List<byte[]>>();
static Map<String, List<String>> EXPECTED_TEXT_HEADERS = new HashMap<String, List<String>>();
static {
EXPECTED_BINARY_HEADERS.put("x-binary-bin", Arrays.asList(new byte[] {0}, new byte[] {1}));
EXPECTED_TEXT_HEADERS.put("x-text", Arrays.asList("foo", "bar"));
}
static class MultiHeaderServerMiddlewareFactory
implements FlightServerMiddleware.Factory<MultiHeaderServerMiddleware> {
@Override
public MultiHeaderServerMiddleware onCallStarted(
CallInfo info, CallHeaders incomingHeaders, RequestContext context) {
// Echo the headers back to the client. Copy values out of CallHeaders since the underlying
// gRPC metadata
// object isn't safe to use after this function returns.
Map<String, List<byte[]>> binaryHeaders = new HashMap<>();
Map<String, List<String>> textHeaders = new HashMap<>();
for (final String key : incomingHeaders.keys()) {
if (key.endsWith("-bin")) {
binaryHeaders.compute(
key,
(ignored, values) -> {
if (values == null) {
values = new ArrayList<>();
}
incomingHeaders.getAllByte(key).forEach(values::add);
return values;
});
} else {
textHeaders.compute(
key,
(ignored, values) -> {
if (values == null) {
values = new ArrayList<>();
}
incomingHeaders.getAll(key).forEach(values::add);
return values;
});
}
}
return new MultiHeaderServerMiddleware(binaryHeaders, textHeaders);
}
}
static class MultiHeaderServerMiddleware implements FlightServerMiddleware {
private final Map<String, List<byte[]>> binaryHeaders;
private final Map<String, List<String>> textHeaders;
MultiHeaderServerMiddleware(
Map<String, List<byte[]>> binaryHeaders, Map<String, List<String>> textHeaders) {
this.binaryHeaders = binaryHeaders;
this.textHeaders = textHeaders;
}
@Override
public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
binaryHeaders.forEach(
(key, values) -> values.forEach(value -> outgoingHeaders.insert(key, value)));
textHeaders.forEach(
(key, values) -> values.forEach(value -> outgoingHeaders.insert(key, value)));
}
@Override
public void onCallCompleted(CallStatus status) {}
@Override
public void onCallErrored(Throwable err) {}
}
static class MultiHeaderClientMiddlewareFactory implements FlightClientMiddleware.Factory {
Map<String, List<byte[]>> lastBinaryHeaders = null;
Map<String, List<String>> lastTextHeaders = null;
@Override
public FlightClientMiddleware onCallStarted(CallInfo info) {
return new MultiHeaderClientMiddleware(this);
}
}
static class MultiHeaderClientMiddleware implements FlightClientMiddleware {
private final MultiHeaderClientMiddlewareFactory factory;
public MultiHeaderClientMiddleware(MultiHeaderClientMiddlewareFactory factory) {
this.factory = factory;
}
@Override
public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
for (final Map.Entry<String, List<byte[]>> entry : EXPECTED_BINARY_HEADERS.entrySet()) {
entry.getValue().forEach((value) -> outgoingHeaders.insert(entry.getKey(), value));
assertTrue(outgoingHeaders.containsKey(entry.getKey()));
}
for (final Map.Entry<String, List<String>> entry : EXPECTED_TEXT_HEADERS.entrySet()) {
entry.getValue().forEach((value) -> outgoingHeaders.insert(entry.getKey(), value));
assertTrue(outgoingHeaders.containsKey(entry.getKey()));
}
}
@Override
public void onHeadersReceived(CallHeaders incomingHeaders) {
factory.lastBinaryHeaders = new HashMap<>();
factory.lastTextHeaders = new HashMap<>();
incomingHeaders
.keys()
.forEach(
header -> {
if (header.endsWith("-bin")) {
final List<byte[]> values = new ArrayList<>();
incomingHeaders.getAllByte(header).forEach(values::add);
factory.lastBinaryHeaders.put(header, values);
} else {
final List<String> values = new ArrayList<>();
incomingHeaders.getAll(header).forEach(values::add);
factory.lastTextHeaders.put(header, values);
}
});
}
@Override
public void onCallCompleted(CallStatus status) {}
}
}