TestDictionaryRowGroupBuilder.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.orc.writer;

import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;

import java.util.Random;

import static com.facebook.airlift.testing.Assertions.assertLessThan;
import static io.airlift.slice.SizeOf.sizeOf;
import static java.util.Arrays.fill;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;

public class TestDictionaryRowGroupBuilder
{
    private static final int MAX_DICTIONARY_INDEX = 10_000;
    private static final int NULL_INDEX_ENTRY = -1;

    private final Random random = new Random();

    @Test
    public void testEmptyDictionary()
    {
        DictionaryRowGroupBuilder rowGroupBuilder = new DictionaryRowGroupBuilder();
        rowGroupBuilder.addIndexes(-1, new int[0], 0);

        byte[] byteIndexes = getByteIndexes(rowGroupBuilder);
        assertEquals(0, byteIndexes.length);
    }

    @Test
    public void testByteIndexes()
    {
        int[] dictionaryIndexes = createIndexArray(Byte.MAX_VALUE + 1, MAX_DICTIONARY_INDEX);

        for (int length : ImmutableList.of(0, 10, dictionaryIndexes.length)) {
            DictionaryRowGroupBuilder rowGroupBuilder = new DictionaryRowGroupBuilder();
            rowGroupBuilder.addIndexes(Byte.MAX_VALUE, dictionaryIndexes, length);
            byte[] byteIndexes = getByteIndexes(rowGroupBuilder);
            assertEquals(length, byteIndexes.length);
            for (int i = 0; i < length; i++) {
                assertEquals(dictionaryIndexes[i], byteIndexes[i]);
            }
        }
    }

    @Test
    public void testShortIndexes()
    {
        int[] dictionaryIndexes = createIndexArray(Short.MAX_VALUE + 1, MAX_DICTIONARY_INDEX);

        for (int length : ImmutableList.of(0, 10, dictionaryIndexes.length)) {
            DictionaryRowGroupBuilder rowGroupBuilder = new DictionaryRowGroupBuilder();
            rowGroupBuilder.addIndexes(Short.MAX_VALUE, dictionaryIndexes, length);
            short[] shortIndexes = getShortIndexes(rowGroupBuilder);
            assertEquals(length, shortIndexes.length);
            for (int i = 0; i < length; i++) {
                assertEquals(dictionaryIndexes[i], shortIndexes[i]);
            }
        }
    }

    @Test
    public void testIntegerIndexes()
    {
        int[] dictionaryIndexes = createIndexArray(Integer.MAX_VALUE, MAX_DICTIONARY_INDEX);

        for (int length : ImmutableList.of(0, 10, dictionaryIndexes.length)) {
            DictionaryRowGroupBuilder rowGroupBuilder = new DictionaryRowGroupBuilder();
            rowGroupBuilder.addIndexes(Integer.MAX_VALUE, dictionaryIndexes, length);
            int[] intIndexes = getIntegerIndexes(rowGroupBuilder);
            assertEquals(length, intIndexes.length);
            for (int i = 0; i < length; i++) {
                assertEquals(dictionaryIndexes[i], intIndexes[i]);
            }
        }
    }

    @Test(expectedExceptions = {IllegalStateException.class})
    public void testDecreasingMaxThrows()
    {
        DictionaryRowGroupBuilder rowGroupBuilder = new DictionaryRowGroupBuilder();
        rowGroupBuilder.addIndexes(5, new int[0], 0);
        rowGroupBuilder.addIndexes(3, new int[1], 1);
    }

    @Test
    public void testNullDictionary()
    {
        int[] indexes = new int[MAX_DICTIONARY_INDEX];
        fill(indexes, NULL_INDEX_ENTRY);
        DictionaryRowGroupBuilder rowGroupBuilder = new DictionaryRowGroupBuilder();
        rowGroupBuilder.addIndexes(NULL_INDEX_ENTRY, indexes, indexes.length);

        byte[] byteIndexes = getByteIndexes(rowGroupBuilder);
        compareIntAndByteArrays(indexes, byteIndexes);

        // Adding 0 element list should be ignored.
        rowGroupBuilder.addIndexes(-1, indexes, 0);
        byteIndexes = getByteIndexes(rowGroupBuilder);
        compareIntAndByteArrays(indexes, byteIndexes);
    }

    @Test
    public void testMultipleSegments()
    {
        int byteSegmentLength = 7;
        int shortSegmentLength = 6;
        int intSegmentLength = 17;

        int[][] segments = new int[byteSegmentLength + shortSegmentLength + intSegmentLength][];

        DictionaryRowGroupBuilder rowGroupBuilder = new DictionaryRowGroupBuilder();
        long emptyRetainedSizeInBytes = rowGroupBuilder.getRetainedSizeInBytes();
        int index = 0;

        for (int i = 0; i < byteSegmentLength; i++, index++) {
            segments[index] = createIndexArray(Byte.MAX_VALUE + 1, 1_000 + i * 5);
            rowGroupBuilder.addIndexes(Byte.MAX_VALUE, segments[index], 1_000 + i * 3);
        }

        for (int i = 0; i < shortSegmentLength; i++, index++) {
            segments[index] = createIndexArray(Short.MAX_VALUE + 1, 1_000 + i * 5);
            rowGroupBuilder.addIndexes(Short.MAX_VALUE, segments[index], 1_000 + i * 3);
        }

        for (int i = 0; i < intSegmentLength; i++, index++) {
            segments[index] = createIndexArray(Integer.MAX_VALUE, 1_000 + i * 5);
            rowGroupBuilder.addIndexes(Integer.MAX_VALUE, segments[index], 1_000 + i * 3);
        }

        byte[][] byteSegments = rowGroupBuilder.getByteSegments();
        short[][] shortSegments = rowGroupBuilder.getShortSegments();
        int[][] intSegments = rowGroupBuilder.getIntegerSegments();

        long indexSize = verifySegments(byteSegmentLength, shortSegmentLength, intSegmentLength, segments, byteSegments, shortSegments, intSegments);
        assertEquals(indexSize, rowGroupBuilder.getIndexRetainedBytes());
        long retainedBytesBeforeReset = rowGroupBuilder.getRetainedSizeInBytes();

        rowGroupBuilder.reset();
        assertEquals(0, rowGroupBuilder.getIndexRetainedBytes());
        assertNull(rowGroupBuilder.getByteSegments());
        assertNull(rowGroupBuilder.getShortSegments());
        assertNull(rowGroupBuilder.getIntegerSegments());
        long retainedBytesAfterReset = rowGroupBuilder.getRetainedSizeInBytes();
        assertLessThan(retainedBytesAfterReset, retainedBytesBeforeReset);
        assertEquals(emptyRetainedSizeInBytes, retainedBytesAfterReset);
    }

    private void compareIntAndByteArrays(int[] indexes, byte[] byteIndexes)
    {
        assertEquals(indexes.length, byteIndexes.length);
        for (int i = 0; i < byteIndexes.length; i++) {
            assertEquals(indexes[i], byteIndexes[i]);
        }
    }

    private int[] createIndexArray(int maxValue, int length)
    {
        int[] dictionaryIndexes = new int[length];
        for (int i = 0; i < length; i++) {
            if (random.nextBoolean()) {
                dictionaryIndexes[i] = NULL_INDEX_ENTRY;
            }
            else {
                dictionaryIndexes[i] = random.nextInt(maxValue);
            }
        }
        return dictionaryIndexes;
    }

    private byte[] getByteIndexes(DictionaryRowGroupBuilder rowGroupBuilder)
    {
        byte[][] byteSegments = rowGroupBuilder.getByteSegments();
        assertNotNull(byteSegments);
        assertEquals(1, byteSegments.length);

        assertNull(rowGroupBuilder.getShortSegments());
        assertNull(rowGroupBuilder.getIntegerSegments());
        assertNotNull(byteSegments[0]);
        return byteSegments[0];
    }

    private short[] getShortIndexes(DictionaryRowGroupBuilder rowGroupBuilder)
    {
        short[][] shortSegments = rowGroupBuilder.getShortSegments();
        assertNotNull(shortSegments);
        assertEquals(1, shortSegments.length);

        assertNull(rowGroupBuilder.getByteSegments());
        assertNull(rowGroupBuilder.getIntegerSegments());
        assertNotNull(shortSegments[0]);
        return shortSegments[0];
    }

    private int[] getIntegerIndexes(DictionaryRowGroupBuilder rowGroupBuilder)
    {
        int[][] integerSegments = rowGroupBuilder.getIntegerSegments();
        assertNotNull(integerSegments);
        assertEquals(1, integerSegments.length);

        assertNull(rowGroupBuilder.getByteSegments());
        assertNull(rowGroupBuilder.getShortSegments());
        assertNotNull(integerSegments[0]);
        return integerSegments[0];
    }

    private long verifySegments(int byteSegmentLength, int shortSegmentLength, int intSegmentLength, int[][] segments, byte[][] byteSegments, short[][] shortSegments, int[][] intSegments)
    {
        int index = 0;
        long totalSize = 0;
        assertEquals(byteSegmentLength, byteSegments.length);
        for (int i = 0; i < byteSegments.length; i++, index++) {
            assertEquals(1000 + i * 3, byteSegments[i].length);
            totalSize += sizeOf(byteSegments[i]);
            for (int j = 0; j < byteSegments[i].length; j++) {
                assertEquals(segments[index][j], byteSegments[i][j]);
            }
        }

        assertEquals(shortSegmentLength, shortSegments.length);
        for (int i = 0; i < shortSegments.length; i++, index++) {
            assertEquals(1000 + i * 3, shortSegments[i].length);
            totalSize += sizeOf(shortSegments[i]);
            for (int j = 0; j < shortSegments[i].length; j++) {
                assertEquals(segments[index][j], shortSegments[i][j]);
            }
        }

        assertEquals(intSegmentLength, intSegments.length);
        for (int i = 0; i < intSegments.length; i++, index++) {
            assertEquals(1000 + i * 3, intSegments[i].length);
            totalSize += sizeOf(intSegments[i]);
            for (int j = 0; j < intSegments[i].length; j++) {
                assertEquals(segments[index][j], intSegments[i][j]);
            }
        }
        return totalSize;
    }
}