TestCompressionCodec.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.compression;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.stream.Stream;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

/** Test cases for {@link CompressionCodec}s. */
class TestCompressionCodec {
  private BufferAllocator allocator;

  @BeforeEach
  void init() {
    allocator = new RootAllocator(Integer.MAX_VALUE);
  }

  @AfterEach
  void terminate() {
    allocator.close();
  }

  static Collection<Arguments> codecs() {
    List<Arguments> params = new ArrayList<>();

    int[] lengths = new int[] {10, 100, 1000};
    for (int len : lengths) {
      CompressionCodec dumbCodec = NoCompressionCodec.INSTANCE;
      params.add(Arguments.arguments(len, dumbCodec));

      CompressionCodec lz4Codec = new Lz4CompressionCodec();
      params.add(Arguments.arguments(len, lz4Codec));

      CompressionCodec zstdCodec = new ZstdCompressionCodec();
      params.add(Arguments.arguments(len, zstdCodec));

      CompressionCodec zstdCodecAndCompressionLevel = new ZstdCompressionCodec(7);
      params.add(Arguments.arguments(len, zstdCodecAndCompressionLevel));
    }
    return params;
  }

  private List<ArrowBuf> compressBuffers(CompressionCodec codec, List<ArrowBuf> inputBuffers) {
    List<ArrowBuf> outputBuffers = new ArrayList<>(inputBuffers.size());
    for (ArrowBuf buf : inputBuffers) {
      outputBuffers.add(codec.compress(allocator, buf));
    }
    return outputBuffers;
  }

  private List<ArrowBuf> deCompressBuffers(CompressionCodec codec, List<ArrowBuf> inputBuffers) {
    List<ArrowBuf> outputBuffers = new ArrayList<>(inputBuffers.size());
    for (ArrowBuf buf : inputBuffers) {
      outputBuffers.add(codec.decompress(allocator, buf));
    }
    return outputBuffers;
  }

  private void assertWriterIndex(List<ArrowBuf> decompressedBuffers) {
    for (ArrowBuf decompressedBuf : decompressedBuffers) {
      assertTrue(decompressedBuf.writerIndex() > 0);
    }
  }

  @ParameterizedTest
  @MethodSource("codecs")
  void testCompressFixedWidthBuffers(int vectorLength, CompressionCodec codec) throws Exception {
    // prepare vector to compress
    IntVector origVec = new IntVector("vec", allocator);
    origVec.allocateNew(vectorLength);
    for (int i = 0; i < vectorLength; i++) {
      if (i % 10 == 0) {
        origVec.setNull(i);
      } else {
        origVec.set(i, i);
      }
    }
    origVec.setValueCount(vectorLength);
    int nullCount = origVec.getNullCount();

    // compress & decompress
    List<ArrowBuf> origBuffers = origVec.getFieldBuffers();
    List<ArrowBuf> compressedBuffers = compressBuffers(codec, origBuffers);
    List<ArrowBuf> decompressedBuffers = deCompressBuffers(codec, compressedBuffers);

    assertEquals(2, decompressedBuffers.size());
    assertWriterIndex(decompressedBuffers);

    // orchestrate new vector
    IntVector newVec = new IntVector("new vec", allocator);
    newVec.loadFieldBuffers(new ArrowFieldNode(vectorLength, nullCount), decompressedBuffers);

    // verify new vector
    assertEquals(vectorLength, newVec.getValueCount());
    for (int i = 0; i < vectorLength; i++) {
      if (i % 10 == 0) {
        assertTrue(newVec.isNull(i));
      } else {
        assertEquals(i, newVec.get(i));
      }
    }

    newVec.close();
    AutoCloseables.close(decompressedBuffers);
  }

  @ParameterizedTest
  @MethodSource("codecs")
  void testCompressVariableWidthBuffers(int vectorLength, CompressionCodec codec) throws Exception {
    // prepare vector to compress
    VarCharVector origVec = new VarCharVector("vec", allocator);
    origVec.allocateNew();
    for (int i = 0; i < vectorLength; i++) {
      if (i % 10 == 0) {
        origVec.setNull(i);
      } else {
        origVec.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8));
      }
    }
    origVec.setValueCount(vectorLength);
    int nullCount = origVec.getNullCount();

    // compress & decompress
    List<ArrowBuf> origBuffers = origVec.getFieldBuffers();
    List<ArrowBuf> compressedBuffers = compressBuffers(codec, origBuffers);
    List<ArrowBuf> decompressedBuffers = deCompressBuffers(codec, compressedBuffers);

    assertEquals(3, decompressedBuffers.size());
    assertWriterIndex(decompressedBuffers);

    // orchestrate new vector
    VarCharVector newVec = new VarCharVector("new vec", allocator);
    newVec.loadFieldBuffers(new ArrowFieldNode(vectorLength, nullCount), decompressedBuffers);

    // verify new vector
    assertEquals(vectorLength, newVec.getValueCount());
    for (int i = 0; i < vectorLength; i++) {
      if (i % 10 == 0) {
        assertTrue(newVec.isNull(i));
      } else {
        assertArrayEquals(String.valueOf(i).getBytes(StandardCharsets.UTF_8), newVec.get(i));
      }
    }

    newVec.close();
    AutoCloseables.close(decompressedBuffers);
  }

  @ParameterizedTest
  @MethodSource("codecs")
  void testEmptyBuffer(int vectorLength, CompressionCodec codec) throws Exception {
    final VarBinaryVector origVec = new VarBinaryVector("vec", allocator);

    origVec.allocateNew(vectorLength);

    // Do not set any values (all missing)
    origVec.setValueCount(vectorLength);

    final List<ArrowBuf> origBuffers = origVec.getFieldBuffers();
    final List<ArrowBuf> compressedBuffers = compressBuffers(codec, origBuffers);
    final List<ArrowBuf> decompressedBuffers = deCompressBuffers(codec, compressedBuffers);

    // orchestrate new vector
    VarBinaryVector newVec = new VarBinaryVector("new vec", allocator);
    newVec.loadFieldBuffers(new ArrowFieldNode(vectorLength, vectorLength), decompressedBuffers);

    // verify new vector
    assertEquals(vectorLength, newVec.getValueCount());
    for (int i = 0; i < vectorLength; i++) {
      assertTrue(newVec.isNull(i));
    }

    newVec.close();
    AutoCloseables.close(decompressedBuffers);
  }

  private static Stream<CompressionUtil.CodecType> codecTypes() {
    return Arrays.stream(CompressionUtil.CodecType.values());
  }

  @ParameterizedTest
  @MethodSource("codecTypes")
  void testReadWriteStream(CompressionUtil.CodecType codec) throws Exception {
    withRoot(
        codec,
        (factory, root) -> {
          ByteArrayOutputStream compressedStream = new ByteArrayOutputStream();
          try (final ArrowStreamWriter writer =
              new ArrowStreamWriter(
                  root,
                  new DictionaryProvider.MapDictionaryProvider(),
                  Channels.newChannel(compressedStream),
                  IpcOption.DEFAULT,
                  factory,
                  codec,
                  Optional.of(7))) {
            writer.start();
            writer.writeBatch();
            writer.end();
          } catch (IOException e) {
            throw new RuntimeException(e);
          }

          try (ArrowStreamReader reader =
              new ArrowStreamReader(
                  new ByteArrayReadableSeekableByteChannel(compressedStream.toByteArray()),
                  allocator,
                  factory)) {
            assertTrue(reader.loadNextBatch());
            assertTrue(root.equals(reader.getVectorSchemaRoot()));
            assertFalse(reader.loadNextBatch());
          } catch (IOException e) {
            throw new RuntimeException(e);
          }
        });
  }

  @ParameterizedTest
  @MethodSource("codecTypes")
  void testReadWriteFile(CompressionUtil.CodecType codec) throws Exception {
    withRoot(
        codec,
        (factory, root) -> {
          ByteArrayOutputStream compressedStream = new ByteArrayOutputStream();
          try (final ArrowFileWriter writer =
              new ArrowFileWriter(
                  root,
                  new DictionaryProvider.MapDictionaryProvider(),
                  Channels.newChannel(compressedStream),
                  new HashMap<>(),
                  IpcOption.DEFAULT,
                  factory,
                  codec,
                  Optional.of(7))) {
            writer.start();
            writer.writeBatch();
            writer.end();
          } catch (IOException e) {
            throw new RuntimeException(e);
          }

          try (ArrowFileReader reader =
              new ArrowFileReader(
                  new ByteArrayReadableSeekableByteChannel(compressedStream.toByteArray()),
                  allocator,
                  factory)) {
            assertTrue(reader.loadNextBatch());
            assertTrue(root.equals(reader.getVectorSchemaRoot()));
            assertFalse(reader.loadNextBatch());
          } catch (IOException e) {
            throw new RuntimeException(e);
          }
        });
  }

  /** Unloading a vector should not free source buffers. */
  @ParameterizedTest
  @MethodSource("codecTypes")
  void testUnloadCompressed(CompressionUtil.CodecType codec) {
    withRoot(
        codec,
        (factory, root) -> {
          root.getFieldVectors()
              .forEach(
                  (vector) -> {
                    Arrays.stream(vector.getBuffers(/*clear*/ false))
                        .forEach(
                            (buf) -> {
                              assertNotEquals(0, buf.getReferenceManager().getRefCount());
                            });
                  });

          final VectorUnloader unloader =
              new VectorUnloader(
                  root, /*includeNullCount*/
                  true,
                  factory.createCodec(codec), /*alignBuffers*/
                  true);
          unloader.getRecordBatch().close();

          root.getFieldVectors()
              .forEach(
                  (vector) -> {
                    Arrays.stream(vector.getBuffers(/*clear*/ false))
                        .forEach(
                            (buf) -> {
                              assertNotEquals(0, buf.getReferenceManager().getRefCount());
                            });
                  });
        });
  }

  void withRoot(
      CompressionUtil.CodecType codec,
      BiConsumer<CompressionCodec.Factory, VectorSchemaRoot> testBody) {
    final Schema schema =
        new Schema(
            Arrays.asList(
                Field.nullable("ints", new ArrowType.Int(32, true)),
                Field.nullable("strings", ArrowType.Utf8.INSTANCE)));
    CompressionCodec.Factory factory =
        codec == CompressionUtil.CodecType.NO_COMPRESSION
            ? NoCompressionCodec.Factory.INSTANCE
            : CommonsCompressionFactory.INSTANCE;
    try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
      final IntVector ints = (IntVector) root.getVector(0);
      final VarCharVector strings = (VarCharVector) root.getVector(1);
      // Doesn't get compressed
      ints.setSafe(0, 0x4a3e);
      ints.setSafe(1, 0x8aba);
      ints.setSafe(2, 0x4362);
      ints.setSafe(3, 0x383f);
      // Gets compressed
      String compressibleString = "                "; // 16 bytes
      compressibleString = compressibleString + compressibleString;
      compressibleString = compressibleString + compressibleString;
      compressibleString = compressibleString + compressibleString;
      compressibleString = compressibleString + compressibleString;
      compressibleString = compressibleString + compressibleString; // 512 bytes
      byte[] compressibleData = compressibleString.getBytes(StandardCharsets.UTF_8);
      strings.setSafe(0, compressibleData);
      strings.setSafe(1, compressibleData);
      strings.setSafe(2, compressibleData);
      strings.setSafe(3, compressibleData);
      root.setRowCount(4);

      testBody.accept(factory, root);
    }
  }
}