GroupByHashYieldAssertion.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;
import com.facebook.airlift.stats.TestingGcMonitor;
import com.facebook.presto.RowPagesBuilder;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.memory.MemoryPool;
import com.facebook.presto.memory.QueryContext;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.memory.MemoryPoolId;
import com.facebook.presto.spiller.SpillSpaceTracker;
import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Function;
import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed;
import static com.facebook.airlift.json.JsonCodec.listJsonCodec;
import static com.facebook.airlift.testing.Assertions.assertBetweenInclusive;
import static com.facebook.airlift.testing.Assertions.assertGreaterThan;
import static com.facebook.airlift.testing.Assertions.assertLessThan;
import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.operator.OperatorAssertion.finishOperator;
import static com.facebook.presto.testing.TestingTaskContext.createTaskContext;
import static com.facebook.presto.testing.assertions.Assert.assertEquals;
import static io.airlift.units.DataSize.Unit.GIGABYTE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
public final class GroupByHashYieldAssertion
{
private static final ExecutorService EXECUTOR = newCachedThreadPool(daemonThreadsNamed("test-executor-%s"));
private static final ScheduledExecutorService SCHEDULED_EXECUTOR = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s"));
private GroupByHashYieldAssertion() {}
public static List<Page> createPagesWithDistinctHashKeys(Type type, int pageCount, int positionCountPerPage)
{
RowPagesBuilder rowPagesBuilder = rowPagesBuilder(true, ImmutableList.of(0), type);
for (int i = 0; i < pageCount; i++) {
rowPagesBuilder.addSequencePage(positionCountPerPage, positionCountPerPage * i);
}
return rowPagesBuilder.build();
}
/**
* @param operatorFactory creates an Operator that should directly or indirectly contain GroupByHash
* @param getHashCapacity returns the hash table capacity for the input operator
* @param additionalMemoryInBytes the memory used in addition to the GroupByHash in the operator (e.g., aggregator)
*/
public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List<Page> input, Type hashKeyType, OperatorFactory operatorFactory, Function<Operator, Integer> getHashCapacity, long additionalMemoryInBytes)
{
assertLessThan(additionalMemoryInBytes, 1L << 21, "additionalMemoryInBytes should be a relatively small number");
List<Page> result = new LinkedList<>();
// mock an adjustable memory pool
QueryId queryId1 = new QueryId("test_query1");
QueryId queryId2 = new QueryId("test_query2");
MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(1, GIGABYTE));
QueryContext queryContext = new QueryContext(
queryId2,
new DataSize(512, MEGABYTE),
new DataSize(1024, MEGABYTE),
new DataSize(512, MEGABYTE),
new DataSize(1, GIGABYTE),
memoryPool,
new TestingGcMonitor(),
EXECUTOR,
SCHEDULED_EXECUTOR,
new DataSize(512, MEGABYTE),
new SpillSpaceTracker(new DataSize(512, MEGABYTE)),
listJsonCodec(TaskMemoryReservationSummary.class));
DriverContext driverContext = createTaskContext(queryContext, EXECUTOR, TEST_SESSION)
.addPipelineContext(0, true, true, false)
.addDriverContext();
Operator operator = operatorFactory.createOperator(driverContext);
// run operator
int yieldCount = 0;
long expectedReservedExtraBytes = 0;
for (Page page : input) {
// unblocked
assertTrue(operator.needsInput());
// saturate the pool with a tiny memory left
long reservedMemoryInBytes = memoryPool.getFreeBytes() - additionalMemoryInBytes;
memoryPool.reserve(queryId1, "test", reservedMemoryInBytes);
long oldMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
int oldCapacity = getHashCapacity.apply(operator);
// add a page and verify different behaviors
operator.addInput(page);
// get output to consume the input
Page output = operator.getOutput();
if (output != null) {
result.add(output);
}
long newMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
// Skip if the memory usage is not large enough since we cannot distinguish
// between rehash and memory used by aggregator
if (newMemoryUsage < new DataSize(3, MEGABYTE).toBytes()) {
// free the pool for the next iteration
memoryPool.free(queryId1, "test", reservedMemoryInBytes);
// this required in case input is blocked
operator.getOutput();
continue;
}
long actualIncreasedMemory = newMemoryUsage - oldMemoryUsage;
if (operator.needsInput()) {
// We have successfully added a page
// Assert we are not blocked
assertTrue(operator.getOperatorContext().isWaitingForMemory().isDone());
// assert the hash capacity is not changed; otherwise, we should have yielded
assertEquals(oldCapacity, getHashCapacity.apply(operator));
// We are not going to rehash; therefore, assert the memory increase only comes from the aggregator
assertLessThan(actualIncreasedMemory, additionalMemoryInBytes);
// free the pool for the next iteration
memoryPool.free(queryId1, "test", reservedMemoryInBytes);
}
else {
// We failed to finish the page processing i.e. we yielded
yieldCount++;
// Assert we are blocked
assertFalse(operator.getOperatorContext().isWaitingForMemory().isDone());
// Hash table capacity should not change
assertEquals(oldCapacity, (long) getHashCapacity.apply(operator));
expectedReservedExtraBytes = getHashTableSizeInBytes(hashKeyType, oldCapacity * 2) + page.getRetainedSizeInBytes();
// Increased memory is no smaller than the hash table size and no greater than the hash table size + the memory used by aggregator
assertBetweenInclusive(actualIncreasedMemory, expectedReservedExtraBytes, 2 * expectedReservedExtraBytes + additionalMemoryInBytes);
// Output should be blocked as well
assertNull(operator.getOutput());
// Free the pool to unblock
memoryPool.free(queryId1, "test", reservedMemoryInBytes);
// Trigger a process through getOutput() or needsInput()
output = operator.getOutput();
if (output != null) {
result.add(output);
}
assertTrue(operator.needsInput());
// Hash table capacity has increased
assertGreaterThan(getHashCapacity.apply(operator), oldCapacity);
// Assert the estimated reserved memory after rehash is lower than the one before rehash (extra memory allocation has been released)
long rehashedMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
long previousHashTableSizeInBytes = getHashTableSizeInBytes(hashKeyType, oldCapacity);
long expectedMemoryUsageAfterRehash = newMemoryUsage - previousHashTableSizeInBytes;
double memoryUsageErrorUpperBound = 1.02;
double memoryUsageError = rehashedMemoryUsage * 1.0 / expectedMemoryUsageAfterRehash;
assertBetweenInclusive(memoryUsageError, 0.98, memoryUsageErrorUpperBound);
// unblocked
assertTrue(operator.needsInput());
}
}
result.addAll(finishOperator(operator));
return new GroupByHashYieldResult(yieldCount, expectedReservedExtraBytes, result);
}
private static long getHashTableSizeInBytes(Type hashKeyType, int capacity)
{
if (hashKeyType == BIGINT) {
// groupIds and values double by hashCapacity; while valuesByGroupId double by maxFill = hashCapacity / 0.75
return capacity * (long) (Long.BYTES * 1.75 + Integer.BYTES);
}
// groupAddressByHash, groupIdsByHash, and rawHashByHashPosition double by hashCapacity; while groupAddressByGroupId double by maxFill = hashCapacity / 0.75
return capacity * (long) (Long.BYTES * 1.75 + Integer.BYTES + Byte.BYTES);
}
public static final class GroupByHashYieldResult
{
private final int yieldCount;
private final long maxReservedBytes;
private final List<Page> output;
public GroupByHashYieldResult(int yieldCount, long maxReservedBytes, List<Page> output)
{
this.yieldCount = yieldCount;
this.maxReservedBytes = maxReservedBytes;
this.output = requireNonNull(output, "output is null");
}
public int getYieldCount()
{
return yieldCount;
}
public long getMaxReservedBytes()
{
return maxReservedBytes;
}
public List<Page> getOutput()
{
return output;
}
}
}