TestServerMiddleware.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight;
import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.fail;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiConsumer;
import org.apache.arrow.flight.FlightClient.ClientStreamListener;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Test;
public class TestServerMiddleware {
/** Make sure errors in DoPut are intercepted. */
@Test
public void doPutErrors() {
test(
new ErrorProducer(new RuntimeException("test")),
(allocator, client) -> {
final FlightDescriptor descriptor = FlightDescriptor.path("test");
try (final VectorSchemaRoot root =
VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) {
final ClientStreamListener listener =
client.startPut(descriptor, root, new SyncPutListener());
listener.completed();
FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, listener::getResult);
}
},
(recorder) -> {
final CallStatus status = recorder.statusFuture.get();
assertNotNull(status);
assertNotNull(status.cause());
assertEquals(FlightStatusCode.INTERNAL, status.code());
});
// Check the status after server shutdown (to make sure gRPC finishes pending calls on the
// server side)
}
/** Make sure custom error codes in DoPut are intercepted. */
@Test
public void doPutCustomCode() {
test(
new ErrorProducer(
CallStatus.UNAVAILABLE.withDescription("description").toRuntimeException()),
(allocator, client) -> {
final FlightDescriptor descriptor = FlightDescriptor.path("test");
try (final VectorSchemaRoot root =
VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) {
final ClientStreamListener listener =
client.startPut(descriptor, root, new SyncPutListener());
listener.completed();
FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, listener::getResult);
}
},
(recorder) -> {
final CallStatus status = recorder.statusFuture.get();
assertNotNull(status);
assertNull(status.cause());
assertEquals(FlightStatusCode.UNAVAILABLE, status.code());
assertEquals("description", status.description());
});
}
/** Make sure uncaught exceptions in DoPut are intercepted. */
@Test
public void doPutUncaught() {
test(
new ServerErrorProducer(new RuntimeException("test")),
(allocator, client) -> {
final FlightDescriptor descriptor = FlightDescriptor.path("test");
try (final VectorSchemaRoot root =
VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) {
final ClientStreamListener listener =
client.startPut(descriptor, root, new SyncPutListener());
listener.completed();
listener.getResult();
}
},
(recorder) -> {
final CallStatus status = recorder.statusFuture.get();
final Throwable err = recorder.errFuture.get();
assertNotNull(status);
assertEquals(FlightStatusCode.OK, status.code());
assertNull(status.cause());
assertNotNull(err);
assertEquals("test", err.getMessage());
});
}
@Test
public void listFlightsUncaught() {
test(
new ServerErrorProducer(new RuntimeException("test")),
(allocator, client) ->
client.listFlights(new Criteria(new byte[0])).forEach((action) -> {}),
(recorder) -> {
final CallStatus status = recorder.statusFuture.get();
final Throwable err = recorder.errFuture.get();
assertNotNull(status);
assertEquals(FlightStatusCode.OK, status.code());
assertNull(status.cause());
assertNotNull(err);
assertEquals("test", err.getMessage());
});
}
@Test
public void doActionUncaught() {
test(
new ServerErrorProducer(new RuntimeException("test")),
(allocator, client) -> client.doAction(new Action("test")).forEachRemaining(result -> {}),
(recorder) -> {
final CallStatus status = recorder.statusFuture.get();
final Throwable err = recorder.errFuture.get();
assertNotNull(status);
assertEquals(FlightStatusCode.OK, status.code());
assertNull(status.cause());
assertNotNull(err);
assertEquals("test", err.getMessage());
});
}
@Test
public void listActionsUncaught() {
test(
new ServerErrorProducer(new RuntimeException("test")),
(allocator, client) -> client.listActions().forEach(result -> {}),
(recorder) -> {
final CallStatus status = recorder.statusFuture.get();
final Throwable err = recorder.errFuture.get();
assertNotNull(status);
assertEquals(FlightStatusCode.OK, status.code());
assertNull(status.cause());
assertNotNull(err);
assertEquals("test", err.getMessage());
});
}
@Test
public void getFlightInfoUncaught() {
test(
new ServerErrorProducer(new RuntimeException("test")),
(allocator, client) -> {
FlightTestUtil.assertCode(
FlightStatusCode.INTERNAL, () -> client.getInfo(FlightDescriptor.path("test")));
},
(recorder) -> {
final CallStatus status = recorder.statusFuture.get();
assertNotNull(status);
assertEquals(FlightStatusCode.INTERNAL, status.code());
assertNotNull(status.cause());
assertEquals(new RuntimeException("test").getMessage(), status.cause().getMessage());
});
}
@Test
public void doGetUncaught() {
test(
new ServerErrorProducer(new RuntimeException("test")),
(allocator, client) -> {
try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) {
while (stream.next()) {}
} catch (Exception e) {
fail(e.toString());
}
},
(recorder) -> {
final CallStatus status = recorder.statusFuture.get();
final Throwable err = recorder.errFuture.get();
assertNotNull(status);
assertEquals(FlightStatusCode.OK, status.code());
assertNull(status.cause());
assertNotNull(err);
assertEquals("test", err.getMessage());
});
}
/** A middleware that records the last error on any call. */
static class ErrorRecorder implements FlightServerMiddleware {
CompletableFuture<CallStatus> statusFuture = new CompletableFuture<>();
CompletableFuture<Throwable> errFuture = new CompletableFuture<>();
@Override
public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {}
@Override
public void onCallCompleted(CallStatus status) {
statusFuture.complete(status);
}
@Override
public void onCallErrored(Throwable err) {
errFuture.complete(err);
}
static class Factory implements FlightServerMiddleware.Factory<ErrorRecorder> {
ErrorRecorder instance = new ErrorRecorder();
@Override
public ErrorRecorder onCallStarted(
CallInfo info, CallHeaders incomingHeaders, RequestContext context) {
return instance;
}
}
}
/** A producer that throws the given exception on a call. */
static class ErrorProducer extends NoOpFlightProducer {
final RuntimeException error;
ErrorProducer(RuntimeException t) {
error = t;
}
@Override
public Runnable acceptPut(
CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
return () -> {
// Drain queue to avoid FlightStream#close cancelling the call
while (flightStream.next()) {}
throw error;
};
}
}
/**
* A producer that throws the given exception on a call, but only after sending a success to the
* client.
*/
static class ServerErrorProducer extends NoOpFlightProducer {
final RuntimeException error;
ServerErrorProducer(RuntimeException t) {
error = t;
}
@Override
public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
final VectorSchemaRoot root =
VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) {
listener.start(root);
listener.completed();
}
throw error;
}
@Override
public void listFlights(
CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
listener.onCompleted();
throw error;
}
@Override
public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
throw error;
}
@Override
public Runnable acceptPut(
CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
return () -> {
while (flightStream.next()) {}
ackStream.onCompleted();
throw error;
};
}
@Override
public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
listener.onCompleted();
throw error;
}
@Override
public void listActions(CallContext context, StreamListener<ActionType> listener) {
listener.onCompleted();
throw error;
}
}
static class ServerMiddlewarePair<T extends FlightServerMiddleware> {
final FlightServerMiddleware.Key<T> key;
final FlightServerMiddleware.Factory<T> factory;
ServerMiddlewarePair(
FlightServerMiddleware.Key<T> key, FlightServerMiddleware.Factory<T> factory) {
this.key = key;
this.factory = factory;
}
}
/**
* Spin up a service with the given middleware and producer.
*
* @param producer The Flight producer to use.
* @param middleware A list of middleware to register.
* @param body A function to run as the body of the test.
* @param <T> The middleware type.
*/
static <T extends FlightServerMiddleware> void test(
FlightProducer producer,
List<ServerMiddlewarePair<T>> middleware,
BiConsumer<BufferAllocator, FlightClient> body) {
try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
final FlightServer.Builder builder =
FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer);
middleware.forEach(pair -> builder.middleware(pair.key, pair.factory));
final FlightServer server = builder.build().start();
try (final FlightServer ignored = server;
final FlightClient client =
FlightClient.builder(allocator, server.getLocation()).build()) {
body.accept(allocator, client);
}
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
}
static void test(
FlightProducer producer,
BiConsumer<BufferAllocator, FlightClient> body,
ErrorConsumer<ErrorRecorder> verify) {
final ErrorRecorder.Factory factory = new ErrorRecorder.Factory();
final List<ServerMiddlewarePair<ErrorRecorder>> middleware =
Collections.singletonList(
new ServerMiddlewarePair<>(FlightServerMiddleware.Key.of("m"), factory));
test(
producer,
middleware,
(allocator, client) -> {
body.accept(allocator, client);
try {
verify.accept(factory.instance);
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}
@FunctionalInterface
interface ErrorConsumer<T> {
void accept(T obj) throws Exception;
}
}