TestBasicOperation.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.flight;

import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import com.google.common.base.Charsets;
import com.google.protobuf.ByteString;
import io.grpc.MethodDescriptor;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import org.apache.arrow.flight.FlightClient.ClientStreamListener;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.Flight.FlightDescriptor.DescriptorType;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnOs;
import org.junit.jupiter.api.condition.OS;

/** Test the operations of a basic flight service. */
public class TestBasicOperation {

  @Test
  public void fastPathDefaults() {
    assertTrue(ArrowMessage.ENABLE_ZERO_COPY_READ);
    assertFalse(ArrowMessage.ENABLE_ZERO_COPY_WRITE);
  }

  @Test
  public void fallbackLocation() {
    assertEquals(
        "arrow-flight-reuse-connection://?", Location.reuseConnection().getUri().toString());
  }

  /** ARROW-6017: we should be able to construct locations for unknown schemes. */
  @Test
  public void unknownScheme() throws URISyntaxException {
    final Location location = new Location("s3://unknown");
    assertEquals("s3", location.getUri().getScheme());
  }

  @Test
  public void unknownSchemeRemote() throws Exception {
    test(
        c -> {
          try {
            final FlightInfo info = c.getInfo(FlightDescriptor.path("test"));
            assertEquals(
                new URI("https://example.com"),
                info.getEndpoints().get(0).getLocations().get(0).getUri());
          } catch (URISyntaxException e) {
            throw new RuntimeException(e);
          }
        });
  }

  @Test
  public void roundTripTicket() throws Exception {
    final Ticket ticket = new Ticket(new byte[] {0, 1, 2, 3, 4, 5});
    assertEquals(ticket, Ticket.deserialize(ticket.serialize()));
  }

  @Test
  public void roundTripInfo() throws Exception {
    final Map<String, String> metadata = new HashMap<>();
    metadata.put("foo", "bar");
    final Schema schema =
        new Schema(
            Arrays.asList(
                Field.nullable("a", new ArrowType.Int(32, true)),
                Field.nullable("b", new ArrowType.FixedSizeBinary(32))),
            metadata);
    final FlightInfo info1 =
        FlightInfo.builder(schema, FlightDescriptor.path(), Collections.emptyList())
            .setAppMetadata("foo".getBytes(StandardCharsets.UTF_8))
            .build();
    final FlightInfo info2 =
        new FlightInfo(
            schema,
            FlightDescriptor.command(new byte[2]),
            Collections.singletonList(
                FlightEndpoint.builder(
                        new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock"))
                    .setAppMetadata("bar".getBytes(StandardCharsets.UTF_8))
                    .build()),
            200,
            500);
    final FlightInfo info3 =
        new FlightInfo(
            schema,
            FlightDescriptor.path("a", "b"),
            Arrays.asList(
                new FlightEndpoint(
                    new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock")),
                new FlightEndpoint(
                    new Ticket(new byte[10]),
                    Location.forGrpcDomainSocket("/tmp/test.sock"),
                    forGrpcInsecure("localhost", 50051))),
            200,
            500);
    final FlightInfo info4 =
        new FlightInfo(
            schema,
            FlightDescriptor.path("a", "b"),
            Arrays.asList(
                new FlightEndpoint(
                    new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock")),
                new FlightEndpoint(
                    new Ticket(new byte[10]),
                    Location.forGrpcDomainSocket("/tmp/test.sock"),
                    forGrpcInsecure("localhost", 50051))),
            200,
            500, /*ordered*/
            true,
            IpcOption.DEFAULT);

    assertEquals(info1, FlightInfo.deserialize(info1.serialize()));
    assertEquals(info2, FlightInfo.deserialize(info2.serialize()));
    assertEquals(info3, FlightInfo.deserialize(info3.serialize()));
    assertEquals(info4, FlightInfo.deserialize(info4.serialize()));

    assertNotEquals(info3, info4);

    assertFalse(info1.getOrdered());
    assertFalse(info2.getOrdered());
    assertFalse(info3.getOrdered());
    assertTrue(info4.getOrdered());
  }

  @Test
  public void roundTripDescriptor() throws Exception {
    final FlightDescriptor cmd =
        FlightDescriptor.command("test command".getBytes(StandardCharsets.UTF_8));
    assertEquals(cmd, FlightDescriptor.deserialize(cmd.serialize()));
    final FlightDescriptor path = FlightDescriptor.path("foo", "bar", "test.arrow");
    assertEquals(path, FlightDescriptor.deserialize(path.serialize()));
  }

  @Test
  public void getDescriptors() throws Exception {
    test(
        c -> {
          int count = 0;
          for (FlightInfo unused : c.listFlights(Criteria.ALL)) {
            count += 1;
          }
          assertEquals(1, count);
        });
  }

  @Test
  public void getDescriptorsWithCriteria() throws Exception {
    test(
        c -> {
          int count = 0;
          for (FlightInfo unused : c.listFlights(new Criteria(new byte[] {1}))) {

            count += 1;
          }
          assertEquals(0, count);
        });
  }

  @Test
  public void getDescriptor() throws Exception {
    test(
        c -> {
          System.out.println(c.getInfo(FlightDescriptor.path("hello")).getDescriptor());
        });
  }

  @Test
  public void getSchema() throws Exception {
    test(
        c -> {
          System.out.println(c.getSchema(FlightDescriptor.path("hello")).getSchema());
        });
  }

  @Test
  public void listActions() throws Exception {
    test(
        c -> {
          for (ActionType at : c.listActions()) {
            System.out.println(at.getType());
          }
        });
  }

  @Test
  public void doAction() throws Exception {
    test(
        c -> {
          Iterator<Result> stream = c.doAction(new Action("hello"));

          assertTrue(stream.hasNext());
          Result r = stream.next();
          assertArrayEquals("world".getBytes(Charsets.UTF_8), r.getBody());
        });
    test(
        c -> {
          Iterator<Result> stream = c.doAction(new Action("hellooo"));

          assertTrue(stream.hasNext());
          Result r = stream.next();
          assertArrayEquals("world".getBytes(Charsets.UTF_8), r.getBody());

          assertTrue(stream.hasNext());
          r = stream.next();
          assertArrayEquals("!".getBytes(Charsets.UTF_8), r.getBody());
          assertFalse(stream.hasNext());
        });
  }

  @Test
  public void putStream() throws Exception {
    test(
        (c, a) -> {
          final int size = 10;

          IntVector iv = new IntVector("c1", a);

          try (VectorSchemaRoot root = VectorSchemaRoot.of(iv)) {
            ClientStreamListener listener =
                c.startPut(FlightDescriptor.path("hello"), root, new AsyncPutListener());

            // batch 1
            root.allocateNew();
            for (int i = 0; i < size; i++) {
              iv.set(i, i);
            }
            iv.setValueCount(size);
            root.setRowCount(size);
            listener.putNext();

            // batch 2

            root.allocateNew();
            for (int i = 0; i < size; i++) {
              iv.set(i, i + size);
            }
            iv.setValueCount(size);
            root.setRowCount(size);
            listener.putNext();
            root.clear();
            listener.completed();

            // wait for ack to avoid memory leaks.
            listener.getResult();
          }
        });
  }

  @Test
  public void propagateErrors() throws Exception {
    test(
        client -> {
          FlightTestUtil.assertCode(
              FlightStatusCode.UNIMPLEMENTED,
              () -> {
                client.doAction(new Action("invalid-action")).forEachRemaining(action -> fail());
              });
        });
  }

  @Test
  public void getStream() throws Exception {
    test(
        c -> {
          try (final FlightStream stream = c.getStream(new Ticket(new byte[0]))) {
            VectorSchemaRoot root = stream.getRoot();
            IntVector iv = (IntVector) root.getVector("c1");
            int value = 0;
            while (stream.next()) {
              for (int i = 0; i < root.getRowCount(); i++) {
                assertEquals(value, iv.get(i));
                value++;
              }
            }
          } catch (Exception e) {
            throw new RuntimeException(e);
          }
        });
  }

  /** Ensure the client is configured to accept large messages. */
  @Test
  @DisabledOnOs(
      value = {OS.WINDOWS},
      disabledReason = "https://github.com/apache/arrow/issues/33237: flaky test")
  public void getStreamLargeBatch() throws Exception {
    test(
        c -> {
          try (final FlightStream stream = c.getStream(new Ticket(Producer.TICKET_LARGE_BATCH))) {
            assertEquals(128, stream.getRoot().getFieldVectors().size());
            assertTrue(stream.next());
            assertEquals(65536, stream.getRoot().getRowCount());
            assertTrue(stream.next());
            assertEquals(65536, stream.getRoot().getRowCount());
            assertFalse(stream.next());
          } catch (Exception e) {
            throw new RuntimeException(e);
          }
        });
  }

  /** Ensure the server is configured to accept large messages. */
  @Test
  public void startPutLargeBatch() throws Exception {
    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
      final List<FieldVector> vectors = new ArrayList<>();
      for (int col = 0; col < 128; col++) {
        final BigIntVector vector = new BigIntVector("f" + col, allocator);
        for (int row = 0; row < 65536; row++) {
          vector.setSafe(row, row);
        }
        vectors.add(vector);
      }
      test(
          c -> {
            try (final VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {
              root.setRowCount(65536);
              final ClientStreamListener stream =
                  c.startPut(FlightDescriptor.path(""), root, new SyncPutListener());
              stream.putNext();
              stream.putNext();
              stream.completed();
              stream.getResult();
            } catch (Exception e) {
              throw new RuntimeException(e);
            }
          });
    }
  }

  private void test(Consumer<FlightClient> consumer) throws Exception {
    test(
        (c, a) -> {
          consumer.accept(c);
        });
  }

  private void test(BiConsumer<FlightClient, BufferAllocator> consumer) throws Exception {
    try (BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
        Producer producer = new Producer(a);
        FlightServer s =
            FlightServer.builder(a, forGrpcInsecure(LOCALHOST, 0), producer).build().start()) {

      try (FlightClient c = FlightClient.builder(a, s.getLocation()).build()) {
        try (BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE)) {
          consumer.accept(c, testAllocator);
        }
      }
    }
  }

  /** Helper method to convert an ArrowMessage into a Protobuf message. */
  private Flight.FlightData arrowMessageToProtobuf(
      MethodDescriptor.Marshaller<ArrowMessage> marshaller, ArrowMessage message)
      throws IOException {
    final ByteArrayOutputStream baos = new ByteArrayOutputStream();
    try (final InputStream serialized = marshaller.stream(message)) {
      final byte[] buf = new byte[1024];
      while (true) {
        int read = serialized.read(buf);
        if (read < 0) {
          break;
        }
        baos.write(buf, 0, read);
      }
    }
    final byte[] serializedMessage = baos.toByteArray();
    return Flight.FlightData.parseFrom(serializedMessage);
  }

  /**
   * ARROW-10962: accept FlightData messages generated by Protobuf (which can omit empty fields).
   */
  @Test
  public void testProtobufRecordBatchCompatibility() throws Exception {
    final Schema schema =
        new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true))));
    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
        final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
      final VectorUnloader unloader = new VectorUnloader(root);
      root.setRowCount(0);
      final MethodDescriptor.Marshaller<ArrowMessage> marshaller =
          ArrowMessage.createMarshaller(allocator);
      try (final ArrowMessage message =
          new ArrowMessage(
              unloader.getRecordBatch(), /* appMetadata */
              null, /* tryZeroCopy */
              false,
              IpcOption.DEFAULT)) {
        assertEquals(ArrowMessage.HeaderType.RECORD_BATCH, message.getMessageType());
        // Should have at least one empty body buffer (there may be multiple for e.g. data and
        // validity)
        Iterator<ArrowBuf> iterator = message.getBufs().iterator();
        assertTrue(iterator.hasNext());
        while (iterator.hasNext()) {
          assertEquals(0, iterator.next().capacity());
        }
        final Flight.FlightData protobufData =
            arrowMessageToProtobuf(marshaller, message).toBuilder().clearDataBody().build();
        assertEquals(0, protobufData.getDataBody().size());
        ArrowMessage parsedMessage =
            marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray()));
        // Should have an empty body buffer
        Iterator<ArrowBuf> parsedIterator = parsedMessage.getBufs().iterator();
        assertTrue(parsedIterator.hasNext());
        assertEquals(0, parsedIterator.next().capacity());
        // Should have only one (the parser synthesizes exactly one); in the case of empty buffers,
        // this is equivalent
        assertFalse(parsedIterator.hasNext());
        // Should not throw
        final ArrowRecordBatch rb = parsedMessage.asRecordBatch();
        assertEquals(rb.computeBodyLength(), 0);
      }
    }
  }

  /**
   * ARROW-10962: accept FlightData messages generated by Protobuf (which can omit empty fields).
   */
  @Test
  public void testProtobufSchemaCompatibility() throws Exception {
    final Schema schema =
        new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true))));
    try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
      final MethodDescriptor.Marshaller<ArrowMessage> marshaller =
          ArrowMessage.createMarshaller(allocator);
      Flight.FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]).toProtocol();
      try (final ArrowMessage message = new ArrowMessage(descriptor, schema, IpcOption.DEFAULT)) {
        assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType());
        // Should have no body buffers
        assertFalse(message.getBufs().iterator().hasNext());
        final Flight.FlightData protobufData =
            arrowMessageToProtobuf(marshaller, message).toBuilder()
                .setDataBody(ByteString.EMPTY)
                .build();
        assertEquals(0, protobufData.getDataBody().size());
        final ArrowMessage parsedMessage =
            marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray()));
        // Should have no body buffers
        assertFalse(parsedMessage.getBufs().iterator().hasNext());
        // Should not throw
        parsedMessage.asSchema();
      }
    }
  }

  @Test
  public void testGrpcInsecureLocation() throws Exception {
    Location location = Location.forGrpcInsecure(LOCALHOST, 9000);
    assertEquals(
        new URI(LocationSchemes.GRPC_INSECURE, null, LOCALHOST, 9000, null, null, null),
        location.getUri());
    assertEquals(new InetSocketAddress(LOCALHOST, 9000), location.toSocketAddress());
  }

  @Test
  public void testGrpcTlsLocation() throws Exception {
    Location location = Location.forGrpcTls(LOCALHOST, 9000);
    assertEquals(
        new URI(LocationSchemes.GRPC_TLS, null, LOCALHOST, 9000, null, null, null),
        location.getUri());
    assertEquals(new InetSocketAddress(LOCALHOST, 9000), location.toSocketAddress());
  }

  /** An example FlightProducer for test purposes. */
  public static class Producer implements FlightProducer, AutoCloseable {
    static final byte[] TICKET_LARGE_BATCH = "large-batch".getBytes(StandardCharsets.UTF_8);

    private final BufferAllocator allocator;

    public Producer(BufferAllocator allocator) {
      super();
      this.allocator = allocator;
    }

    @Override
    public void listFlights(
        CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
      if (criteria.getExpression().length > 0) {
        // Don't send anything if criteria are set
        listener.onCompleted();
      }

      Flight.FlightInfo getInfo =
          Flight.FlightInfo.newBuilder()
              .setFlightDescriptor(
                  Flight.FlightDescriptor.newBuilder()
                      .setType(DescriptorType.CMD)
                      .setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8)))
              .build();
      try {
        listener.onNext(new FlightInfo(getInfo));
      } catch (URISyntaxException e) {
        listener.onError(e);
        return;
      }
      listener.onCompleted();
    }

    @Override
    public Runnable acceptPut(
        CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
      return () -> {
        while (flightStream.next()) {
          // Drain the stream
        }
      };
    }

    @Override
    public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
      if (Arrays.equals(TICKET_LARGE_BATCH, ticket.getBytes())) {
        getLargeBatch(listener);
        return;
      }
      final int size = 10;

      IntVector iv = new IntVector("c1", allocator);
      VectorSchemaRoot root = VectorSchemaRoot.of(iv);
      listener.start(root);

      // batch 1
      root.allocateNew();
      for (int i = 0; i < size; i++) {
        iv.set(i, i);
      }
      iv.setValueCount(size);
      root.setRowCount(size);
      listener.putNext();

      // batch 2

      root.allocateNew();
      for (int i = 0; i < size; i++) {
        iv.set(i, i + size);
      }
      iv.setValueCount(size);
      root.setRowCount(size);
      listener.putNext();
      root.clear();
      listener.completed();
    }

    private void getLargeBatch(ServerStreamListener listener) {
      final List<FieldVector> vectors = new ArrayList<>();
      for (int col = 0; col < 128; col++) {
        final BigIntVector vector = new BigIntVector("f" + col, allocator);
        for (int row = 0; row < 65536; row++) {
          vector.setSafe(row, row);
        }
        vectors.add(vector);
      }
      try (final VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {
        root.setRowCount(65536);
        listener.start(root);
        listener.putNext();
        listener.putNext();
        listener.completed();
      }
    }

    @Override
    public void close() throws Exception {
      allocator.close();
    }

    @Override
    public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
      try {
        Flight.FlightInfo getInfo =
            Flight.FlightInfo.newBuilder()
                .setSchema(schemaToByteString(new Schema(Collections.emptyList())))
                .setFlightDescriptor(
                    Flight.FlightDescriptor.newBuilder()
                        .setType(DescriptorType.CMD)
                        .setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8)))
                .addEndpoint(
                    Flight.FlightEndpoint.newBuilder()
                        .addLocation(new Location("https://example.com").toProtocol()))
                .build();
        return new FlightInfo(getInfo);
      } catch (URISyntaxException e) {
        throw new RuntimeException(e);
      }
    }

    private static ByteString schemaToByteString(Schema schema) {
      try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
        MessageSerializer.serialize(
            new WriteChannel(Channels.newChannel(baos)), schema, IpcOption.DEFAULT);
        return ByteString.copyFrom(baos.toByteArray());
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
    }

    @Override
    public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
      switch (action.getType()) {
        case "hello":
          {
            listener.onNext(new Result("world".getBytes(Charsets.UTF_8)));
            listener.onCompleted();
            break;
          }
        case "hellooo":
          {
            listener.onNext(new Result("world".getBytes(Charsets.UTF_8)));
            listener.onNext(new Result("!".getBytes(Charsets.UTF_8)));
            listener.onCompleted();
            break;
          }
        default:
          listener.onError(
              CallStatus.UNIMPLEMENTED
                  .withDescription("Action not implemented: " + action.getType())
                  .toRuntimeException());
      }
    }

    @Override
    public void listActions(CallContext context, StreamListener<ActionType> listener) {
      listener.onNext(new ActionType("get", ""));
      listener.onNext(new ActionType("put", ""));
      listener.onNext(new ActionType("hello", ""));
      listener.onCompleted();
    }
  }
}