TestPrestoSparkRowBatch.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.execution;

import com.facebook.presto.spark.classloader_interface.MutablePartitionId;
import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow;
import com.facebook.presto.spark.execution.PrestoSparkRowBatch.PrestoSparkRowBatchBuilder;
import com.facebook.presto.spark.execution.PrestoSparkRowBatch.RowIndex;
import com.facebook.presto.spark.execution.PrestoSparkRowBatch.RowTupleSupplier;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.SliceOutput;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import org.testng.annotations.Test;
import scala.Tuple2;

import java.nio.ByteBuffer;
import java.util.List;
import java.util.Objects;
import java.util.stream.IntStream;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Integer.BYTES;
import static java.lang.Integer.MAX_VALUE;
import static java.nio.ByteOrder.LITTLE_ENDIAN;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

public class TestPrestoSparkRowBatch
{
    private static final int REPLICATED_ROW_PARTITION_ID = -1;

    private static final int DEFAULT_TARGET_SIZE = 1024 * 1024;
    private static final int DEFAULT_EXPECTED_ROWS = 10000;
    private static final int NO_TARGET_ENTRY_SIZE_REQUIREMENT = 0;
    private static final int UNLIMITED_MAX_ENTRY_ROW_COUNT = MAX_VALUE;
    private static final int UNLIMITED_MAX_ENTRY_SIZE = MAX_VALUE;

    @Test
    public void testRoundTrip()
    {
        assertRoundTrip(ImmutableList.of());
        assertRoundTrip(ImmutableList.of(
                createRow(1, "row_data_1")));
        assertRoundTrip(ImmutableList.of(
                createRow(1, "")));
        assertRoundTrip(ImmutableList.of(
                createRow(1, ""),
                createRow(1, "")));
        assertRoundTrip(ImmutableList.of(
                createRow(1, ""),
                createRow(1, ""),
                createRow(1, "")));
        assertRoundTrip(ImmutableList.of(
                createRow(1, "row_data_1"),
                createRow(1, "row_data_2")));
        assertRoundTrip(ImmutableList.of(
                createRow(1, "row_data_1"),
                createRow(2, "row_data_2")));
        assertRoundTrip(ImmutableList.of(
                createRow(1, "row_data_1"),
                createRow(2, "row_data_2")));
        assertRoundTrip(IntStream.range(0, 4)
                .mapToObj(i -> createRow(i, "row_data_" + i))
                .collect(toImmutableList()));
        assertRoundTrip(IntStream.range(0, 5)
                .mapToObj(i -> createRow(i, "row_data"))
                .collect(toImmutableList()));
        assertRoundTrip(IntStream.range(0, 20)
                .mapToObj(i -> createRow(i, ""))
                .collect(toImmutableList()));
        assertRoundTrip(IntStream.range(0, 20)
                .mapToObj(i -> createRow(0, ""))
                .collect(toImmutableList()));
    }

    @Test
    public void testBuilderFull()
    {
        PrestoSparkRowBatchBuilder builder = PrestoSparkRowBatch.builder(
                10,
                5,
                10,
                NO_TARGET_ENTRY_SIZE_REQUIREMENT,
                UNLIMITED_MAX_ENTRY_SIZE,
                UNLIMITED_MAX_ENTRY_ROW_COUNT);
        assertFalse(builder.isFull());
        assertTrue(builder.isEmpty());
        addRow(builder, createRow(1, "12345"));
        assertTrue(builder.isFull());
        assertFalse(builder.isEmpty());
    }

    @Test
    public void testReplicatedRows()
    {
        assertRoundTrip(
                ImmutableList.of(
                        createReplicatedRow("replicated")),
                1,
                ImmutableList.of(
                        createRow(0, "replicated")));
        assertRoundTrip(
                ImmutableList.of(
                        createReplicatedRow("replicated")),
                2,
                ImmutableList.of(
                        createRow(0, "replicated"),
                        createRow(1, "replicated")));
        assertRoundTrip(
                ImmutableList.of(
                        createReplicatedRow("replicated")),
                3,
                ImmutableList.of(
                        createRow(0, "replicated"),
                        createRow(1, "replicated"),
                        createRow(2, "replicated")));
        assertRoundTrip(
                ImmutableList.of(
                        createReplicatedRow("replicated")),
                3,
                ImmutableList.of(
                        createRow(0, "replicated"),
                        createRow(1, "replicated"),
                        createRow(2, "replicated")));
        assertRoundTrip(
                ImmutableList.of(
                        createReplicatedRow("replicated"),
                        createRow(1, "non_replicated_1")),
                3,
                ImmutableList.of(
                        createRow(0, "replicated"),
                        createRow(1, "replicated"),
                        createRow(2, "replicated"),
                        createRow(1, "non_replicated_1")));
        assertRoundTrip(
                ImmutableList.of(
                        createRow(2, "non_replicated_22"),
                        createReplicatedRow("replicated")),
                3,
                ImmutableList.of(
                        createRow(0, "replicated"),
                        createRow(1, "replicated"),
                        createRow(2, "replicated"),
                        createRow(2, "non_replicated_22")));
        assertRoundTrip(
                ImmutableList.of(
                        createRow(1, "non_replicated_22"),
                        createReplicatedRow("replicated"),
                        createRow(0, "non_replicated_1")),
                2,
                ImmutableList.of(
                        createRow(0, "replicated"),
                        createRow(1, "replicated"),
                        createRow(1, "non_replicated_22"),
                        createRow(0, "non_replicated_1")));
        assertRoundTrip(
                ImmutableList.of(
                        createRow(1, "non_replicated_22"),
                        createReplicatedRow("replicated1"),
                        createReplicatedRow("replicated2"),
                        createRow(0, "non_replicated_1")),
                2,
                ImmutableList.of(
                        createRow(0, "replicated1"),
                        createRow(1, "replicated1"),
                        createRow(0, "replicated2"),
                        createRow(1, "replicated2"),
                        createRow(1, "non_replicated_22"),
                        createRow(0, "non_replicated_1")));
    }

    @Test
    public void testRowIndex()
    {
        assertRowIndex(new int[] {}, new int[][] {}, new int[] {});
        assertRowIndex(new int[] {}, new int[][] {new int[] {}, new int[] {}}, new int[] {});
        assertRowIndex(new int[] {0}, new int[][] {new int[] {0}, new int[] {}}, new int[] {});
        assertRowIndex(new int[] {1}, new int[][] {new int[] {}, new int[] {0}}, new int[] {});
        assertRowIndex(new int[] {0, 1}, new int[][] {new int[] {0}, new int[] {1}}, new int[] {});
        assertRowIndex(new int[] {0, 1, 1, 0, 0, 1}, new int[][] {new int[] {0, 3, 4}, new int[] {1, 2, 5}}, new int[] {});
        assertRowIndex(new int[] {0, 1, 1, 0, 0, 1}, new int[][] {new int[] {0, 3, 4}, new int[] {1, 2, 5}, new int[] {}}, new int[] {});
        assertRowIndex(new int[] {0, 1, 1, 2, 0, 2, 0, 1, 2}, new int[][] {new int[] {0, 4, 6}, new int[] {1, 2, 7}, new int[] {3, 5, 8}}, new int[] {});
        assertRowIndex(new int[] {1, 1, 1, 2, 1, 2, 1, 1, 2}, new int[][] {new int[] {}, new int[] {0, 1, 2, 4, 6, 7}, new int[] {3, 5, 8}}, new int[] {});
        assertRowIndex(new int[] {-1}, new int[][] {}, new int[] {0});
        assertRowIndex(new int[] {-1, -1}, new int[][] {}, new int[] {0, 1});
        assertRowIndex(new int[] {0, 1, -1}, new int[][] {new int[] {0}, new int[] {1}}, new int[] {2});
        assertRowIndex(new int[] {1, 1, 1, 2, -1, 1, 2, 1, 1, 2, -1}, new int[][] {new int[] {}, new int[] {0, 1, 2, 5, 7, 8}, new int[] {3, 6, 9}}, new int[] {4, 10});
    }

    @Test
    public void testMultiRowEntries()
    {
        Row row01 = createRow(0, "p0_1");
        Row row02 = createRow(0, "p0_11");
        Row row03 = createRow(0, "p0_111");
        Row row11 = createRow(1, "p1_1");
        Row row21 = createRow(2, "p2_1");

        List<Row> rows = ImmutableList.of(row01, row02, row03, row11, row21);

        assertEntries(
                rows,
                3,
                4,
                UNLIMITED_MAX_ENTRY_SIZE,
                UNLIMITED_MAX_ENTRY_ROW_COUNT,
                ImmutableList.of(
                        ImmutableList.of(row01),
                        ImmutableList.of(row02),
                        ImmutableList.of(row03),
                        ImmutableList.of(row11),
                        ImmutableList.of(row21)));
        assertEntries(
                rows,
                3,
                11,
                UNLIMITED_MAX_ENTRY_SIZE,
                UNLIMITED_MAX_ENTRY_ROW_COUNT,
                ImmutableList.of(
                        ImmutableList.of(row01, row02, row03),
                        ImmutableList.of(row11),
                        ImmutableList.of(row21)));
        assertEntries(
                rows,
                3,
                11,
                UNLIMITED_MAX_ENTRY_ROW_COUNT,
                2,
                ImmutableList.of(
                        ImmutableList.of(row01, row02),
                        ImmutableList.of(row03),
                        ImmutableList.of(row11),
                        ImmutableList.of(row21)));
        assertEntries(
                rows,
                3,
                11,
                18,
                UNLIMITED_MAX_ENTRY_ROW_COUNT,
                ImmutableList.of(
                        ImmutableList.of(row01, row02),
                        ImmutableList.of(row03),
                        ImmutableList.of(row11),
                        ImmutableList.of(row21)));
        assertEntries(
                rows,
                4,
                10,
                0,
                UNLIMITED_MAX_ENTRY_ROW_COUNT,
                ImmutableList.of(
                        ImmutableList.of(row01),
                        ImmutableList.of(row02),
                        ImmutableList.of(row03),
                        ImmutableList.of(row11),
                        ImmutableList.of(row21)));
        assertEntries(
                rows,
                3,
                10,
                UNLIMITED_MAX_ENTRY_SIZE,
                0,
                ImmutableList.of(
                        ImmutableList.of(row01),
                        ImmutableList.of(row02),
                        ImmutableList.of(row03),
                        ImmutableList.of(row11),
                        ImmutableList.of(row21)));
    }

    private static void assertRoundTrip(List<Row> rows)
    {
        // replicated rows are not allowed
        assertThat(rows).allMatch(row -> row.getPartition() >= 0);
        assertRoundTrip(rows, 20, rows);
    }

    private static void assertRoundTrip(List<Row> inputRows, int partitionCount, List<Row> expectedOutputRows)
    {
        PrestoSparkRowBatchBuilder singleRowEntryBuilder = PrestoSparkRowBatch.builder(
                partitionCount,
                DEFAULT_TARGET_SIZE,
                DEFAULT_EXPECTED_ROWS,
                0,
                UNLIMITED_MAX_ENTRY_SIZE,
                UNLIMITED_MAX_ENTRY_ROW_COUNT);
        assertRoundTrip(singleRowEntryBuilder, inputRows, partitionCount, expectedOutputRows);

        PrestoSparkRowBatchBuilder multiRowEntryBuilder = PrestoSparkRowBatch.builder(
                partitionCount,
                DEFAULT_TARGET_SIZE,
                DEFAULT_EXPECTED_ROWS,
                1024,
                UNLIMITED_MAX_ENTRY_SIZE,
                UNLIMITED_MAX_ENTRY_ROW_COUNT);
        assertRoundTrip(multiRowEntryBuilder, inputRows, partitionCount, expectedOutputRows);
    }

    private static void assertRoundTrip(PrestoSparkRowBatchBuilder builder, List<Row> inputRows, int partitionCount, List<Row> expectedOutputRows)
    {
        assertThat(inputRows).allMatch(row -> row.getPartition() < partitionCount);
        assertTrue(builder.isEmpty());
        for (Row row : inputRows) {
            addRow(builder, row);
        }
        assertFalse(builder.isFull());
        PrestoSparkRowBatch rowBatch = builder.build();
        assertContains(rowBatch, expectedOutputRows);
    }

    private static void assertEntries(
            List<Row> rows,
            int partitionCount,
            int targetEntrySize,
            int maxEntrySize,
            int maxRowsPerEntry,
            List<List<Row>> expectedEntries)
    {
        PrestoSparkRowBatchBuilder builder = PrestoSparkRowBatch.builder(
                partitionCount,
                DEFAULT_TARGET_SIZE,
                DEFAULT_EXPECTED_ROWS,
                targetEntrySize,
                maxEntrySize,
                maxRowsPerEntry);
        assertTrue(builder.isEmpty());
        for (Row row : rows) {
            addRow(builder, row);
        }
        assertFalse(builder.isFull());

        PrestoSparkRowBatch rowBatch = builder.build();

        List<List<Row>> actualEntries = getEntries(rowBatch);
        assertEquals(actualEntries, expectedEntries);
    }

    private static void addRow(PrestoSparkRowBatchBuilder builder, Row row)
    {
        SliceOutput output = builder.beginRowEntry();
        byte[] data = row.getData().getBytes(UTF_8);
        int bufferSize = data.length + BYTES;
        ByteBuffer buffer = ByteBuffer.allocate(bufferSize);
        buffer.order(LITTLE_ENDIAN);
        buffer.putInt(data.length);
        buffer.put(data);
        output.writeBytes(buffer.array(), 0, bufferSize);
        if (row.isReplicated()) {
            builder.closeEntryForReplicatedRow();
        }
        else {
            builder.closeEntryForNonReplicatedRow(row.getPartition());
        }
    }

    private static void assertContains(PrestoSparkRowBatch rowBatch, List<Row> expected)
    {
        List<List<Row>> entries = getEntries(rowBatch);
        List<Row> rows = entries.stream()
                .flatMap(List::stream)
                .collect(toImmutableList());
        assertThat(rows)
                .containsExactlyInAnyOrder(expected.toArray(new Row[0]));
    }

    private static List<List<Row>> getEntries(PrestoSparkRowBatch rowBatch)
    {
        ImmutableList.Builder<List<Row>> entries = ImmutableList.builder();
        RowTupleSupplier rowTupleSupplier = rowBatch.createRowTupleSupplier();
        while (true) {
            Tuple2<MutablePartitionId, PrestoSparkMutableRow> next = rowTupleSupplier.getNext();
            if (next == null) {
                break;
            }
            ImmutableList.Builder<Row> entry = ImmutableList.builder();
            int partition = next._1.getPartition();
            PrestoSparkMutableRow mutableRow = next._2;
            ByteBuffer buffer = mutableRow.getBuffer();
            buffer.order(LITTLE_ENDIAN);
            short rowCount = buffer.getShort();
            assertEquals(mutableRow.getPositionCount(), rowCount);
            for (int i = 0; i < rowCount; i++) {
                entry.add(new Row(partition, readRowData(buffer)));
            }
            entries.add(entry.build());
        }
        return entries.build();
    }

    private static String readRowData(ByteBuffer buffer)
    {
        int size = buffer.getInt();
        String data = new String(buffer.array(), buffer.arrayOffset() + buffer.position(), size, UTF_8);
        buffer.position(buffer.position() + size);
        return data;
    }

    private static Row createRow(int partition, String data)
    {
        return new Row(partition, data);
    }

    private static Row createReplicatedRow(String data)
    {
        return new Row(REPLICATED_ROW_PARTITION_ID, data);
    }

    private static class Row
    {
        private final int partition;
        private final String data;

        private Row(int partition, String data)
        {
            this.partition = partition;
            this.data = requireNonNull(data, "data is null");
        }

        public int getPartition()
        {
            return partition;
        }

        public String getData()
        {
            return data;
        }

        public boolean isReplicated()
        {
            return partition == REPLICATED_ROW_PARTITION_ID;
        }

        @Override
        public boolean equals(Object o)
        {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            Row row = (Row) o;
            return partition == row.partition &&
                    Objects.equals(data, row.data);
        }

        @Override
        public int hashCode()
        {
            return Objects.hash(partition, data);
        }

        @Override
        public String toString()
        {
            return toStringHelper(this)
                    .add("partition", partition)
                    .add("data", data)
                    .toString();
        }
    }

    private static void assertRowIndex(int[] partitions, int[][] expected, int[] expectedReplicated)
    {
        RowIndex rowIndex = RowIndex.create(partitions.length, expected.length, partitions);
        int[][] actual = new int[expected.length][];
        for (int partition = 0; partition < expected.length; partition++) {
            IntArrayList partitionRows = new IntArrayList();
            while (rowIndex.hasNextRow(partition)) {
                partitionRows.add(rowIndex.nextRow(partition));
            }
            actual[partition] = partitionRows.toIntArray();
        }
        assertThat(actual).isEqualTo(expected);

        IntArrayList replicatedRows = new IntArrayList();
        while (rowIndex.hasNextRow(REPLICATED_ROW_PARTITION_ID)) {
            replicatedRows.add(rowIndex.nextRow(REPLICATED_ROW_PARTITION_ID));
        }
        assertThat(replicatedRows.toIntArray()).isEqualTo(expectedReplicated);
    }
}