TestFlightService.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.flight;

import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;

import io.grpc.stub.ServerCallStreamObserver;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Optional;
import java.util.Random;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public class TestFlightService {

  private BufferAllocator allocator;

  @BeforeEach
  public void setup() {
    allocator = new RootAllocator(Long.MAX_VALUE);
  }

  @AfterEach
  public void cleanup() throws Exception {
    AutoCloseables.close(allocator);
  }

  @Test
  public void testFlightServiceWithNoAuthHandlerOrInterceptors() {
    // This test is for ARROW-10491. There was a bug where FlightService would try to access the
    // RequestContext,
    // but the RequestContext was getting set to null because no interceptors were active to
    // initialize it
    // when using FlightService directly rather than starting up a FlightServer.

    // Arrange
    final FlightProducer producer =
        new NoOpFlightProducer() {
          @Override
          public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
            listener.completed();
          }
        };

    // This response observer notifies that the test failed if onError() is called.
    final ServerCallStreamObserver<ArrowMessage> observer =
        new ServerCallStreamObserver<ArrowMessage>() {
          @Override
          public boolean isCancelled() {
            return false;
          }

          @Override
          public void setOnCancelHandler(Runnable runnable) {}

          @Override
          public void setCompression(String s) {}

          @Override
          public boolean isReady() {
            return false;
          }

          @Override
          public void setOnReadyHandler(Runnable runnable) {}

          @Override
          public void disableAutoInboundFlowControl() {}

          @Override
          public void request(int i) {}

          @Override
          public void setMessageCompression(boolean b) {}

          @Override
          public void onNext(ArrowMessage arrowMessage) {}

          @Override
          public void onError(Throwable throwable) {
            fail(throwable);
          }

          @Override
          public void onCompleted() {}
        };
    final FlightService flightService = new FlightService(allocator, producer, null, null);

    // Act
    flightService.doGetCustom(Flight.Ticket.newBuilder().build(), observer);

    // fail() would have been called if an error happened during doGetCustom(), so this test passed.
  }

  @Test
  public void supportsNullSchemas() throws Exception {
    final FlightProducer producer =
        new NoOpFlightProducer() {
          @Override
          public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
            return new FlightInfo(
                null,
                descriptor,
                Collections.emptyList(),
                0,
                0,
                false,
                IpcOption.DEFAULT,
                "foo".getBytes(StandardCharsets.UTF_8));
          }
        };

    try (final FlightServer s =
            FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer)
                .build()
                .start();
        final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
      FlightInfo flightInfo = client.getInfo(FlightDescriptor.path("test"));
      assertEquals(Optional.empty(), flightInfo.getSchemaOptional());
      assertEquals(new Schema(Collections.emptyList()), flightInfo.getSchema());
      assertArrayEquals(flightInfo.getAppMetadata(), "foo".getBytes(StandardCharsets.UTF_8));

      Exception e =
          assertThrows(
              FlightRuntimeException.class, () -> client.getSchema(FlightDescriptor.path("test")));
      assertEquals("No schema is present in FlightInfo", e.getMessage());
    }
  }

  /**
   * Test for GH-41584 where flight defaults for header size was not in sync b\w client and server.
   */
  @Test
  public void testHeaderSizeExchangeInService() throws Exception {
    final FlightProducer producer =
        new NoOpFlightProducer() {
          @Override
          public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
            String longHeader =
                context.getMiddleware(FlightConstants.HEADER_KEY).headers().get("long-header");
            return new FlightInfo(
                null,
                descriptor,
                Collections.emptyList(),
                0,
                0,
                false,
                IpcOption.DEFAULT,
                longHeader.getBytes(StandardCharsets.UTF_8));
          }
        };

    String headerVal = generateRandom(1024 * 10);
    FlightCallHeaders callHeaders = new FlightCallHeaders();
    callHeaders.insert("long-header", headerVal);
    // sever with default header limit same as client
    try (final FlightServer s =
            FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer)
                .build()
                .start();
        final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
      FlightInfo flightInfo =
          client.getInfo(FlightDescriptor.path("test"), new HeaderCallOption(callHeaders));
      assertEquals(Optional.empty(), flightInfo.getSchemaOptional());
      assertEquals(new Schema(Collections.emptyList()), flightInfo.getSchema());
      assertArrayEquals(flightInfo.getAppMetadata(), headerVal.getBytes(StandardCharsets.UTF_8));
    }
    // server with 15kb header limit
    try (final FlightServer s =
            FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer)
                .setMaxHeaderListSize(1024 * 15)
                .build()
                .start();
        final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
      FlightInfo flightInfo =
          client.getInfo(FlightDescriptor.path("test"), new HeaderCallOption(callHeaders));
      assertEquals(Optional.empty(), flightInfo.getSchemaOptional());
      assertEquals(new Schema(Collections.emptyList()), flightInfo.getSchema());
      assertArrayEquals(flightInfo.getAppMetadata(), headerVal.getBytes(StandardCharsets.UTF_8));

      callHeaders.insert("another-header", headerVal + headerVal);
      FlightRuntimeException e =
          assertThrows(
              FlightRuntimeException.class,
              () ->
                  client.getInfo(FlightDescriptor.path("test"), new HeaderCallOption(callHeaders)));
      assertEquals("http2 exception", e.getMessage());
    }
  }

  private static String generateRandom(int size) {
    String aToZ = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890";
    Random random = new Random();
    StringBuilder res = new StringBuilder();
    for (int i = 0; i < size; i++) {
      int randIndex = random.nextInt(aToZ.length());
      res.append(aToZ.charAt(randIndex));
    }
    return res.toString();
  }
}