EchoServerTest.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.tools;

import static java.util.Arrays.asList;
import static org.apache.arrow.vector.types.Types.MinorType.TINYINT;
import static org.apache.arrow.vector.types.Types.MinorType.VARCHAR;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.net.Socket;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.impl.UnionListWriter;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.testing.ValueVectorDataPopulator;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.ArrowType.Int;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.Text;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

public class EchoServerTest {

  private static EchoServer server;
  private static int serverPort;
  private static Thread serverThread;

  @BeforeAll
  public static void startEchoServer() throws IOException {
    server = new EchoServer(0);
    serverPort = server.port();
    serverThread =
        new Thread() {
          @Override
          public void run() {
            try {
              server.run();
            } catch (IOException e) {
              e.printStackTrace();
            }
          }
        };
    serverThread.start();
  }

  @AfterAll
  public static void stopEchoServer() throws IOException, InterruptedException {
    server.close();
    serverThread.join();
  }

  private void testEchoServer(int serverPort, Field field, TinyIntVector vector, int batches)
      throws UnknownHostException, IOException {
    VectorSchemaRoot root = new VectorSchemaRoot(asList(field), asList((FieldVector) vector), 0);
    try (BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE);
        Socket socket = new Socket("localhost", serverPort);
        ArrowStreamWriter writer = new ArrowStreamWriter(root, null, socket.getOutputStream());
        ArrowStreamReader reader = new ArrowStreamReader(socket.getInputStream(), alloc)) {
      writer.start();
      for (int i = 0; i < batches; i++) {
        vector.allocateNew(16);
        for (int j = 0; j < 8; j++) {
          vector.set(j, j + i);
          vector.set(j + 8, 0, (byte) (j + i));
        }
        vector.setValueCount(16);
        root.setRowCount(16);
        writer.writeBatch();
      }
      writer.end();

      assertEquals(new Schema(asList(field)), reader.getVectorSchemaRoot().getSchema());

      TinyIntVector readVector =
          (TinyIntVector) reader.getVectorSchemaRoot().getFieldVectors().get(0);
      for (int i = 0; i < batches; i++) {
        assertTrue(reader.loadNextBatch());
        assertEquals(16, reader.getVectorSchemaRoot().getRowCount());
        assertEquals(16, readVector.getValueCount());
        for (int j = 0; j < 8; j++) {
          assertEquals(j + i, readVector.get(j));
          assertTrue(readVector.isNull(j + 8));
        }
      }
      assertFalse(reader.loadNextBatch());
      assertEquals(0, reader.getVectorSchemaRoot().getRowCount());
      assertEquals(reader.bytesRead(), writer.bytesWritten());
    }
  }

  @Test
  public void basicTest() throws InterruptedException, IOException {
    BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE);

    Field field =
        new Field(
            "testField",
            new FieldType(true, new ArrowType.Int(8, true), null, null),
            Collections.<Field>emptyList());
    TinyIntVector vector =
        new TinyIntVector("testField", FieldType.nullable(TINYINT.getType()), alloc);

    // Try an empty stream, just the header.
    testEchoServer(serverPort, field, vector, 0);

    // Try with one batch.
    testEchoServer(serverPort, field, vector, 1);

    // Try with a few
    testEchoServer(serverPort, field, vector, 10);
  }

  @Test
  public void testFlatDictionary() throws IOException {
    DictionaryEncoding writeEncoding = new DictionaryEncoding(1L, false, null);
    try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
        IntVector writeVector =
            new IntVector(
                "varchar",
                new FieldType(true, MinorType.INT.getType(), writeEncoding, null),
                allocator);
        VarCharVector writeDictionaryVector =
            new VarCharVector("dict", FieldType.nullable(VARCHAR.getType()), allocator)) {

      ValueVectorDataPopulator.setVector(writeVector, 0, 1, null, 2, 1, 2);
      ValueVectorDataPopulator.setVector(
          writeDictionaryVector,
          "foo".getBytes(StandardCharsets.UTF_8),
          "bar".getBytes(StandardCharsets.UTF_8),
          "baz".getBytes(StandardCharsets.UTF_8));

      List<Field> fields = ImmutableList.of(writeVector.getField());
      List<FieldVector> vectors = ImmutableList.of((FieldVector) writeVector);
      VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors, 6);

      DictionaryProvider writeProvider =
          new MapDictionaryProvider(new Dictionary(writeDictionaryVector, writeEncoding));

      try (Socket socket = new Socket("localhost", serverPort);
          ArrowStreamWriter writer =
              new ArrowStreamWriter(root, writeProvider, socket.getOutputStream());
          ArrowStreamReader reader = new ArrowStreamReader(socket.getInputStream(), allocator)) {
        writer.start();
        writer.writeBatch();
        writer.end();

        reader.loadNextBatch();
        VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot();
        assertEquals(6, readerRoot.getRowCount());

        FieldVector readVector = readerRoot.getFieldVectors().get(0);
        assertNotNull(readVector);

        DictionaryEncoding readEncoding = readVector.getField().getDictionary();
        assertNotNull(readEncoding);
        assertEquals(1L, readEncoding.getId());

        assertEquals(6, readVector.getValueCount());
        assertEquals(0, readVector.getObject(0));
        assertEquals(1, readVector.getObject(1));
        assertEquals(null, readVector.getObject(2));
        assertEquals(2, readVector.getObject(3));
        assertEquals(1, readVector.getObject(4));
        assertEquals(2, readVector.getObject(5));

        Dictionary dictionary = reader.lookup(1L);
        assertNotNull(dictionary);
        VarCharVector dictionaryVector = ((VarCharVector) dictionary.getVector());
        assertEquals(3, dictionaryVector.getValueCount());
        assertEquals(new Text("foo"), dictionaryVector.getObject(0));
        assertEquals(new Text("bar"), dictionaryVector.getObject(1));
        assertEquals(new Text("baz"), dictionaryVector.getObject(2));
      }
    }
  }

  @Test
  public void testNestedDictionary() throws IOException {
    DictionaryEncoding writeEncoding = new DictionaryEncoding(2L, false, null);
    try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
        VarCharVector writeDictionaryVector =
            new VarCharVector("dictionary", FieldType.nullable(VARCHAR.getType()), allocator);
        ListVector writeVector = ListVector.empty("list", allocator)) {

      // data being written:
      // [['foo', 'bar'], ['foo'], ['bar']] -> [[0, 1], [0], [1]]

      writeDictionaryVector.allocateNew();
      writeDictionaryVector.set(0, "foo".getBytes(StandardCharsets.UTF_8));
      writeDictionaryVector.set(1, "bar".getBytes(StandardCharsets.UTF_8));
      writeDictionaryVector.setValueCount(2);

      writeVector.addOrGetVector(new FieldType(true, MinorType.INT.getType(), writeEncoding, null));
      writeVector.allocateNew();
      UnionListWriter listWriter = new UnionListWriter(writeVector);
      listWriter.startList();
      listWriter.writeInt(0);
      listWriter.writeInt(1);
      listWriter.endList();
      listWriter.startList();
      listWriter.writeInt(0);
      listWriter.endList();
      listWriter.startList();
      listWriter.writeInt(1);
      listWriter.endList();
      listWriter.setValueCount(3);

      List<Field> fields = ImmutableList.of(writeVector.getField());
      List<FieldVector> vectors = ImmutableList.of((FieldVector) writeVector);
      VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors, 3);

      DictionaryProvider writeProvider =
          new MapDictionaryProvider(new Dictionary(writeDictionaryVector, writeEncoding));

      try (Socket socket = new Socket("localhost", serverPort);
          ArrowStreamWriter writer =
              new ArrowStreamWriter(root, writeProvider, socket.getOutputStream());
          ArrowStreamReader reader = new ArrowStreamReader(socket.getInputStream(), allocator)) {
        writer.start();
        writer.writeBatch();
        writer.end();

        reader.loadNextBatch();
        VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot();
        assertEquals(3, readerRoot.getRowCount());

        ListVector readVector = (ListVector) readerRoot.getFieldVectors().get(0);
        assertNotNull(readVector);

        assertNull(readVector.getField().getDictionary());
        DictionaryEncoding readEncoding =
            readVector.getField().getChildren().get(0).getDictionary();
        assertNotNull(readEncoding);
        assertEquals(2L, readEncoding.getId());

        Field nestedField = readVector.getField().getChildren().get(0);

        DictionaryEncoding encoding = nestedField.getDictionary();
        assertNotNull(encoding);
        assertEquals(2L, encoding.getId());
        assertEquals(new Int(32, true), encoding.getIndexType());

        assertEquals(3, readVector.getValueCount());
        assertEquals(Arrays.asList(0, 1), readVector.getObject(0));
        assertEquals(Arrays.asList(0), readVector.getObject(1));
        assertEquals(Arrays.asList(1), readVector.getObject(2));

        Dictionary readDictionary = reader.lookup(2L);
        assertNotNull(readDictionary);
        VarCharVector dictionaryVector = ((VarCharVector) readDictionary.getVector());
        assertEquals(2, dictionaryVector.getValueCount());
        assertEquals(new Text("foo"), dictionaryVector.getObject(0));
        assertEquals(new Text("bar"), dictionaryVector.getObject(1));
      }
    }
  }
}