TestBasicAuth2.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight.auth2;
import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
public class TestBasicAuth2 {
private static final String USERNAME_1 = "flight1";
private static final String USERNAME_2 = "flight2";
private static final String NO_USERNAME = "";
private static final String PASSWORD_1 = "woohoo1";
private static final String PASSWORD_2 = "woohoo2";
private static BufferAllocator allocator;
private static FlightServer server;
private static FlightClient client;
private static FlightClient client2;
@BeforeAll
public static void setup() throws Exception {
allocator = new RootAllocator(Long.MAX_VALUE);
startServerAndClient();
}
private static FlightProducer getFlightProducer() {
return new NoOpFlightProducer() {
@Override
public void listFlights(
CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
if (!context.peerIdentity().equals(USERNAME_1)
&& !context.peerIdentity().equals(USERNAME_2)) {
listener.onError(new IllegalArgumentException("Invalid username"));
return;
}
listener.onCompleted();
}
@Override
public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
if (!context.peerIdentity().equals(USERNAME_1)
&& !context.peerIdentity().equals(USERNAME_2)) {
listener.error(new IllegalArgumentException("Invalid username"));
return;
}
final Schema pojoSchema =
new Schema(ImmutableList.of(Field.nullable("a", Types.MinorType.BIGINT.getType())));
try (VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator)) {
listener.start(root);
root.allocateNew();
root.setRowCount(4095);
listener.putNext();
listener.completed();
}
}
};
}
private static void startServerAndClient() throws IOException {
final FlightProducer flightProducer = getFlightProducer();
server =
FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), flightProducer)
.headerAuthenticator(
new GeneratedBearerTokenAuthenticator(
new BasicCallHeaderAuthenticator(TestBasicAuth2::validate)))
.build()
.start();
client = FlightClient.builder(allocator, server.getLocation()).build();
}
@AfterAll
public static void shutdown() throws Exception {
AutoCloseables.close(client, client2, server);
client = null;
client2 = null;
server = null;
allocator.getChildAllocators().forEach(BufferAllocator::close);
AutoCloseables.close(allocator);
allocator = null;
}
private void startClient2() throws IOException {
client2 = FlightClient.builder(allocator, server.getLocation()).build();
}
private static CallHeaderAuthenticator.AuthResult validate(String username, String password) {
if (Strings.isNullOrEmpty(username)) {
throw CallStatus.UNAUTHENTICATED
.withDescription("Credentials not supplied.")
.toRuntimeException();
}
final String identity;
if (USERNAME_1.equals(username) && PASSWORD_1.equals(password)) {
identity = USERNAME_1;
} else if (USERNAME_2.equals(username) && PASSWORD_2.equals(password)) {
identity = USERNAME_2;
} else {
throw CallStatus.UNAUTHENTICATED
.withDescription("Username or password is invalid.")
.toRuntimeException();
}
return () -> identity;
}
@Test
public void validAuthWithBearerAuthServer() throws IOException {
testValidAuth(client);
}
@Test
public void validAuthWithMultipleClientsWithSameCredentialsWithBearerAuthServer()
throws IOException {
startClient2();
testValidAuthWithMultipleClientsWithSameCredentials(client, client2);
}
@Test
public void validAuthWithMultipleClientsWithDifferentCredentialsWithBearerAuthServer()
throws IOException {
startClient2();
testValidAuthWithMultipleClientsWithDifferentCredentials(client, client2);
}
@Test
public void asyncCall() throws Exception {
final CredentialCallOption bearerToken =
client.authenticateBasicToken(USERNAME_1, PASSWORD_1).get();
client.listFlights(Criteria.ALL, bearerToken);
try (final FlightStream s = client.getStream(new Ticket(new byte[1]), bearerToken)) {
while (s.next()) {
assertEquals(4095, s.getRoot().getRowCount());
}
}
}
@Test
public void invalidAuthWithBearerAuthServer() throws IOException {
testInvalidAuth(client);
}
@Test
public void didntAuthWithBearerAuthServer() throws IOException {
didntAuth(client);
}
private void testValidAuth(FlightClient client) {
final CredentialCallOption bearerToken =
client.authenticateBasicToken(USERNAME_1, PASSWORD_1).get();
assertTrue(ImmutableList.copyOf(client.listFlights(Criteria.ALL, bearerToken)).isEmpty());
}
private void testValidAuthWithMultipleClientsWithSameCredentials(
FlightClient client1, FlightClient client2) {
final CredentialCallOption bearerToken1 =
client1.authenticateBasicToken(USERNAME_1, PASSWORD_1).get();
final CredentialCallOption bearerToken2 =
client2.authenticateBasicToken(USERNAME_1, PASSWORD_1).get();
assertTrue(ImmutableList.copyOf(client1.listFlights(Criteria.ALL, bearerToken1)).isEmpty());
assertTrue(ImmutableList.copyOf(client2.listFlights(Criteria.ALL, bearerToken2)).isEmpty());
}
private void testValidAuthWithMultipleClientsWithDifferentCredentials(
FlightClient client1, FlightClient client2) {
final CredentialCallOption bearerToken1 =
client1.authenticateBasicToken(USERNAME_1, PASSWORD_1).get();
final CredentialCallOption bearerToken2 =
client2.authenticateBasicToken(USERNAME_2, PASSWORD_2).get();
assertTrue(ImmutableList.copyOf(client1.listFlights(Criteria.ALL, bearerToken1)).isEmpty());
assertTrue(ImmutableList.copyOf(client2.listFlights(Criteria.ALL, bearerToken2)).isEmpty());
}
private void testInvalidAuth(FlightClient client) {
FlightTestUtil.assertCode(
FlightStatusCode.UNAUTHENTICATED, () -> client.authenticateBasicToken(USERNAME_1, "WRONG"));
FlightTestUtil.assertCode(
FlightStatusCode.UNAUTHENTICATED,
() -> client.authenticateBasicToken(NO_USERNAME, PASSWORD_1));
FlightTestUtil.assertCode(
FlightStatusCode.UNAUTHENTICATED,
() -> client.listFlights(Criteria.ALL).forEach(action -> fail()));
}
private void didntAuth(FlightClient client) {
FlightTestUtil.assertCode(
FlightStatusCode.UNAUTHENTICATED,
() -> client.listFlights(Criteria.ALL).forEach(action -> fail()));
}
}