TestArrowReaderWriterWithCompression.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.compression;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.GenerateSampleData;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
public class TestArrowReaderWriterWithCompression {
private BufferAllocator allocator;
private ByteArrayOutputStream out;
private VectorSchemaRoot root;
@BeforeEach
public void setup() {
if (allocator == null) {
allocator = new RootAllocator(Integer.MAX_VALUE);
}
out = new ByteArrayOutputStream();
root = null;
}
@AfterEach
public void tearDown() {
if (root != null) {
root.close();
}
if (allocator != null) {
allocator.close();
}
if (out != null) {
out.reset();
}
}
private void createAndWriteArrowFile(
DictionaryProvider provider, CompressionUtil.CodecType codecType) throws IOException {
List<Field> fields = new ArrayList<>();
fields.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()), new ArrayList<>()));
root = VectorSchemaRoot.create(new Schema(fields), allocator);
final int rowCount = 10;
GenerateSampleData.generateTestData(root.getVector(0), rowCount);
root.setRowCount(rowCount);
try (final ArrowFileWriter writer =
new ArrowFileWriter(
root,
provider,
Channels.newChannel(out),
new HashMap<>(),
IpcOption.DEFAULT,
CommonsCompressionFactory.INSTANCE,
codecType,
Optional.of(7))) {
writer.start();
writer.writeBatch();
writer.end();
}
}
private void createAndWriteArrowStream(
DictionaryProvider provider, CompressionUtil.CodecType codecType) throws IOException {
List<Field> fields = new ArrayList<>();
fields.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()), new ArrayList<>()));
root = VectorSchemaRoot.create(new Schema(fields), allocator);
final int rowCount = 10;
GenerateSampleData.generateTestData(root.getVector(0), rowCount);
root.setRowCount(rowCount);
try (final ArrowStreamWriter writer =
new ArrowStreamWriter(
root,
provider,
Channels.newChannel(out),
IpcOption.DEFAULT,
CommonsCompressionFactory.INSTANCE,
codecType,
Optional.of(7))) {
writer.start();
writer.writeBatch();
writer.end();
}
}
private Dictionary createDictionary(VarCharVector dictionaryVector) {
setVector(
dictionaryVector,
"foo".getBytes(StandardCharsets.UTF_8),
"bar".getBytes(StandardCharsets.UTF_8),
"baz".getBytes(StandardCharsets.UTF_8));
return new Dictionary(
dictionaryVector,
new DictionaryEncoding(/* id= */ 1L, /* ordered= */ false, /* indexType= */ null));
}
@Test
public void testArrowFileZstdRoundTrip() throws Exception {
createAndWriteArrowFile(null, CompressionUtil.CodecType.ZSTD);
// with compression
try (ArrowFileReader reader =
new ArrowFileReader(
new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator,
CommonsCompressionFactory.INSTANCE)) {
assertEquals(1, reader.getRecordBlocks().size());
assertTrue(reader.loadNextBatch());
assertTrue(root.equals(reader.getVectorSchemaRoot()));
assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowFileReader reader =
new ArrowFileReader(
new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator,
NoCompressionCodec.Factory.INSTANCE)) {
assertEquals(1, reader.getRecordBlocks().size());
Exception exception = assertThrows(IllegalArgumentException.class, reader::loadNextBatch);
assertEquals(
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage());
}
}
@Test
public void testArrowStreamZstdRoundTrip() throws Exception {
createAndWriteArrowStream(null, CompressionUtil.CodecType.ZSTD);
// with compression
try (ArrowStreamReader reader =
new ArrowStreamReader(
new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator,
CommonsCompressionFactory.INSTANCE)) {
assertTrue(reader.loadNextBatch());
assertTrue(root.equals(reader.getVectorSchemaRoot()));
assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowStreamReader reader =
new ArrowStreamReader(
new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Exception exception = assertThrows(IllegalArgumentException.class, reader::loadNextBatch);
assertEquals(
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage());
}
}
@Test
public void testArrowFileZstdRoundTripWithDictionary() throws Exception {
VarCharVector dictionaryVector =
(VarCharVector)
FieldType.nullable(new ArrowType.Utf8())
.createNewSingleVector("f1_file", allocator, null);
Dictionary dictionary = createDictionary(dictionaryVector);
DictionaryProvider.MapDictionaryProvider provider =
new DictionaryProvider.MapDictionaryProvider();
provider.put(dictionary);
createAndWriteArrowFile(provider, CompressionUtil.CodecType.ZSTD);
// with compression
try (ArrowFileReader reader =
new ArrowFileReader(
new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator,
CommonsCompressionFactory.INSTANCE)) {
assertEquals(1, reader.getRecordBlocks().size());
assertTrue(reader.loadNextBatch());
assertTrue(root.equals(reader.getVectorSchemaRoot()));
assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowFileReader reader =
new ArrowFileReader(
new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator,
NoCompressionCodec.Factory.INSTANCE)) {
assertEquals(1, reader.getRecordBlocks().size());
Exception exception = assertThrows(IllegalArgumentException.class, reader::loadNextBatch);
assertEquals(
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage());
}
dictionaryVector.close();
}
@Test
public void testArrowStreamZstdRoundTripWithDictionary() throws Exception {
VarCharVector dictionaryVector =
(VarCharVector)
FieldType.nullable(new ArrowType.Utf8())
.createNewSingleVector("f1_stream", allocator, null);
Dictionary dictionary = createDictionary(dictionaryVector);
DictionaryProvider.MapDictionaryProvider provider =
new DictionaryProvider.MapDictionaryProvider();
provider.put(dictionary);
createAndWriteArrowStream(provider, CompressionUtil.CodecType.ZSTD);
// with compression
try (ArrowStreamReader reader =
new ArrowStreamReader(
new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator,
CommonsCompressionFactory.INSTANCE)) {
assertTrue(reader.loadNextBatch());
assertTrue(root.equals(reader.getVectorSchemaRoot()));
assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowStreamReader reader =
new ArrowStreamReader(
new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Exception exception = assertThrows(IllegalArgumentException.class, reader::loadNextBatch);
assertEquals(
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage());
}
dictionaryVector.close();
}
public static void setVector(VarCharVector vector, byte[]... values) {
final int length = values.length;
vector.allocateNewSafe();
for (int i = 0; i < length; i++) {
if (values[i] != null) {
vector.set(i, values[i]);
}
}
vector.setValueCount(length);
}
}