CustomHeaderTest.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight.client;
import static org.junit.jupiter.api.Assertions.assertEquals;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.Map;
import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightCallHeaders;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClient.ClientStreamListener;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightMethod;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.HeaderCallOption;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.RequestContext;
import org.apache.arrow.flight.SyncPutListener;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
/** Tests to ensure custom headers are passed along to the server for each command. */
public class CustomHeaderTest {
FlightServer server;
FlightClient client;
BufferAllocator allocator;
TestCustomHeaderMiddleware.Factory headersMiddleware;
HeaderCallOption headers;
Map<String, String> testHeaders =
ImmutableMap.of(
"foo", "bar",
"bar", "foo",
"answer", "42");
@BeforeEach
public void setUp() throws Exception {
allocator = new RootAllocator(Integer.MAX_VALUE);
headersMiddleware = new TestCustomHeaderMiddleware.Factory();
FlightCallHeaders callHeaders = new FlightCallHeaders();
for (Map.Entry<String, String> entry : testHeaders.entrySet()) {
callHeaders.insert(entry.getKey(), entry.getValue());
}
headers = new HeaderCallOption(callHeaders);
server =
FlightServer.builder(
allocator,
Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, /*port*/ 0),
new NoOpFlightProducer())
.middleware(FlightServerMiddleware.Key.of("customHeader"), headersMiddleware)
.build();
server.start();
client = FlightClient.builder(allocator, server.getLocation()).build();
}
@AfterEach
public void tearDown() throws Exception {
allocator.getChildAllocators().forEach(BufferAllocator::close);
AutoCloseables.close(allocator, server, client);
}
@Test
public void testHandshake() {
try {
client.handshake(headers);
} catch (Exception ignored) {
}
assertHeadersMatch(FlightMethod.HANDSHAKE);
}
@Test
public void testGetSchema() {
try {
client.getSchema(FlightDescriptor.command(new byte[0]), headers);
} catch (Exception ignored) {
}
assertHeadersMatch(FlightMethod.GET_SCHEMA);
}
@Test
public void testGetFlightInfo() {
try {
client.getInfo(FlightDescriptor.command(new byte[0]), headers);
} catch (Exception ignored) {
}
assertHeadersMatch(FlightMethod.GET_FLIGHT_INFO);
}
@Test
public void testListActions() {
try {
client.listActions(headers).iterator().next();
} catch (Exception ignored) {
}
assertHeadersMatch(FlightMethod.LIST_ACTIONS);
}
@Test
public void testListFlights() {
try {
client.listFlights(new Criteria(new byte[] {1}), headers).iterator().next();
} catch (Exception ignored) {
}
assertHeadersMatch(FlightMethod.LIST_FLIGHTS);
}
@Test
public void testDoAction() {
try {
client.doAction(new Action("test"), headers).next();
} catch (Exception ignored) {
}
assertHeadersMatch(FlightMethod.DO_ACTION);
}
@Test
public void testStartPut() {
try {
final ClientStreamListener listener =
client.startPut(FlightDescriptor.command(new byte[0]), new SyncPutListener(), headers);
listener.getResult();
} catch (Exception ignored) {
}
assertHeadersMatch(FlightMethod.DO_PUT);
}
@Test
public void testGetStream() {
try (final FlightStream stream = client.getStream(new Ticket(new byte[0]), headers)) {
stream.next();
} catch (Exception ignored) {
}
assertHeadersMatch(FlightMethod.DO_GET);
}
@Test
public void testDoExchange() {
try (final FlightClient.ExchangeReaderWriter stream =
client.doExchange(FlightDescriptor.command(new byte[0]), headers)) {
stream.getReader().next();
} catch (Exception ignored) {
}
assertHeadersMatch(FlightMethod.DO_EXCHANGE);
}
private void assertHeadersMatch(FlightMethod method) {
for (Map.Entry<String, String> entry : testHeaders.entrySet()) {
assertEquals(entry.getValue(), headersMiddleware.getCustomHeader(method, entry.getKey()));
}
}
/** A middleware used to test if customHeaders are being sent to the server properly. */
static class TestCustomHeaderMiddleware implements FlightServerMiddleware {
public TestCustomHeaderMiddleware() {}
@Override
public void onBeforeSendingHeaders(CallHeaders callHeaders) {}
@Override
public void onCallCompleted(CallStatus callStatus) {}
@Override
public void onCallErrored(Throwable throwable) {}
/**
* A factory for the middleware that keeps track of the received headers and provides a way to
* check those values for a given Flight Method.
*/
static class Factory implements FlightServerMiddleware.Factory<TestCustomHeaderMiddleware> {
private final Map<FlightMethod, CallHeaders> receivedCallHeaders = new HashMap<>();
@Override
public TestCustomHeaderMiddleware onCallStarted(
CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) {
receivedCallHeaders.put(callInfo.method(), callHeaders);
return new TestCustomHeaderMiddleware();
}
public String getCustomHeader(FlightMethod method, String key) {
CallHeaders headers = receivedCallHeaders.get(method);
if (headers == null) {
return null;
}
return headers.get(key);
}
}
}
}