TestTls.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight;
import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.function.Consumer;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.junit.jupiter.api.Test;
/** Tests for TLS in Flight. */
public class TestTls {
/** Test a basic request over TLS. */
@Test
public void connectTls() {
test(
(builder) -> {
try (final InputStream roots =
new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
final FlightClient client = builder.trustedCertificates(roots).build()) {
final Iterator<Result> responses = client.doAction(new Action("hello-world"));
final byte[] response = responses.next().getBody();
assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8));
assertFalse(responses.hasNext());
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
});
}
/** Make sure that connections are rejected when the root certificate isn't trusted. */
@Test
public void rejectInvalidCert() {
test(
(builder) -> {
try (final FlightClient client = builder.build()) {
final Iterator<Result> responses = client.doAction(new Action("hello-world"));
FlightTestUtil.assertCode(
FlightStatusCode.UNAVAILABLE, () -> responses.next().getBody());
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
}
/** Make sure that connections are rejected when the hostname doesn't match. */
@Test
public void rejectHostname() {
test(
(builder) -> {
try (final InputStream roots =
new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
final FlightClient client =
builder.trustedCertificates(roots).overrideHostname("fakehostname").build()) {
final Iterator<Result> responses = client.doAction(new Action("hello-world"));
FlightTestUtil.assertCode(
FlightStatusCode.UNAVAILABLE, () -> responses.next().getBody());
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
});
}
/** Test a basic request over TLS. */
@Test
public void connectTlsDisableServerVerification() {
test(
(builder) -> {
try (final FlightClient client = builder.verifyServer(false).build()) {
final Iterator<Result> responses = client.doAction(new Action("hello-world"));
final byte[] response = responses.next().getBody();
assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8));
assertFalse(responses.hasNext());
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
}
void test(Consumer<FlightClient.Builder> testFn) {
final FlightTestUtil.CertKeyPair certKey = FlightTestUtil.exampleTlsCerts().get(0);
try (BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
Producer producer = new Producer();
FlightServer s =
FlightServer.builder(a, forGrpcInsecure(LOCALHOST, 0), producer)
.useTls(certKey.cert, certKey.key)
.build()
.start()) {
final FlightClient.Builder builder =
FlightClient.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, s.getPort()));
testFn.accept(builder);
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
}
static class Producer extends NoOpFlightProducer implements AutoCloseable {
@Override
public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
if (action.getType().equals("hello-world")) {
listener.onNext(new Result("Hello, world!".getBytes(StandardCharsets.UTF_8)));
listener.onCompleted();
return;
}
listener.onError(
CallStatus.UNIMPLEMENTED
.withDescription("Invalid action " + action.getType())
.toRuntimeException());
}
@Override
public void close() {}
}
}