AvroTestBase.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.adapter.avro;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.util.Text;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.BinaryDecoder;
import org.apache.avro.io.BinaryEncoder;
import org.apache.avro.io.DatumWriter;
import org.apache.avro.io.DecoderFactory;
import org.apache.avro.io.EncoderFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.io.TempDir;

public class AvroTestBase {

  @TempDir public File TMP;

  protected AvroToArrowConfig config;

  @BeforeEach
  public void init() {
    BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
    config = new AvroToArrowConfigBuilder(allocator).build();
  }

  public static Schema getSchema(String schemaName) throws Exception {
    try {
      // Attempt to use JDK 9 behavior of getting the module then the resource stream from the
      // module.
      // Note that this code is caller-sensitive.
      Method getModuleMethod = Class.class.getMethod("getModule");
      Object module = getModuleMethod.invoke(TestWriteReadAvroRecord.class);
      Method getResourceAsStreamFromModule =
          module.getClass().getMethod("getResourceAsStream", String.class);
      try (InputStream is =
          (InputStream) getResourceAsStreamFromModule.invoke(module, "/schema/" + schemaName)) {
        return new Schema.Parser().parse(is);
      }
    } catch (NoSuchMethodException ex) {
      // Use JDK8 behavior.
      try (InputStream is =
          TestWriteReadAvroRecord.class.getResourceAsStream("/schema/" + schemaName)) {
        return new Schema.Parser().parse(is);
      }
    }
  }

  protected VectorSchemaRoot writeAndRead(Schema schema, List data) throws Exception {
    File dataFile = new File(TMP, "test.avro");

    try (FileOutputStream fos = new FileOutputStream(dataFile);
        FileInputStream fis = new FileInputStream(dataFile)) {

      BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null);
      DatumWriter<Object> writer = new GenericDatumWriter<>(schema);
      BinaryDecoder decoder = new DecoderFactory().directBinaryDecoder(fis, null);

      for (Object value : data) {
        writer.write(value, encoder);
      }

      return AvroToArrow.avroToArrow(schema, decoder, config);
    }
  }

  protected void checkArrayResult(List<List<?>> expected, ListVector vector) {
    assertEquals(expected.size(), vector.getValueCount());
    for (int i = 0; i < expected.size(); i++) {
      checkArrayElement(expected.get(i), vector.getObject(i));
    }
  }

  protected void checkArrayElement(List expected, List actual) {
    assertEquals(expected.size(), actual.size());
    for (int i = 0; i < expected.size(); i++) {
      Object value1 = expected.get(i);
      Object value2 = actual.get(i);
      if (value1 == null) {
        assertTrue(value2 == null);
        continue;
      }
      if (value2 instanceof byte[]) {
        value2 = ByteBuffer.wrap((byte[]) value2);
      } else if (value2 instanceof Text) {
        value2 = value2.toString();
      }
      assertEquals(value1, value2);
    }
  }

  protected void checkPrimitiveResult(List data, FieldVector vector) {
    assertEquals(data.size(), vector.getValueCount());
    for (int i = 0; i < data.size(); i++) {
      Object value1 = data.get(i);
      Object value2 = vector.getObject(i);
      if (value1 == null) {
        assertTrue(value2 == null);
        continue;
      }
      if (value2 instanceof byte[]) {
        value2 = ByteBuffer.wrap((byte[]) value2);
        if (value1 instanceof byte[]) {
          value1 = ByteBuffer.wrap((byte[]) value1);
        }
      } else if (value2 instanceof Text) {
        value2 = value2.toString();
      } else if (value2 instanceof Byte) {
        value2 = ((Byte) value2).intValue();
      }
      assertEquals(value1, value2);
    }
  }

  protected void checkRecordResult(Schema schema, List<GenericRecord> data, VectorSchemaRoot root) {
    assertEquals(data.size(), root.getRowCount());
    assertEquals(schema.getFields().size(), root.getFieldVectors().size());

    for (int i = 0; i < schema.getFields().size(); i++) {
      ArrayList fieldData = new ArrayList();
      for (GenericRecord record : data) {
        fieldData.add(record.get(i));
      }

      checkPrimitiveResult(fieldData, root.getFieldVectors().get(i));
    }
  }

  protected void checkNestedRecordResult(
      Schema schema, List<GenericRecord> data, VectorSchemaRoot root) {
    assertEquals(data.size(), root.getRowCount());
    assertTrue(schema.getFields().size() == 1);

    final Schema nestedSchema = schema.getFields().get(0).schema();
    final StructVector structVector = (StructVector) root.getFieldVectors().get(0);

    for (int i = 0; i < nestedSchema.getFields().size(); i++) {
      ArrayList fieldData = new ArrayList();
      for (GenericRecord record : data) {
        GenericRecord nestedRecord = (GenericRecord) record.get(0);
        fieldData.add(nestedRecord.get(i));
      }

      checkPrimitiveResult(fieldData, structVector.getChildrenFromFields().get(i));
    }
  }

  // belows are for iterator api

  protected void checkArrayResult(List<List<?>> expected, List<ListVector> vectors) {
    int valueCount = vectors.stream().mapToInt(v -> v.getValueCount()).sum();
    assertEquals(expected.size(), valueCount);

    int index = 0;
    for (ListVector vector : vectors) {
      for (int i = 0; i < vector.getValueCount(); i++) {
        checkArrayElement(expected.get(index++), vector.getObject(i));
      }
    }
  }

  protected void checkRecordResult(
      Schema schema, List<GenericRecord> data, List<VectorSchemaRoot> roots) {
    roots.forEach(
        root -> {
          assertEquals(schema.getFields().size(), root.getFieldVectors().size());
        });

    for (int i = 0; i < schema.getFields().size(); i++) {
      List fieldData = new ArrayList();
      List<FieldVector> vectors = new ArrayList<>();
      for (GenericRecord record : data) {
        fieldData.add(record.get(i));
      }
      final int columnIndex = i;
      roots.forEach(root -> vectors.add(root.getFieldVectors().get(columnIndex)));

      checkPrimitiveResult(fieldData, vectors);
    }
  }

  protected void checkPrimitiveResult(List data, List<FieldVector> vectors) {
    int valueCount = vectors.stream().mapToInt(v -> v.getValueCount()).sum();
    assertEquals(data.size(), valueCount);

    int index = 0;
    for (FieldVector vector : vectors) {
      for (int i = 0; i < vector.getValueCount(); i++) {
        Object value1 = data.get(index++);
        Object value2 = vector.getObject(i);
        if (value1 == null) {
          assertNull(value2);
          continue;
        }
        if (value2 instanceof byte[]) {
          value2 = ByteBuffer.wrap((byte[]) value2);
          if (value1 instanceof byte[]) {
            value1 = ByteBuffer.wrap((byte[]) value1);
          }
        } else if (value2 instanceof Text) {
          value2 = value2.toString();
        }
        assertEquals(value1, value2);
      }
    }
  }
}