TestThriftIndexPageSource.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.connector.thrift;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.InMemoryRecordSet;
import com.facebook.presto.spi.SchemaTableName;
import com.facebook.presto.thrift.api.connector.PrestoThriftId;
import com.facebook.presto.thrift.api.connector.PrestoThriftNullableColumnSet;
import com.facebook.presto.thrift.api.connector.PrestoThriftNullableSchemaName;
import com.facebook.presto.thrift.api.connector.PrestoThriftNullableTableMetadata;
import com.facebook.presto.thrift.api.connector.PrestoThriftNullableToken;
import com.facebook.presto.thrift.api.connector.PrestoThriftPageResult;
import com.facebook.presto.thrift.api.connector.PrestoThriftSchemaTableName;
import com.facebook.presto.thrift.api.connector.PrestoThriftService;
import com.facebook.presto.thrift.api.connector.PrestoThriftServiceException;
import com.facebook.presto.thrift.api.connector.PrestoThriftSplit;
import com.facebook.presto.thrift.api.connector.PrestoThriftSplitBatch;
import com.facebook.presto.thrift.api.connector.PrestoThriftTupleDomain;
import com.facebook.presto.thrift.api.datatypes.PrestoThriftInteger;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.stream.IntStream;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.thrift.api.datatypes.PrestoThriftBlock.integerData;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.util.concurrent.Futures.immediateFuture;
import static java.util.Collections.shuffle;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
public class TestThriftIndexPageSource
{
private static final long MAX_BYTES_PER_RESPONSE = 16_000_000;
@Test
public void testGetNextPageTwoConcurrentRequests()
throws Exception
{
final int splits = 3;
final int lookupRequestsConcurrency = 2;
final int rowsPerSplit = 1;
List<SettableFuture<PrestoThriftPageResult>> futures = IntStream.range(0, splits)
.mapToObj(i -> SettableFuture.<PrestoThriftPageResult>create())
.collect(toImmutableList());
List<CountDownLatch> signals = IntStream.range(0, splits)
.mapToObj(i -> new CountDownLatch(1))
.collect(toImmutableList());
TestingThriftService client = new TestingThriftService(rowsPerSplit, false, false)
{
@Override
public ListenableFuture<PrestoThriftPageResult> getRows(PrestoThriftId splitId, List<String> columns, long maxBytes, PrestoThriftNullableToken nextToken)
{
int key = Ints.fromByteArray(splitId.getId());
signals.get(key).countDown();
return futures.get(key);
}
};
ThriftConnectorStats stats = new ThriftConnectorStats();
long pageSizeReceived = 0;
ThriftIndexPageSource pageSource = new ThriftIndexPageSource(
(context, headers) -> client,
ImmutableMap.of(),
stats,
new ThriftIndexHandle(new SchemaTableName("default", "table1"), TupleDomain.all()),
ImmutableList.of(column("a", INTEGER)),
ImmutableList.of(column("b", INTEGER)),
new InMemoryRecordSet(ImmutableList.of(INTEGER), generateKeys(0, splits)),
MAX_BYTES_PER_RESPONSE,
lookupRequestsConcurrency);
assertNull(pageSource.getNextPage());
assertEquals((long) stats.getIndexPageSize().getAllTime().getTotal(), 0);
signals.get(0).await(1, SECONDS);
signals.get(1).await(1, SECONDS);
signals.get(2).await(1, SECONDS);
assertEquals(signals.get(0).getCount(), 0, "first request wasn't sent");
assertEquals(signals.get(1).getCount(), 0, "second request wasn't sent");
assertEquals(signals.get(2).getCount(), 1, "third request shouldn't be sent");
// at this point first two requests were sent
assertFalse(pageSource.isFinished());
assertNull(pageSource.getNextPage());
assertEquals((long) stats.getIndexPageSize().getAllTime().getTotal(), 0);
// completing the second request
futures.get(1).set(pageResult(20, null));
Page page = pageSource.getNextPage();
pageSizeReceived += page.getSizeInBytes();
assertEquals((long) stats.getIndexPageSize().getAllTime().getTotal(), pageSizeReceived);
assertNotNull(page);
assertEquals(page.getPositionCount(), 1);
assertEquals(page.getBlock(0).getInt(0), 20);
// not complete yet
assertFalse(pageSource.isFinished());
// once one of the requests completes the next one should be sent
signals.get(2).await(1, SECONDS);
assertEquals(signals.get(2).getCount(), 0, "third request wasn't sent");
// completing the first request
futures.get(0).set(pageResult(10, null));
page = pageSource.getNextPage();
assertNotNull(page);
pageSizeReceived += page.getSizeInBytes();
assertEquals((long) stats.getIndexPageSize().getAllTime().getTotal(), pageSizeReceived);
assertEquals(page.getPositionCount(), 1);
assertEquals(page.getBlock(0).getInt(0), 10);
// still not complete
assertFalse(pageSource.isFinished());
// completing the third request
futures.get(2).set(pageResult(30, null));
page = pageSource.getNextPage();
assertNotNull(page);
pageSizeReceived += page.getSizeInBytes();
assertEquals((long) stats.getIndexPageSize().getAllTime().getTotal(), pageSizeReceived);
assertEquals(page.getPositionCount(), 1);
assertEquals(page.getBlock(0).getInt(0), 30);
// finished now
assertTrue(pageSource.isFinished());
// after completion
assertNull(pageSource.getNextPage());
pageSource.close();
}
@Test
public void testGetNextPageMultipleSplitRequest()
throws Exception
{
runGeneralTest(5, 2, 2, true);
}
@Test
public void testGetNextPageNoSplits()
throws Exception
{
runGeneralTest(0, 2, 2, false);
}
@Test
public void testGetNextPageOneConcurrentRequest()
throws Exception
{
runGeneralTest(3, 1, 3, false);
}
@Test
public void testGetNextPageMoreConcurrencyThanRequestsNoContinuation()
throws Exception
{
runGeneralTest(2, 4, 1, false);
}
private static void runGeneralTest(int splits, int lookupRequestsConcurrency, int rowsPerSplit, boolean twoSplitBatches)
throws Exception
{
TestingThriftService client = new TestingThriftService(rowsPerSplit, true, twoSplitBatches);
ThriftIndexPageSource pageSource = new ThriftIndexPageSource(
(context, headers) -> client,
ImmutableMap.of(),
new ThriftConnectorStats(),
new ThriftIndexHandle(new SchemaTableName("default", "table1"), TupleDomain.all()),
ImmutableList.of(column("a", INTEGER)),
ImmutableList.of(column("b", INTEGER)),
new InMemoryRecordSet(ImmutableList.of(INTEGER), generateKeys(1, splits + 1)),
MAX_BYTES_PER_RESPONSE,
lookupRequestsConcurrency);
List<Integer> actual = new ArrayList<>();
while (!pageSource.isFinished()) {
CompletableFuture<?> blocked = pageSource.isBlocked();
blocked.get(1, SECONDS);
Page page = pageSource.getNextPage();
if (page != null) {
Block block = page.getBlock(0);
for (int position = 0; position < block.getPositionCount(); position++) {
actual.add(block.getInt(position));
}
}
}
Collections.sort(actual);
List<Integer> expected = new ArrayList<>(splits * rowsPerSplit);
for (int split = 1; split <= splits; split++) {
for (int row = 0; row < rowsPerSplit; row++) {
expected.add(split * 10 + row);
}
}
assertEquals(actual, expected);
// must be null after finish
assertNull(pageSource.getNextPage());
pageSource.close();
}
private static class TestingThriftService
implements PrestoThriftService
{
private final int rowsPerSplit;
private final boolean shuffleSplits;
private final boolean twoSplitBatches;
public TestingThriftService(int rowsPerSplit, boolean shuffleSplits, boolean twoSplitBatches)
{
this.rowsPerSplit = rowsPerSplit;
this.shuffleSplits = shuffleSplits;
this.twoSplitBatches = twoSplitBatches;
}
@Override
public ListenableFuture<PrestoThriftSplitBatch> getIndexSplits(PrestoThriftSchemaTableName schemaTableName, List<String> indexColumnNames, List<String> outputColumnNames, PrestoThriftPageResult keys, PrestoThriftTupleDomain outputConstraint, int maxSplitCount, PrestoThriftNullableToken nextToken)
{
if (keys.getRowCount() == 0) {
return immediateFuture(new PrestoThriftSplitBatch(ImmutableList.of(), null));
}
PrestoThriftId newNextToken = null;
int[] values = keys.getColumnBlocks().get(0).getIntegerData().getInts();
int begin;
int end;
if (twoSplitBatches) {
if (nextToken.getToken() == null) {
begin = 0;
end = values.length / 2;
newNextToken = new PrestoThriftId(Ints.toByteArray(1));
}
else {
begin = values.length / 2;
end = values.length;
}
}
else {
begin = 0;
end = values.length;
}
List<PrestoThriftSplit> splits = new ArrayList<>(end - begin);
for (int i = begin; i < end; i++) {
splits.add(new PrestoThriftSplit(new PrestoThriftId(Ints.toByteArray(values[i])), ImmutableList.of()));
}
if (shuffleSplits) {
shuffle(splits);
}
return immediateFuture(new PrestoThriftSplitBatch(splits, newNextToken));
}
@Override
public ListenableFuture<PrestoThriftPageResult> getRows(PrestoThriftId splitId, List<String> columns, long maxBytes, PrestoThriftNullableToken nextToken)
{
if (rowsPerSplit == 0) {
return immediateFuture(new PrestoThriftPageResult(ImmutableList.of(), 0, null));
}
int key = Ints.fromByteArray(splitId.getId());
int offset = nextToken.getToken() != null ? Ints.fromByteArray(nextToken.getToken().getId()) : 0;
PrestoThriftId newNextToken = offset + 1 < rowsPerSplit ? new PrestoThriftId(Ints.toByteArray(offset + 1)) : null;
return immediateFuture(pageResult(key * 10 + offset, newNextToken));
}
// methods below are not used for the test
@Override
public List<String> listSchemaNames()
throws PrestoThriftServiceException
{
throw new UnsupportedOperationException();
}
@Override
public List<PrestoThriftSchemaTableName> listTables(PrestoThriftNullableSchemaName schemaNameOrNull)
throws PrestoThriftServiceException
{
throw new UnsupportedOperationException();
}
@Override
public PrestoThriftNullableTableMetadata getTableMetadata(PrestoThriftSchemaTableName schemaTableName)
throws PrestoThriftServiceException
{
throw new UnsupportedOperationException();
}
@Override
public ListenableFuture<PrestoThriftSplitBatch> getSplits(PrestoThriftSchemaTableName schemaTableName, PrestoThriftNullableColumnSet desiredColumns, PrestoThriftTupleDomain outputConstraint, int maxSplitCount, PrestoThriftNullableToken nextToken)
{
throw new UnsupportedOperationException();
}
}
private static ThriftColumnHandle column(String name, Type type)
{
return new ThriftColumnHandle(name, type, null, false);
}
private static List<List<Integer>> generateKeys(int beginInclusive, int endExclusive)
{
return IntStream.range(beginInclusive, endExclusive)
.mapToObj(ImmutableList::of)
.collect(toImmutableList());
}
private static PrestoThriftPageResult pageResult(int value, PrestoThriftId nextToken)
{
return new PrestoThriftPageResult(ImmutableList.of(integerData(new PrestoThriftInteger(null, new int[] {value}))), 1, nextToken);
}
}