OperatorAssertion.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;
import com.facebook.presto.Session;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.RowBlockBuilder;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.testing.MaterializedResult;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.IntStream;
import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue;
import static com.facebook.airlift.concurrent.MoreFutures.tryGetFutureValue;
import static com.facebook.airlift.testing.Assertions.assertEqualsIgnoreOrder;
import static com.facebook.presto.operator.PageAssertions.assertPageEquals;
import static com.facebook.presto.testing.assertions.Assert.assertEquals;
import static com.facebook.presto.util.StructuralTestUtil.appendToBlockBuilder;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.testng.Assert.fail;
public final class OperatorAssertion
{
private static final Duration BLOCKED_DEFAULT_TIMEOUT = new Duration(10, MILLISECONDS);
private static final Duration UNBLOCKED_DEFAULT_TIMEOUT = new Duration(5, SECONDS);
private OperatorAssertion()
{
}
public static List<Page> toPages(Operator operator, Iterator<Page> input)
{
return ImmutableList.<Page>builder()
.addAll(toPagesPartial(operator, input))
.addAll(finishOperator(operator))
.build();
}
public static List<Page> toPages(Operator operator, Iterator<Page> input, boolean revokeMemoryWhenAddingPages)
{
return ImmutableList.<Page>builder()
.addAll(toPagesPartial(operator, input, revokeMemoryWhenAddingPages))
.addAll(finishOperator(operator))
.build();
}
public static List<Page> toPagesPartial(Operator operator, Iterator<Page> input)
{
return toPagesPartial(operator, input, true);
}
public static List<Page> toPagesPartial(Operator operator, Iterator<Page> input, boolean revokeMemory)
{
// verify initial state
assertEquals(operator.isFinished(), false);
ImmutableList.Builder<Page> outputPages = ImmutableList.builder();
for (int loopsSinceLastPage = 0; loopsSinceLastPage < 1_000; loopsSinceLastPage++) {
if (handledBlocked(operator)) {
continue;
}
if (revokeMemory) {
handleMemoryRevoking(operator);
}
if (input.hasNext() && operator.needsInput()) {
operator.addInput(input.next());
loopsSinceLastPage = 0;
}
Page outputPage = operator.getOutput();
if (outputPage != null && outputPage.getPositionCount() != 0) {
outputPages.add(outputPage);
loopsSinceLastPage = 0;
}
}
return outputPages.build();
}
public static List<Page> finishOperator(Operator operator)
{
ImmutableList.Builder<Page> outputPages = ImmutableList.builder();
for (int loopsSinceLastPage = 0; !operator.isFinished() && loopsSinceLastPage < 1_000; loopsSinceLastPage++) {
if (handledBlocked(operator)) {
continue;
}
operator.finish();
Page outputPage = operator.getOutput();
if (outputPage != null && outputPage.getPositionCount() != 0) {
outputPages.add(outputPage);
loopsSinceLastPage = 0;
}
// revoke memory when output pages have started being produced
handleMemoryRevoking(operator);
}
assertEquals(operator.isFinished(), true, "Operator did not finish");
assertEquals(operator.needsInput(), false, "Operator still wants input");
assertEquals(operator.isBlocked().isDone(), true, "Operator is blocked");
return outputPages.build();
}
private static boolean handledBlocked(Operator operator)
{
ListenableFuture<?> isBlocked = operator.isBlocked();
if (!isBlocked.isDone()) {
tryGetFutureValue(isBlocked, 1, TimeUnit.MILLISECONDS);
return true;
}
return false;
}
private static void handleMemoryRevoking(Operator operator)
{
if (operator.getOperatorContext().getReservedRevocableBytes() > 0) {
getFutureValue(operator.startMemoryRevoke());
operator.finishMemoryRevoke();
}
}
public static List<Page> toPages(OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input)
{
return toPages(operatorFactory, driverContext, input, true);
}
public static List<Page> toPages(OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input, boolean revokeMemoryWhenAddingPages)
{
try (Operator operator = operatorFactory.createOperator(driverContext)) {
return toPages(operator, input.iterator(), revokeMemoryWhenAddingPages);
}
catch (Exception e) {
throwIfUnchecked(e);
throw new RuntimeException(e);
}
}
public static List<Page> toPages(OperatorFactory operatorFactory, DriverContext driverContext)
{
return toPages(operatorFactory, driverContext, ImmutableList.of());
}
public static MaterializedResult toMaterializedResult(Session session, List<Type> types, List<Page> pages)
{
// materialize pages
MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder(session, types);
for (Page outputPage : pages) {
resultBuilder.page(outputPage);
}
return resultBuilder.build();
}
public static Block toRow(List<Type> parameterTypes, Object... values)
{
checkArgument(parameterTypes.size() == values.length, "parameterTypes.size(" + parameterTypes.size() + ") does not equal to values.length(" + values.length + ")");
RowType rowType = RowType.anonymous(parameterTypes);
BlockBuilder blockBuilder = new RowBlockBuilder(parameterTypes, null, 1);
BlockBuilder singleRowBlockWriter = blockBuilder.beginBlockEntry();
for (int i = 0; i < values.length; i++) {
appendToBlockBuilder(parameterTypes.get(i), values[i], singleRowBlockWriter);
}
blockBuilder.closeEntry();
return rowType.getObject(blockBuilder, 0);
}
public static void assertOperatorEquals(OperatorFactory operatorFactory, List<Type> types, DriverContext driverContext, List<Page> input, List<Page> expected)
{
List<Page> actual = toPages(operatorFactory, driverContext, input);
assertEquals(actual.size(), expected.size());
for (int i = 0; i < actual.size(); i++) {
assertPageEquals(types, actual.get(i), expected.get(i));
}
}
public static void assertOperatorEquals(OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input, MaterializedResult expected)
{
assertOperatorEquals(operatorFactory, driverContext, input, expected, true);
}
public static void assertOperatorEquals(
OperatorFactory operatorFactory,
DriverContext driverContext,
List<Page> input,
MaterializedResult expected,
boolean revokeMemoryWhenAddingPages)
{
assertOperatorEquals(operatorFactory, driverContext, input, expected, false, ImmutableList.of(), revokeMemoryWhenAddingPages);
}
public static void assertOperatorEquals(OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input, MaterializedResult expected, boolean hashEnabled, List<Integer> hashChannels)
{
assertOperatorEquals(operatorFactory, driverContext, input, expected, hashEnabled, hashChannels, true);
}
public static void assertOperatorEquals(
OperatorFactory operatorFactory,
DriverContext driverContext,
List<Page> input,
MaterializedResult expected,
boolean hashEnabled,
List<Integer> hashChannels,
boolean revokeMemoryWhenAddingPages)
{
List<Page> pages = toPages(operatorFactory, driverContext, input, revokeMemoryWhenAddingPages);
if (hashEnabled && !hashChannels.isEmpty()) {
// Drop the hashChannel for all pages
pages = dropChannel(pages, hashChannels);
}
MaterializedResult actual = toMaterializedResult(driverContext.getSession(), expected.getTypes(), pages);
assertEquals(actual, expected);
}
public static void assertOperatorEqualsIgnoreOrder(
OperatorFactory operatorFactory,
DriverContext driverContext,
List<Page> input,
MaterializedResult expected)
{
assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, false);
}
public static void assertOperatorEqualsIgnoreOrder(
OperatorFactory operatorFactory,
DriverContext driverContext,
List<Page> input,
MaterializedResult expected,
boolean revokeMemoryWhenAddingPages)
{
assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, false, Optional.empty(), revokeMemoryWhenAddingPages);
}
public static void assertOperatorEqualsIgnoreOrder(
OperatorFactory operatorFactory,
DriverContext driverContext,
List<Page> input,
MaterializedResult expected,
boolean hashEnabled,
Optional<Integer> hashChannel)
{
assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, hashChannel, true);
}
public static void assertOperatorEqualsIgnoreOrder(
OperatorFactory operatorFactory,
DriverContext driverContext,
List<Page> input,
MaterializedResult expected,
boolean hashEnabled,
Optional<Integer> hashChannel,
boolean revokeMemoryWhenAddingPages)
{
assertPagesEqualIgnoreOrder(
driverContext,
toPages(operatorFactory, driverContext, input, revokeMemoryWhenAddingPages),
expected,
hashEnabled,
hashChannel);
}
public static void assertPagesEqualIgnoreOrder(
DriverContext driverContext,
List<Page> actualPages,
MaterializedResult expected,
boolean hashEnabled,
Optional<Integer> hashChannel)
{
if (hashEnabled && hashChannel.isPresent()) {
// Drop the hashChannel for all pages
actualPages = dropChannel(actualPages, ImmutableList.of(hashChannel.get()));
}
MaterializedResult actual = toMaterializedResult(driverContext.getSession(), expected.getTypes(), actualPages);
assertEqualsIgnoreOrder(actual.getMaterializedRows(), expected.getMaterializedRows());
}
public static void assertOperatorIsBlocked(Operator operator)
{
assertOperatorIsBlocked(operator, BLOCKED_DEFAULT_TIMEOUT);
}
public static void assertOperatorIsBlocked(Operator operator, Duration timeout)
{
if (waitForOperatorToUnblock(operator, timeout)) {
fail("Operator is expected to be blocked for at least " + timeout.toString());
}
}
public static void assertOperatorIsUnblocked(Operator operator)
{
assertOperatorIsUnblocked(operator, UNBLOCKED_DEFAULT_TIMEOUT);
}
public static void assertOperatorIsUnblocked(Operator operator, Duration timeout)
{
if (!waitForOperatorToUnblock(operator, timeout)) {
fail("Operator is expected to be unblocked within " + timeout.toString());
}
}
private static boolean waitForOperatorToUnblock(Operator operator, Duration timeout)
{
try {
operator.isBlocked().get(timeout.toMillis(), TimeUnit.MILLISECONDS);
return true;
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("interrupted", e);
}
catch (ExecutionException e) {
throw new RuntimeException(e.getCause());
}
catch (TimeoutException expected) {
return false;
}
}
static <T> List<T> without(List<T> list, Collection<Integer> indexes)
{
Set<Integer> indexesSet = ImmutableSet.copyOf(indexes);
return IntStream.range(0, list.size())
.filter(index -> !indexesSet.contains(index))
.mapToObj(list::get)
.collect(toImmutableList());
}
static List<Page> dropChannel(List<Page> pages, List<Integer> channels)
{
List<Page> actualPages = new ArrayList<>();
for (Page page : pages) {
int channel = 0;
Block[] blocks = new Block[page.getChannelCount() - channels.size()];
for (int i = 0; i < page.getChannelCount(); i++) {
if (channels.contains(i)) {
continue;
}
blocks[channel++] = page.getBlock(i);
}
actualPages.add(new Page(blocks));
}
return actualPages;
}
}