TestHashTableDictionaryEncoder.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.algorithm.dictionary;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Random;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryEncoder;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
/** Test cases for {@link HashTableDictionaryEncoder}. */
public class TestHashTableDictionaryEncoder {
private final int VECTOR_LENGTH = 50;
private final int DICTIONARY_LENGTH = 10;
private BufferAllocator allocator;
byte[] zero = "000".getBytes(StandardCharsets.UTF_8);
byte[] one = "111".getBytes(StandardCharsets.UTF_8);
byte[] two = "222".getBytes(StandardCharsets.UTF_8);
byte[][] data = new byte[][] {zero, one, two};
@BeforeEach
public void prepare() {
allocator = new RootAllocator(1024 * 1024);
}
@AfterEach
public void shutdown() {
allocator.close();
}
@Test
public void testEncodeAndDecode() {
Random random = new Random();
try (VarCharVector rawVector = new VarCharVector("original vector", allocator);
IntVector encodedVector = new IntVector("encoded vector", allocator);
VarCharVector dictionary = new VarCharVector("dictionary", allocator)) {
// set up dictionary
dictionary.allocateNew();
for (int i = 0; i < DICTIONARY_LENGTH; i++) {
// encode "i" as i
dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8));
}
dictionary.setValueCount(DICTIONARY_LENGTH);
// set up raw vector
rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH);
for (int i = 0; i < VECTOR_LENGTH; i++) {
int val = (random.nextInt() & Integer.MAX_VALUE) % DICTIONARY_LENGTH;
rawVector.set(i, String.valueOf(val).getBytes(StandardCharsets.UTF_8));
}
rawVector.setValueCount(VECTOR_LENGTH);
HashTableDictionaryEncoder<IntVector, VarCharVector> encoder =
new HashTableDictionaryEncoder<>(dictionary, false);
// perform encoding
encodedVector.allocateNew();
encoder.encode(rawVector, encodedVector);
// verify encoding results
assertEquals(rawVector.getValueCount(), encodedVector.getValueCount());
for (int i = 0; i < VECTOR_LENGTH; i++) {
assertArrayEquals(
rawVector.get(i),
String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8));
}
// perform decoding
Dictionary dict = new Dictionary(dictionary, new DictionaryEncoding(1L, false, null));
try (VarCharVector decodedVector =
(VarCharVector) DictionaryEncoder.decode(encodedVector, dict)) {
// verify decoding results
assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount());
for (int i = 0; i < VECTOR_LENGTH; i++) {
assertArrayEquals(
String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8),
decodedVector.get(i));
}
}
}
}
@Test
public void testEncodeAndDecodeWithNull() {
Random random = new Random();
try (VarCharVector rawVector = new VarCharVector("original vector", allocator);
IntVector encodedVector = new IntVector("encoded vector", allocator);
VarCharVector dictionary = new VarCharVector("dictionary", allocator)) {
// set up dictionary
dictionary.allocateNew();
dictionary.setNull(0);
for (int i = 1; i < DICTIONARY_LENGTH; i++) {
// encode "i" as i
dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8));
}
dictionary.setValueCount(DICTIONARY_LENGTH);
// set up raw vector
rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH);
for (int i = 0; i < VECTOR_LENGTH; i++) {
if (i % 10 == 0) {
rawVector.setNull(i);
} else {
int val = (random.nextInt() & Integer.MAX_VALUE) % (DICTIONARY_LENGTH - 1) + 1;
rawVector.set(i, String.valueOf(val).getBytes(StandardCharsets.UTF_8));
}
}
rawVector.setValueCount(VECTOR_LENGTH);
HashTableDictionaryEncoder<IntVector, VarCharVector> encoder =
new HashTableDictionaryEncoder<>(dictionary, true);
// perform encoding
encodedVector.allocateNew();
encoder.encode(rawVector, encodedVector);
// verify encoding results
assertEquals(rawVector.getValueCount(), encodedVector.getValueCount());
for (int i = 0; i < VECTOR_LENGTH; i++) {
if (i % 10 == 0) {
assertEquals(0, encodedVector.get(i));
} else {
assertArrayEquals(
rawVector.get(i),
String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8));
}
}
// perform decoding
Dictionary dict = new Dictionary(dictionary, new DictionaryEncoding(1L, false, null));
try (VarCharVector decodedVector =
(VarCharVector) DictionaryEncoder.decode(encodedVector, dict)) {
// verify decoding results
assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount());
for (int i = 0; i < VECTOR_LENGTH; i++) {
if (i % 10 == 0) {
assertTrue(decodedVector.isNull(i));
} else {
assertArrayEquals(
String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8),
decodedVector.get(i));
}
}
}
}
}
@Test
public void testEncodeNullWithoutNullInDictionary() {
try (VarCharVector rawVector = new VarCharVector("original vector", allocator);
IntVector encodedVector = new IntVector("encoded vector", allocator);
VarCharVector dictionary = new VarCharVector("dictionary", allocator)) {
// set up dictionary, with no null in it.
dictionary.allocateNew();
for (int i = 0; i < DICTIONARY_LENGTH; i++) {
// encode "i" as i
dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8));
}
dictionary.setValueCount(DICTIONARY_LENGTH);
// the vector to encode has a null inside.
rawVector.allocateNew(1);
rawVector.setNull(0);
rawVector.setValueCount(1);
encodedVector.allocateNew();
HashTableDictionaryEncoder<IntVector, VarCharVector> encoder =
new HashTableDictionaryEncoder<>(dictionary, true);
// the encoder should encode null, but no null in the dictionary,
// so an exception should be thrown.
assertThrows(
IllegalArgumentException.class,
() -> {
encoder.encode(rawVector, encodedVector);
});
}
}
@Test
public void testEncodeStrings() {
// Create a new value vector
try (final VarCharVector vector = new VarCharVector("foo", allocator);
final IntVector encoded = new IntVector("encoded", allocator);
final VarCharVector dictionaryVector = new VarCharVector("dict", allocator)) {
vector.allocateNew(512, 5);
encoded.allocateNew();
// set some values
vector.setSafe(0, zero, 0, zero.length);
vector.setSafe(1, one, 0, one.length);
vector.setSafe(2, one, 0, one.length);
vector.setSafe(3, two, 0, two.length);
vector.setSafe(4, zero, 0, zero.length);
vector.setValueCount(5);
// set some dictionary values
dictionaryVector.allocateNew(512, 3);
dictionaryVector.setSafe(0, zero, 0, one.length);
dictionaryVector.setSafe(1, one, 0, two.length);
dictionaryVector.setSafe(2, two, 0, zero.length);
dictionaryVector.setValueCount(3);
HashTableDictionaryEncoder<IntVector, VarCharVector> encoder =
new HashTableDictionaryEncoder<>(dictionaryVector);
encoder.encode(vector, encoded);
// verify indices
assertEquals(5, encoded.getValueCount());
assertEquals(0, encoded.get(0));
assertEquals(1, encoded.get(1));
assertEquals(1, encoded.get(2));
assertEquals(2, encoded.get(3));
assertEquals(0, encoded.get(4));
// now run through the decoder and verify we get the original back
Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
try (VarCharVector decoded = (VarCharVector) DictionaryEncoder.decode(encoded, dict)) {
assertEquals(vector.getValueCount(), decoded.getValueCount());
for (int i = 0; i < 5; i++) {
assertEquals(vector.getObject(i), decoded.getObject(i));
}
}
}
}
@Test
public void testEncodeLargeVector() {
// Create a new value vector
try (final VarCharVector vector = new VarCharVector("foo", allocator);
final IntVector encoded = new IntVector("encoded", allocator);
final VarCharVector dictionaryVector = new VarCharVector("dict", allocator)) {
vector.allocateNew();
encoded.allocateNew();
int count = 10000;
for (int i = 0; i < 10000; ++i) {
vector.setSafe(i, data[i % 3], 0, data[i % 3].length);
}
vector.setValueCount(count);
dictionaryVector.allocateNew(512, 3);
dictionaryVector.setSafe(0, zero, 0, one.length);
dictionaryVector.setSafe(1, one, 0, two.length);
dictionaryVector.setSafe(2, two, 0, zero.length);
dictionaryVector.setValueCount(3);
HashTableDictionaryEncoder<IntVector, VarCharVector> encoder =
new HashTableDictionaryEncoder<>(dictionaryVector);
encoder.encode(vector, encoded);
assertEquals(count, encoded.getValueCount());
for (int i = 0; i < count; ++i) {
assertEquals(i % 3, encoded.get(i));
}
// now run through the decoder and verify we get the original back
Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
try (VarCharVector decoded = (VarCharVector) DictionaryEncoder.decode(encoded, dict)) {
assertEquals(vector.getClass(), decoded.getClass());
assertEquals(vector.getValueCount(), decoded.getValueCount());
for (int i = 0; i < count; ++i) {
assertEquals(vector.getObject(i), decoded.getObject(i));
}
}
}
}
@Test
public void testEncodeBinaryVector() {
// Create a new value vector
try (final VarBinaryVector vector = new VarBinaryVector("foo", allocator);
final VarBinaryVector dictionaryVector = new VarBinaryVector("dict", allocator);
final IntVector encoded = new IntVector("encoded", allocator)) {
vector.allocateNew(512, 5);
vector.allocateNew();
encoded.allocateNew();
// set some values
vector.setSafe(0, zero, 0, zero.length);
vector.setSafe(1, one, 0, one.length);
vector.setSafe(2, one, 0, one.length);
vector.setSafe(3, two, 0, two.length);
vector.setSafe(4, zero, 0, zero.length);
vector.setValueCount(5);
// set some dictionary values
dictionaryVector.allocateNew(512, 3);
dictionaryVector.setSafe(0, zero, 0, one.length);
dictionaryVector.setSafe(1, one, 0, two.length);
dictionaryVector.setSafe(2, two, 0, zero.length);
dictionaryVector.setValueCount(3);
HashTableDictionaryEncoder<IntVector, VarBinaryVector> encoder =
new HashTableDictionaryEncoder<>(dictionaryVector);
encoder.encode(vector, encoded);
assertEquals(5, encoded.getValueCount());
assertEquals(0, encoded.get(0));
assertEquals(1, encoded.get(1));
assertEquals(1, encoded.get(2));
assertEquals(2, encoded.get(3));
assertEquals(0, encoded.get(4));
// now run through the decoder and verify we get the original back
Dictionary dict = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
try (VarBinaryVector decoded = (VarBinaryVector) DictionaryEncoder.decode(encoded, dict)) {
assertEquals(vector.getClass(), decoded.getClass());
assertEquals(vector.getValueCount(), decoded.getValueCount());
for (int i = 0; i < 5; i++) {
assertTrue(Arrays.equals(vector.getObject(i), decoded.getObject(i)));
}
}
}
}
}