TestBasicAuth.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.flight.auth;

import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Optional;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public class TestBasicAuth {

  private static final String USERNAME = "flight";
  private static final String PASSWORD = "woohoo";
  private static final byte[] VALID_TOKEN = "my_token".getBytes(StandardCharsets.UTF_8);

  private FlightClient client;
  private static FlightServer server;
  private static BufferAllocator allocator;

  @Test
  public void validAuth() {
    client.authenticateBasic(USERNAME, PASSWORD);
    assertEquals(0, ImmutableList.copyOf(client.listFlights(Criteria.ALL)).size());
  }

  @Test
  public void asyncCall() throws Exception {
    client.authenticateBasic(USERNAME, PASSWORD);
    client.listFlights(Criteria.ALL);
    try (final FlightStream s = client.getStream(new Ticket(new byte[1]))) {
      while (s.next()) {
        assertEquals(4095, s.getRoot().getRowCount());
      }
    }
  }

  @Test
  public void invalidAuth() {
    FlightTestUtil.assertCode(
        FlightStatusCode.UNAUTHENTICATED,
        () -> {
          client.authenticateBasic(USERNAME, "WRONG");
        });

    FlightTestUtil.assertCode(
        FlightStatusCode.UNAUTHENTICATED,
        () -> {
          client.listFlights(Criteria.ALL).forEach(action -> fail());
        });
  }

  @Test
  public void didntAuth() {
    FlightTestUtil.assertCode(
        FlightStatusCode.UNAUTHENTICATED,
        () -> {
          client.listFlights(Criteria.ALL).forEach(action -> fail());
        });
  }

  @BeforeEach
  public void testSetup() throws IOException {
    client = FlightClient.builder(allocator, server.getLocation()).build();
  }

  @BeforeAll
  public static void setup() throws IOException {
    allocator = new RootAllocator(Long.MAX_VALUE);
    final BasicServerAuthHandler.BasicAuthValidator validator =
        new BasicServerAuthHandler.BasicAuthValidator() {

          @Override
          public Optional<String> isValid(byte[] token) {
            if (Arrays.equals(token, VALID_TOKEN)) {
              return Optional.of(USERNAME);
            }
            return Optional.empty();
          }

          @Override
          public byte[] getToken(String username, String password) {
            if (USERNAME.equals(username) && PASSWORD.equals(password)) {
              return VALID_TOKEN;
            } else {
              throw new IllegalArgumentException("invalid credentials");
            }
          }
        };

    server =
        FlightServer.builder(
                allocator,
                forGrpcInsecure(LOCALHOST, 0),
                new NoOpFlightProducer() {
                  @Override
                  public void listFlights(
                      CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
                    if (!context.peerIdentity().equals(USERNAME)) {
                      listener.onError(new IllegalArgumentException("Invalid username"));
                      return;
                    }
                    listener.onCompleted();
                  }

                  @Override
                  public void getStream(
                      CallContext context, Ticket ticket, ServerStreamListener listener) {
                    if (!context.peerIdentity().equals(USERNAME)) {
                      listener.error(new IllegalArgumentException("Invalid username"));
                      return;
                    }
                    final Schema pojoSchema =
                        new Schema(
                            ImmutableList.of(
                                Field.nullable("a", Types.MinorType.BIGINT.getType())));
                    try (VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator)) {
                      listener.start(root);
                      root.allocateNew();
                      root.setRowCount(4095);
                      listener.putNext();
                      listener.completed();
                    }
                  }
                })
            .authHandler(new BasicServerAuthHandler(validator))
            .build()
            .start();
  }

  @AfterEach
  public void tearDown() throws Exception {
    AutoCloseables.close(client);
  }

  @AfterAll
  public static void shutdown() throws Exception {
    AutoCloseables.close(server);

    allocator.getChildAllocators().forEach(BufferAllocator::close);
    AutoCloseables.close(allocator);
  }
}