PrestoSparkRowBatch.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.spark.execution;
import com.facebook.presto.spark.classloader_interface.MutablePartitionId;
import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow;
import com.google.common.annotations.VisibleForTesting;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.SliceOutput;
import org.openjdk.jol.info.ClassLayout;
import scala.Tuple2;
import javax.annotation.Nullable;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.Arrays;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static io.airlift.slice.SizeOf.sizeOf;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.nio.ByteOrder.LITTLE_ENDIAN;
import static java.util.Arrays.fill;
import static java.util.Objects.requireNonNull;
public class PrestoSparkRowBatch
implements PrestoSparkBufferedResult
{
private static final int INSTANCE_SIZE = ClassLayout.parseClass(PrestoSparkRowBatch.class).instanceSize();
private static final int MIN_TARGET_SIZE_IN_BYTES = 1024 * 1024;
private static final int MAX_TARGET_SIZE_IN_BYTES = 10 * 1024 * 1024;
private static final int DEFAULT_EXPECTED_ROWS_COUNT = 10000;
private static final int REPLICATED_ROW_PARTITION_ID = -1;
private static final short MULTI_ROW_ENTRY_MAX_SIZE_IN_BYTES = 10 * 1024;
private static final short MULTI_ROW_ENTRY_MAX_ROW_COUNT = 10 * 1024;
private final int partitionCount;
private final int rowCount;
private final byte[] rowData;
private final int[] rowPartitions;
private final int[] rowOffsets;
private final int totalSizeInBytes;
private final long retainedSizeInBytes;
private PrestoSparkRowBatch(int partitionCount, int rowCount, byte[] rowData, int[] rowPartitions, int[] rowOffsets, int totalSizeInBytes)
{
this.partitionCount = partitionCount;
this.rowCount = rowCount;
this.rowData = requireNonNull(rowData, "rowData is null");
this.rowPartitions = requireNonNull(rowPartitions, "rowPartitions is null");
this.rowOffsets = requireNonNull(rowOffsets, "rowOffsets is null");
this.retainedSizeInBytes = INSTANCE_SIZE
+ sizeOf(rowData)
+ sizeOf(rowPartitions)
+ sizeOf(rowOffsets);
this.totalSizeInBytes = totalSizeInBytes;
}
public RowTupleSupplier createRowTupleSupplier()
{
return new RowTupleSupplier(partitionCount, rowCount, rowData, rowPartitions, rowOffsets, totalSizeInBytes);
}
public long getRetainedSizeInBytes()
{
return retainedSizeInBytes;
}
@Override
public int getPositionCount()
{
return rowCount;
}
public static PrestoSparkRowBatchBuilder builder(int partitionCount, int targetAverageRowSizeInBytes)
{
checkArgument(partitionCount > 0, "partitionCount must be greater then zero: %s", partitionCount);
int targetSizeInBytes = partitionCount * targetAverageRowSizeInBytes;
targetSizeInBytes = max(targetSizeInBytes, MIN_TARGET_SIZE_IN_BYTES);
targetSizeInBytes = min(targetSizeInBytes, MAX_TARGET_SIZE_IN_BYTES);
targetAverageRowSizeInBytes = min(targetSizeInBytes / partitionCount, targetAverageRowSizeInBytes);
return builder(
partitionCount,
targetSizeInBytes,
DEFAULT_EXPECTED_ROWS_COUNT,
targetAverageRowSizeInBytes,
MULTI_ROW_ENTRY_MAX_SIZE_IN_BYTES,
MULTI_ROW_ENTRY_MAX_ROW_COUNT);
}
@VisibleForTesting
static PrestoSparkRowBatchBuilder builder(
int partitionCount,
int targetSizeInBytes,
int expectedRowsCount,
int targetAverageRowSizeInBytes,
int maxEntrySizeInBytes,
int maxRowsPerEntry)
{
return new PrestoSparkRowBatchBuilder(
partitionCount,
targetSizeInBytes,
expectedRowsCount,
targetAverageRowSizeInBytes,
maxEntrySizeInBytes,
maxRowsPerEntry);
}
public static class PrestoSparkRowBatchBuilder
{
private static final int BUILDER_INSTANCE_SIZE = ClassLayout.parseClass(PrestoSparkRowBatchBuilder.class).instanceSize();
private final int partitionCount;
private final int targetSizeInBytes;
private final int targetAverageRowSizeInBytes;
private final int maxEntrySizeInBytes;
private final int maxRowsPerEntry;
private final DynamicSliceOutput sliceOutput;
private int[] rowOffsets;
private int totalSizeInBytes;
private int[] rowPartitions;
private int rowCount;
private int currentRowOffset;
private boolean openEntry;
private PrestoSparkRowBatchBuilder(
int partitionCount,
int targetSizeInBytes,
int expectedRowsCount,
int targetAverageRowSizeInBytes,
int maxEntrySizeInBytes,
int maxRowsPerEntry)
{
checkArgument(partitionCount > 0, "partitionCount must be greater then zero: %s", partitionCount);
this.partitionCount = partitionCount;
this.targetSizeInBytes = targetSizeInBytes;
this.targetAverageRowSizeInBytes = targetAverageRowSizeInBytes;
this.maxEntrySizeInBytes = maxEntrySizeInBytes;
this.maxRowsPerEntry = maxRowsPerEntry;
sliceOutput = new DynamicSliceOutput((int) (targetSizeInBytes * 1.2f));
rowOffsets = new int[expectedRowsCount];
rowPartitions = new int[expectedRowsCount];
}
public long getRetainedSizeInBytes()
{
return BUILDER_INSTANCE_SIZE + sliceOutput.getRetainedSize() + sizeOf(rowOffsets) + sizeOf(rowPartitions);
}
public boolean isFull()
{
return sliceOutput.size() >= targetSizeInBytes;
}
public boolean isEmpty()
{
return rowCount == 0;
}
public SliceOutput beginRowEntry()
{
checkState(!openEntry, "previous entry must be closed before creating a new entry");
openEntry = true;
currentRowOffset = sliceOutput.size();
sliceOutput.writeShort(1);
return sliceOutput;
}
public void closeEntryForNonReplicatedRow(int partition)
{
closeEntry(partition);
}
public void closeEntryForReplicatedRow()
{
closeEntry(REPLICATED_ROW_PARTITION_ID);
}
private void closeEntry(int partitionId)
{
checkState(openEntry, "entry must be opened first");
openEntry = false;
rowOffsets = ensureCapacity(rowOffsets, rowCount + 1);
rowOffsets[rowCount] = currentRowOffset;
rowPartitions = ensureCapacity(rowPartitions, rowCount + 1);
rowPartitions[rowCount] = partitionId;
rowCount++;
totalSizeInBytes += sliceOutput.size() - currentRowOffset;
}
private static int[] ensureCapacity(int[] array, int capacity)
{
if (array.length >= capacity) {
return array;
}
return Arrays.copyOf(array, capacity * 2);
}
public PrestoSparkRowBatch build()
{
checkState(!openEntry, "entry must be closed before creating a row batch");
if (rowCount == 0) {
return createDirectRowBatch();
}
int averageRowSize = totalSizeInBytes / rowCount;
if (averageRowSize < targetAverageRowSizeInBytes) {
return createGroupedRowBatch();
}
return createDirectRowBatch();
}
private PrestoSparkRowBatch createDirectRowBatch()
{
return new PrestoSparkRowBatch(
partitionCount,
rowCount,
sliceOutput.getUnderlyingSlice().byteArray(),
rowPartitions,
rowOffsets,
totalSizeInBytes);
}
private PrestoSparkRowBatch createGroupedRowBatch()
{
RowIndex rowIndex = RowIndex.create(rowCount, partitionCount, rowPartitions);
byte[] data = sliceOutput.getUnderlyingSlice().byteArray();
DynamicSliceOutput output = new DynamicSliceOutput((int) (totalSizeInBytes * 1.2f));
int expectedEntriesCount = (int) ((totalSizeInBytes / targetAverageRowSizeInBytes) * 1.2f);
int[] entryOffsets = new int[expectedEntriesCount];
int[] entryPartitions = new int[expectedEntriesCount];
int entriesCount = 0;
for (int partition = REPLICATED_ROW_PARTITION_ID; partition < partitionCount; partition++) {
while (rowIndex.hasNextRow(partition)) {
// start entry
short currentEntrySize = 0;
short currentEntryRowCount = 0;
int currentEntryOffset = output.size();
// Reserve space for the row count, the actual row count will be set later
output.writeShort(0);
entryOffsets = ensureCapacity(entryOffsets, entriesCount + 1);
entryOffsets[entriesCount] = currentEntryOffset;
entryPartitions = ensureCapacity(entryPartitions, entriesCount + 1);
entryPartitions[entriesCount] = partition;
while (rowIndex.hasNextRow(partition)) {
int row = rowIndex.peekRow(partition);
int followingRow = row + 1;
int rowOffset = rowOffsets[row];
int followingRowOffset = followingRow < rowCount ? rowOffsets[followingRow] : totalSizeInBytes;
int rowSize = followingRowOffset - rowOffset;
verify(rowSize >= 2, "rowSize is expected to be greater than or equal to 2: %s", rowSize);
// skip the rows count
rowOffset += 2;
rowSize -= 2;
if (currentEntryRowCount > 0 && (currentEntrySize + rowSize > maxEntrySizeInBytes || currentEntryRowCount + 1 > maxRowsPerEntry)) {
break;
}
output.writeBytes(data, rowOffset, rowSize);
currentEntrySize += rowSize;
currentEntryRowCount++;
rowIndex.nextRow(partition);
}
// entry is done
output.getUnderlyingSlice().setShort(currentEntryOffset, currentEntryRowCount);
entriesCount++;
}
}
return new PrestoSparkRowBatch(
partitionCount,
entriesCount,
output.getUnderlyingSlice().byteArray(),
entryPartitions,
entryOffsets,
output.size());
}
}
public static class RowTupleSupplier
{
private final int partitionCount;
private final int rowCount;
private final int[] rowPartitions;
private final int[] rowOffsets;
private final int totalSizeInBytes;
private int remainingReplicasCount;
private int currentRow;
private final ByteBuffer rowData;
private final MutablePartitionId mutablePartitionId;
private final PrestoSparkMutableRow row;
private final Tuple2<MutablePartitionId, PrestoSparkMutableRow> tuple;
private RowTupleSupplier(int partitionCount, int rowCount, byte[] rowData, int[] rowPartitions, int[] rowOffsets, int totalSizeInBytes)
{
this.partitionCount = partitionCount;
this.rowCount = rowCount;
this.rowPartitions = requireNonNull(rowPartitions, "rowPartitions is null");
this.rowOffsets = requireNonNull(rowOffsets, "rowSizes is null");
this.totalSizeInBytes = totalSizeInBytes;
this.rowData = ByteBuffer.wrap(requireNonNull(rowData, "rowData is null"));
this.rowData.order(LITTLE_ENDIAN);
mutablePartitionId = new MutablePartitionId();
row = new PrestoSparkMutableRow();
row.setBuffer(this.rowData);
tuple = new Tuple2<>(mutablePartitionId, row);
}
@Nullable
public Tuple2<MutablePartitionId, PrestoSparkMutableRow> getNext()
{
if (currentRow >= rowCount) {
return null;
}
int currentRowOffset = rowOffsets[currentRow];
int nextRow = currentRow + 1;
int nextRowOffset = nextRow < rowCount ? rowOffsets[nextRow] : totalSizeInBytes;
int rowSize = nextRowOffset - currentRowOffset;
((Buffer) rowData).limit(currentRowOffset + rowSize);
((Buffer) rowData).position(currentRowOffset);
short rowsCount = rowData.getShort(currentRowOffset);
row.setPositionCount(rowsCount);
int partition = rowPartitions[currentRow];
if (partition == REPLICATED_ROW_PARTITION_ID) {
if (remainingReplicasCount == 0) {
remainingReplicasCount = partitionCount;
}
mutablePartitionId.setPartition(remainingReplicasCount - 1);
remainingReplicasCount--;
if (remainingReplicasCount == 0) {
currentRow++;
}
}
else {
mutablePartitionId.setPartition(partition);
currentRow++;
}
return tuple;
}
}
/*
* Partitions rows into disjoint sets based on the partitions assigned
*
* int[] rowIndex - links rows that belong for the same partition
*
* For example for 3 rows with partitions assigned [2, 1, 2, 1] the
* row index will look like:
*
* [2, 3, -1, -1]
*
* int[] nextRow - contains the pointers to the next row for each partition:
*
* [-1, 1, 0]
*
* note: there's no rows with partition 0
*
* To get all rows for a single partition first we check what is the tip of the
* list of rows for that partition at the moment:
*
* int row = nextRow[partition]
*
* And then we iterate over the linked list to get all the rows that belong to
* the same partition:
*
* while (rowIndex[row] != -1)
* row = rowIndex[row]
*/
public static class RowIndex
{
private static final int NIL = -1;
private final int[] nextRow;
private final int[] rowIndex;
public static RowIndex create(int rowCount, int partitionCount, int[] partitions)
{
// one more slot for replicated partition
int[] nextRow = new int[partitionCount + 1];
fill(nextRow, NIL);
int[] rowIndex = new int[rowCount];
fill(rowIndex, NIL);
for (int row = rowCount - 1; row >= 0; row--) {
int partition = partitions[row];
int partitionIndex = getPartitionIndex(partition, nextRow);
int currentPointer = nextRow[partitionIndex];
nextRow[partitionIndex] = row;
rowIndex[row] = currentPointer;
}
return new RowIndex(nextRow, rowIndex);
}
private RowIndex(int[] nextRow, int[] rowIndex)
{
this.nextRow = requireNonNull(nextRow, "nextRow is null");
this.rowIndex = requireNonNull(rowIndex, "rowIndex is null");
}
public boolean hasNextRow(int partition)
{
return peekRow(partition) != NIL;
}
public int peekRow(int partition)
{
return nextRow[getPartitionIndex(partition, nextRow)];
}
public int nextRow(int partition)
{
int partitionIndex = getPartitionIndex(partition, nextRow);
int result = nextRow[partitionIndex];
nextRow[partitionIndex] = rowIndex[result];
return result;
}
private static int getPartitionIndex(int partition, int[] nextRow)
{
if (partition == REPLICATED_ROW_PARTITION_ID) {
return nextRow.length - 1;
}
return partition;
}
}
}