TestMemoryRevokingScheduler.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution;
import com.facebook.airlift.stats.CounterStat;
import com.facebook.airlift.stats.TestingGcMonitor;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.execution.TestSqlTaskManager.MockExchangeClientSupplier;
import com.facebook.presto.execution.buffer.OutputBuffers;
import com.facebook.presto.execution.buffer.SpoolingOutputBufferFactory;
import com.facebook.presto.execution.executor.TaskExecutor;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.memory.MemoryPool;
import com.facebook.presto.memory.QueryContext;
import com.facebook.presto.memory.context.LocalMemoryContext;
import com.facebook.presto.memory.context.MemoryTrackingContext;
import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.PipelineContext;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.operator.TaskMemoryReservationSummary;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.memory.MemoryPoolId;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spiller.SpillSpaceTracker;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.gen.OrderingCompiler;
import com.facebook.presto.sql.planner.LocalExecutionPlanner;
import com.google.common.base.Functions;
import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.airlift.units.DataSize;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import java.net.URI;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import static com.facebook.airlift.concurrent.Threads.threadsNamed;
import static com.facebook.airlift.json.JsonCodec.listJsonCodec;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.execution.SqlTask.createSqlTask;
import static com.facebook.presto.execution.TaskManagerConfig.TaskPriorityTracking.TASK_FAIR;
import static com.facebook.presto.execution.TaskTestUtils.PLAN_FRAGMENT;
import static com.facebook.presto.execution.TaskTestUtils.SPLIT;
import static com.facebook.presto.execution.TaskTestUtils.TABLE_SCAN_NODE_ID;
import static com.facebook.presto.execution.TaskTestUtils.createTestSplitMonitor;
import static com.facebook.presto.execution.TaskTestUtils.createTestingPlanner;
import static com.facebook.presto.execution.TaskTestUtils.updateTask;
import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.PARTITIONED;
import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers;
import static com.facebook.presto.memory.LocalMemoryManager.GENERAL_POOL;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.TaskSpillingStrategy.ORDER_BY_CREATE_TIME;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.TaskSpillingStrategy.ORDER_BY_REVOCABLE_BYTES;
import static io.airlift.units.DataSize.Unit.BYTE;
import static io.airlift.units.DataSize.Unit.GIGABYTE;
import static io.airlift.units.DataSize.Unit.KILOBYTE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static java.util.concurrent.Executors.newSingleThreadExecutor;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
@Test(singleThreaded = true)
public class TestMemoryRevokingScheduler
{
public static final OutputBuffers.OutputBufferId OUT = new OutputBuffers.OutputBufferId(0);
private final AtomicInteger idGenerator = new AtomicInteger();
private final SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(new DataSize(10, GIGABYTE));
private final Map<QueryId, QueryContext> queryContexts = new HashMap<>();
private ExecutorService singleThreadedExecutor;
private ScheduledExecutorService singleThreadedScheduledExecutor;
private ScheduledExecutorService scheduledExecutor;
private SqlTaskExecutionFactory sqlTaskExecutionFactory;
private MemoryPool memoryPool;
private Set<OperatorContext> allOperatorContexts;
@BeforeMethod
public void setUp()
{
memoryPool = new MemoryPool(GENERAL_POOL, new DataSize(10, BYTE));
TaskExecutor taskExecutor = new TaskExecutor(8, 16, 3, 4, TASK_FAIR, Ticker.systemTicker());
taskExecutor.start();
// Must be single threaded
singleThreadedExecutor = newSingleThreadExecutor(threadsNamed("task-notification-%s"));
singleThreadedScheduledExecutor = newScheduledThreadPool(1, threadsNamed("task-notification-%s"));
scheduledExecutor = newScheduledThreadPool(2, threadsNamed("task-notification-%s"));
LocalExecutionPlanner planner = createTestingPlanner();
sqlTaskExecutionFactory = new SqlTaskExecutionFactory(
singleThreadedExecutor,
taskExecutor,
planner,
new BlockEncodingManager(),
new OrderingCompiler(),
createTestSplitMonitor(),
new TaskManagerConfig()
.setPerOperatorAllocationTrackingEnabled(true)
.setTaskCpuTimerEnabled(true)
.setPerOperatorAllocationTrackingEnabled(true)
.setTaskAllocationTrackingEnabled(true));
allOperatorContexts = null;
TestOperatorContext.firstOperator = null;
}
@AfterMethod(alwaysRun = true)
public void tearDown()
{
queryContexts.clear();
memoryPool = null;
singleThreadedExecutor.shutdownNow();
singleThreadedScheduledExecutor.shutdown();
scheduledExecutor.shutdownNow();
}
@Test
public void testMemoryPoolRevoking()
throws Exception
{
QueryContext q1 = getOrCreateQueryContext(new QueryId("q1"), memoryPool);
QueryContext q2 = getOrCreateQueryContext(new QueryId("q2"), memoryPool);
SqlTask sqlTask1 = newSqlTask(q1.getQueryId(), memoryPool);
SqlTask sqlTask2 = newSqlTask(q2.getQueryId(), memoryPool);
TaskContext taskContext1 = getOrCreateTaskContext(sqlTask1);
PipelineContext pipelineContext11 = taskContext1.addPipelineContext(0, false, false, false);
DriverContext driverContext111 = pipelineContext11.addDriverContext();
OperatorContext operatorContext1 = driverContext111.addOperatorContext(1, new PlanNodeId("na"), "na");
OperatorContext operatorContext2 = driverContext111.addOperatorContext(2, new PlanNodeId("na"), "na");
DriverContext driverContext112 = pipelineContext11.addDriverContext();
OperatorContext operatorContext3 = driverContext112.addOperatorContext(3, new PlanNodeId("na"), "na");
TaskContext taskContext2 = getOrCreateTaskContext(sqlTask2);
PipelineContext pipelineContext21 = taskContext2.addPipelineContext(1, false, false, false);
DriverContext driverContext211 = pipelineContext21.addDriverContext();
OperatorContext operatorContext4 = driverContext211.addOperatorContext(4, new PlanNodeId("na"), "na");
OperatorContext operatorContext5 = driverContext211.addOperatorContext(5, new PlanNodeId("na"), "na");
List<SqlTask> tasks = ImmutableList.of(sqlTask1, sqlTask2);
MemoryRevokingScheduler scheduler = new MemoryRevokingScheduler(
singletonList(memoryPool),
() -> tasks,
queryContexts::get,
1.0,
1.0,
ORDER_BY_CREATE_TIME,
false);
try {
scheduler.start();
allOperatorContexts = ImmutableSet.of(operatorContext1, operatorContext2, operatorContext3, operatorContext4, operatorContext5);
assertMemoryRevokingNotRequested();
assertEquals(10, memoryPool.getFreeBytes());
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingNotRequested();
LocalMemoryContext revocableMemory1 = operatorContext1.localRevocableMemoryContext();
LocalMemoryContext revocableMemory3 = operatorContext3.localRevocableMemoryContext();
LocalMemoryContext revocableMemory4 = operatorContext4.localRevocableMemoryContext();
LocalMemoryContext revocableMemory5 = operatorContext5.localRevocableMemoryContext();
revocableMemory1.setBytes(3);
revocableMemory3.setBytes(6);
assertEquals(1, memoryPool.getFreeBytes());
scheduler.awaitAsynchronousCallbacksRun();
// we are still good - no revoking needed
assertMemoryRevokingNotRequested();
revocableMemory4.setBytes(7);
assertEquals(-6, memoryPool.getFreeBytes());
scheduler.awaitAsynchronousCallbacksRun();
// we need to revoke 3 and 6
assertMemoryRevokingRequestedFor(operatorContext1, operatorContext3);
// lets revoke some bytes
revocableMemory1.setBytes(0);
operatorContext1.resetMemoryRevokingRequested();
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingRequestedFor(operatorContext3);
assertEquals(-3, memoryPool.getFreeBytes());
// and allocate some more
revocableMemory5.setBytes(3);
assertEquals(-6, memoryPool.getFreeBytes());
scheduler.awaitAsynchronousCallbacksRun();
// we are still good with just OC3 in process of revoking
assertMemoryRevokingRequestedFor(operatorContext3);
// and allocate some more
revocableMemory5.setBytes(4);
assertEquals(-7, memoryPool.getFreeBytes());
scheduler.awaitAsynchronousCallbacksRun();
// now we have to trigger revoking for OC4
assertMemoryRevokingRequestedFor(operatorContext3, operatorContext4);
}
finally {
scheduler.stop();
}
}
@Test
public void testCountAlreadyRevokedMemoryWithinAPool()
throws Exception
{
// Given
MemoryPool anotherMemoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(10, BYTE));
SqlTask sqlTask1 = newSqlTask(new QueryId("q1"), anotherMemoryPool);
OperatorContext operatorContext1 = createContexts(sqlTask1);
SqlTask sqlTask2 = newSqlTask(new QueryId("q2"), memoryPool);
OperatorContext operatorContext2 = createContexts(sqlTask2);
List<SqlTask> tasks = ImmutableList.of(sqlTask1, sqlTask2);
MemoryRevokingScheduler scheduler = new MemoryRevokingScheduler(
asList(memoryPool, anotherMemoryPool),
() -> tasks,
queryContexts::get,
1.0,
1.0,
ORDER_BY_CREATE_TIME,
false);
try {
scheduler.start();
allOperatorContexts = ImmutableSet.of(operatorContext1, operatorContext2);
/*
* sqlTask1 fills its pool
*/
operatorContext1.localRevocableMemoryContext().setBytes(12);
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingRequestedFor(operatorContext1);
/*
* When sqlTask2 fills its pool
*/
operatorContext2.localRevocableMemoryContext().setBytes(12);
scheduler.awaitAsynchronousCallbacksRun();
/*
* Then sqlTask2 should be asked to revoke its memory too
*/
assertMemoryRevokingRequestedFor(operatorContext1, operatorContext2);
}
finally {
scheduler.stop();
}
}
/**
* Ensures that when revoking is requested, the first task to start revoking is based on the {@link FeaturesConfig.TaskSpillingStrategy}
*/
@Test
public void testTaskRevokingOrderForCreateTime()
throws Exception
{
SqlTask sqlTask1 = newSqlTask(new QueryId("query"), memoryPool);
TestOperatorContext operatorContext1 = createTestingOperatorContexts(sqlTask1, "operator1");
SqlTask sqlTask2 = newSqlTask(new QueryId("query"), memoryPool);
TestOperatorContext operatorContext2 = createTestingOperatorContexts(sqlTask2, "operator2");
allOperatorContexts = ImmutableSet.of(operatorContext1, operatorContext2);
List<SqlTask> tasks = ImmutableList.of(sqlTask1, sqlTask2);
MemoryRevokingScheduler scheduler = new MemoryRevokingScheduler(
singletonList(memoryPool),
() -> tasks,
queryContexts::get,
1.0,
1.0,
ORDER_BY_CREATE_TIME,
false);
try {
scheduler.start(); // no periodic check initiated
assertMemoryRevokingNotRequested();
operatorContext1.localRevocableMemoryContext().setBytes(11);
operatorContext2.localRevocableMemoryContext().setBytes(12);
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingRequestedFor(operatorContext1, operatorContext2);
assertEquals(TestOperatorContext.firstOperator, "operator1"); // operator1 should revoke first as it belongs to a task that was created earlier
}
finally {
scheduler.stop();
}
}
@Test
public void testTaskRevokingOrderForRevocableBytes()
throws Exception
{
SqlTask sqlTask1 = newSqlTask(new QueryId("query"), memoryPool);
TestOperatorContext operatorContext1 = createTestingOperatorContexts(sqlTask1, "operator1");
SqlTask sqlTask2 = newSqlTask(new QueryId("query"), memoryPool);
TestOperatorContext operatorContext2 = createTestingOperatorContexts(sqlTask2, "operator2");
allOperatorContexts = ImmutableSet.of(operatorContext1, operatorContext2);
List<SqlTask> tasks = ImmutableList.of(sqlTask1, sqlTask2);
MemoryRevokingScheduler scheduler = new MemoryRevokingScheduler(
singletonList(memoryPool),
() -> tasks,
queryContexts::get,
1.0,
1.0,
ORDER_BY_REVOCABLE_BYTES,
false);
try {
scheduler.start();
// Waiting for all existing tasks in scheduler's memoryRevocationExecutor to complete
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingNotRequested();
CompletableFuture<Void> future = new CompletableFuture<>();
// Submit a task that will only be completed after the following two memory reserving actions have occurred.
// It can make sure that no asynchronous memory revoking task occurs between the two memory reserving actions,
// since `memoryRevocationExecutor` of the scheduler where all these tasks run is a single threaded pool
scheduler.submitAsynchronousCallable(() -> future.get());
operatorContext1.localRevocableMemoryContext().setBytes(11);
operatorContext2.localRevocableMemoryContext().setBytes(12);
future.complete(null);
// Waiting for all existing tasks in scheduler's memoryRevocationExecutor to complete
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingRequestedFor(operatorContext1, operatorContext2);
assertEquals(TestOperatorContext.firstOperator, "operator2"); // operator2 should revoke first since it (and it's encompassing task) has allocated more bytes
}
finally {
scheduler.stop();
}
}
@Test
public void testTaskThresholdRevokingScheduler()
throws Exception
{
SqlTask sqlTask1 = newSqlTask(new QueryId("query"), memoryPool);
TestOperatorContext operatorContext11 = createTestingOperatorContexts(sqlTask1, "operator11");
TestOperatorContext operatorContext12 = createTestingOperatorContexts(sqlTask1, "operator12");
SqlTask sqlTask2 = newSqlTask(new QueryId("query"), memoryPool);
TestOperatorContext operatorContext2 = createTestingOperatorContexts(sqlTask2, "operator2");
allOperatorContexts = ImmutableSet.of(operatorContext11, operatorContext12, operatorContext2);
List<SqlTask> tasks = ImmutableList.of(sqlTask1, sqlTask2);
ImmutableMap<TaskId, SqlTask> taskMap = ImmutableMap.of(sqlTask1.getTaskId(), sqlTask1, sqlTask2.getTaskId(), sqlTask2);
TaskThresholdMemoryRevokingScheduler scheduler = new TaskThresholdMemoryRevokingScheduler(
singletonList(memoryPool), () -> tasks, taskMap::get, singleThreadedScheduledExecutor, 5L);
assertMemoryRevokingNotRequested();
operatorContext11.localRevocableMemoryContext().setBytes(3);
operatorContext2.localRevocableMemoryContext().setBytes(2);
// at this point, Task1 = 3 total bytes, Task2 = 2 total bytes
requestMemoryRevoking(scheduler);
assertMemoryRevokingNotRequested();
operatorContext12.localRevocableMemoryContext().setBytes(3);
// at this point, Task1 = 6 total bytes, Task2 = 2 total bytes
requestMemoryRevoking(scheduler);
// only operator11 should revoke since we need to revoke only 1 byte
// threshold - (operator11 + operator12) => 5 - (3 + 3) = 1 bytes to revoke
assertMemoryRevokingRequestedFor(operatorContext11);
// revoke 2 bytes in operator11
operatorContext11.localRevocableMemoryContext().setBytes(1);
// at this point, Task1 = 3 total bytes, Task2 = 2 total bytes
operatorContext11.resetMemoryRevokingRequested();
requestMemoryRevoking(scheduler);
assertMemoryRevokingNotRequested();
operatorContext12.localRevocableMemoryContext().setBytes(6); // operator12 fills up
// at this point, Task1 = 7 total bytes, Task2 = 2 total bytes
requestMemoryRevoking(scheduler);
// both operator11 and operator 12 are revoking since we revoke in order of operator creation within the task until we are below the memory revoking threshold
assertMemoryRevokingRequestedFor(operatorContext11, operatorContext12);
operatorContext11.localRevocableMemoryContext().setBytes(2);
operatorContext11.resetMemoryRevokingRequested();
operatorContext12.localRevocableMemoryContext().setBytes(2);
operatorContext12.resetMemoryRevokingRequested();
// at this point, Task1 = 4 total bytes, Task2 = 2 total bytes
requestMemoryRevoking(scheduler);
assertMemoryRevokingNotRequested(); // no need to revoke
operatorContext2.localRevocableMemoryContext().setBytes(6);
// at this point, Task1 = 4 total bytes, Task2 = 6 total bytes, operators in Task2 must be revoked
requestMemoryRevoking(scheduler);
assertMemoryRevokingRequestedFor(operatorContext2);
}
@Test
public void testTaskThresholdRevokingSchedulerImmediate()
throws Exception
{
SqlTask sqlTask1 = newSqlTask(new QueryId("query"), memoryPool);
TestOperatorContext operatorContext11 = createTestingOperatorContexts(sqlTask1, "operator11");
TestOperatorContext operatorContext12 = createTestingOperatorContexts(sqlTask1, "operator12");
SqlTask sqlTask2 = newSqlTask(new QueryId("query"), memoryPool);
TestOperatorContext operatorContext2 = createTestingOperatorContexts(sqlTask2, "operator2");
allOperatorContexts = ImmutableSet.of(operatorContext11, operatorContext12, operatorContext2);
List<SqlTask> tasks = ImmutableList.of(sqlTask1, sqlTask2);
ImmutableMap<TaskId, SqlTask> taskMap = ImmutableMap.of(sqlTask1.getTaskId(), sqlTask1, sqlTask2.getTaskId(), sqlTask2);
TaskThresholdMemoryRevokingScheduler scheduler = new TaskThresholdMemoryRevokingScheduler(
singletonList(memoryPool), () -> tasks, taskMap::get, singleThreadedScheduledExecutor, 5L);
scheduler.registerPoolListeners(); // no periodic check initiated
assertMemoryRevokingNotRequested();
operatorContext11.localRevocableMemoryContext().setBytes(3);
operatorContext2.localRevocableMemoryContext().setBytes(2);
// at this point, Task1 = 3 total bytes, Task2 = 2 total bytes
// this ensures that we are waiting for the memory revocation listener and not using polling-based revoking
awaitTaskThresholdAsynchronousCallbacksRun();
assertMemoryRevokingNotRequested();
operatorContext12.localRevocableMemoryContext().setBytes(3);
// at this point, Task1 = 6 total bytes, Task2 = 2 total bytes
awaitTaskThresholdAsynchronousCallbacksRun();
// only operator11 should revoke since we need to revoke only 1 byte
// threshold - (operator11 + operator12) => 5 - (3 + 3) = 1 bytes to revoke
assertMemoryRevokingRequestedFor(operatorContext11);
// revoke 2 bytes in operator11
operatorContext11.localRevocableMemoryContext().setBytes(1);
// at this point, Task1 = 3 total bytes, Task2 = 2 total bytes
operatorContext11.resetMemoryRevokingRequested();
awaitTaskThresholdAsynchronousCallbacksRun();
assertMemoryRevokingNotRequested();
operatorContext12.localRevocableMemoryContext().setBytes(6); // operator12 fills up
// at this point, Task1 = 7 total bytes, Task2 = 2 total bytes
awaitTaskThresholdAsynchronousCallbacksRun();
// both operator11 and operator 12 are revoking since we revoke in order of operator creation within the task until we are below the memory revoking threshold
assertMemoryRevokingRequestedFor(operatorContext11, operatorContext12);
operatorContext11.localRevocableMemoryContext().setBytes(2);
operatorContext11.resetMemoryRevokingRequested();
operatorContext12.localRevocableMemoryContext().setBytes(2);
operatorContext12.resetMemoryRevokingRequested();
// at this point, Task1 = 4 total bytes, Task2 = 2 total bytes
awaitTaskThresholdAsynchronousCallbacksRun();
assertMemoryRevokingNotRequested(); // no need to revoke
operatorContext2.localRevocableMemoryContext().setBytes(6);
// at this point, Task1 = 4 total bytes, Task2 = 6 total bytes, operators in Task2 must be revoked
awaitTaskThresholdAsynchronousCallbacksRun();
assertMemoryRevokingRequestedFor(operatorContext2);
}
@Test
public void testQueryMemoryRevoking()
throws Exception
{
// The various tasks created here use a small amount of system memory independent of what's set explicitly
// in this test. Triggering spilling based on differences of thousands of bytes rather than hundreds
// makes the test resilient to any noise that creates.
// There can still be a race condition where some of these allocations are made when the total memory is above
// the spill threshold, but in revokeMemory() some memory is reduced between when we get the total memory usage
// and when we get the task memory usage. This can cause some extra spilling.
// To prevent flakiness in the test, we reset revoke memory requested for all operators, even if only one spilled.
QueryId queryId = new QueryId("query");
// use a larger memory pool so that we don't trigger spilling due to filling the memory pool
MemoryPool queryLimitMemoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(100, GIGABYTE));
SqlTask sqlTask1 = newSqlTask(queryId, queryLimitMemoryPool);
TestOperatorContext operatorContext11 = createTestingOperatorContexts(sqlTask1, "operator11");
TestOperatorContext operatorContext12 = createTestingOperatorContexts(sqlTask1, "operator12");
SqlTask sqlTask2 = newSqlTask(queryId, queryLimitMemoryPool);
TestOperatorContext operatorContext2 = createTestingOperatorContexts(sqlTask2, "operator2");
allOperatorContexts = ImmutableSet.of(operatorContext11, operatorContext12, operatorContext2);
List<SqlTask> tasks = ImmutableList.of(sqlTask1, sqlTask2);
MemoryRevokingScheduler scheduler = new MemoryRevokingScheduler(
singletonList(queryLimitMemoryPool),
() -> tasks,
queryContexts::get,
1.0,
1.0,
ORDER_BY_REVOCABLE_BYTES,
true);
try {
scheduler.start();
assertMemoryRevokingNotRequested();
operatorContext11.localRevocableMemoryContext().setBytes(150_000);
operatorContext2.localRevocableMemoryContext().setBytes(100_000);
// at this point, Task1 = 150k total bytes, Task2 = 100k total bytes
// this ensures that we are waiting for the memory revocation listener and not using polling-based revoking
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingNotRequested();
operatorContext12.localRevocableMemoryContext().setBytes(300_000);
// at this point, Task1 = 450k total bytes, Task2 = 100k total bytes
scheduler.awaitAsynchronousCallbacksRun();
// only operator11 should revoke since we need to revoke only 50k bytes
// limit - (task1 + task2) => 500k - (450k + 100k) = 50k byte to revoke
assertMemoryRevokingRequestedFor(operatorContext11);
// revoke all bytes in operator11
operatorContext11.localRevocableMemoryContext().setBytes(0);
// at this point, Task1 = 300k total bytes, Task2 = 100k total bytes
scheduler.awaitAsynchronousCallbacksRun();
operatorContext11.resetMemoryRevokingRequested();
operatorContext12.resetMemoryRevokingRequested();
operatorContext2.resetMemoryRevokingRequested();
assertMemoryRevokingNotRequested();
operatorContext11.localRevocableMemoryContext().setBytes(20_000);
// at this point, Task1 = 320,000 total bytes (oc11 - 20k, oc12 - 300k), Task2 = 100k total bytes
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingNotRequested();
operatorContext2.localSystemMemoryContext().setBytes(150_000);
// at this point, Task1 = 320K total bytes, Task2 = 250K total bytes
// both operator11 and operator 12 are revoking since we revoke in order of operator creation within the task until we are below the memory revoking threshold
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingRequestedFor(operatorContext11, operatorContext12);
operatorContext11.localRevocableMemoryContext().setBytes(0);
operatorContext12.localRevocableMemoryContext().setBytes(0);
scheduler.awaitAsynchronousCallbacksRun();
operatorContext11.resetMemoryRevokingRequested();
operatorContext12.resetMemoryRevokingRequested();
operatorContext2.resetMemoryRevokingRequested();
assertMemoryRevokingNotRequested();
operatorContext11.localRevocableMemoryContext().setBytes(50_000);
operatorContext12.localRevocableMemoryContext().setBytes(50_000);
operatorContext2.localSystemMemoryContext().setBytes(150_000);
operatorContext2.localRevocableMemoryContext().setBytes(150_000);
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingNotRequested(); // no need to revoke
// at this point, Task1 = 75k total bytes, Task2 = 300k total bytes (150k revocable, 150k system)
operatorContext12.localUserMemoryContext().setBytes(300_000);
// at this point, Task1 = 400K total bytes (100k revocable, 300k user), Task2 = 300k total bytes (150k revocable, 150k system)
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingRequestedFor(operatorContext2, operatorContext11);
}
finally {
scheduler.stop();
}
}
@Test
public void testRevokesPoolWhenFullBeforeQueryLimit()
throws Exception
{
QueryContext q1 = getOrCreateQueryContext(new QueryId("q1"), memoryPool);
QueryContext q2 = getOrCreateQueryContext(new QueryId("q2"), memoryPool);
SqlTask sqlTask1 = newSqlTask(q1.getQueryId(), memoryPool);
SqlTask sqlTask2 = newSqlTask(q2.getQueryId(), memoryPool);
TaskContext taskContext1 = getOrCreateTaskContext(sqlTask1);
PipelineContext pipelineContext11 = taskContext1.addPipelineContext(0, false, false, false);
DriverContext driverContext111 = pipelineContext11.addDriverContext();
OperatorContext operatorContext1 = driverContext111.addOperatorContext(1, new PlanNodeId("na"), "na");
OperatorContext operatorContext2 = driverContext111.addOperatorContext(2, new PlanNodeId("na"), "na");
DriverContext driverContext112 = pipelineContext11.addDriverContext();
OperatorContext operatorContext3 = driverContext112.addOperatorContext(3, new PlanNodeId("na"), "na");
TaskContext taskContext2 = getOrCreateTaskContext(sqlTask2);
PipelineContext pipelineContext21 = taskContext2.addPipelineContext(1, false, false, false);
DriverContext driverContext211 = pipelineContext21.addDriverContext();
OperatorContext operatorContext4 = driverContext211.addOperatorContext(4, new PlanNodeId("na"), "na");
List<SqlTask> tasks = ImmutableList.of(sqlTask1, sqlTask2);
MemoryRevokingScheduler scheduler = new MemoryRevokingScheduler(
singletonList(memoryPool),
() -> tasks,
queryContexts::get,
1.0,
1.0,
ORDER_BY_CREATE_TIME,
true);
try {
scheduler.start();
allOperatorContexts = ImmutableSet.of(operatorContext1, operatorContext2, operatorContext3, operatorContext4);
assertMemoryRevokingNotRequested();
assertEquals(10, memoryPool.getFreeBytes());
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingNotRequested();
LocalMemoryContext revocableMemory1 = operatorContext1.localRevocableMemoryContext();
LocalMemoryContext revocableMemory3 = operatorContext3.localRevocableMemoryContext();
LocalMemoryContext revocableMemory4 = operatorContext4.localRevocableMemoryContext();
revocableMemory1.setBytes(3);
revocableMemory3.setBytes(6);
assertEquals(1, memoryPool.getFreeBytes());
scheduler.awaitAsynchronousCallbacksRun();
// we are still good - no revoking needed
assertMemoryRevokingNotRequested();
revocableMemory4.setBytes(7);
assertEquals(-6, memoryPool.getFreeBytes());
scheduler.awaitAsynchronousCallbacksRun();
// we need to revoke 3 and 6
assertMemoryRevokingRequestedFor(operatorContext1, operatorContext3);
}
finally {
scheduler.stop();
}
}
@Test
public void testQueryMemoryNotRevokedWhenNotEnabled()
throws Exception
{
// The various tasks created here use a small amount of system memory independent of what's set explicitly
// in this test. Triggering spilling based on differences of thousands of bytes rather than hundreds
// makes the test resilient to any noise that creates.
QueryId queryId = new QueryId("query");
// use a larger memory pool so that we don't trigger spilling due to filling the memory pool
MemoryPool queryLimitMemoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(100, GIGABYTE));
SqlTask sqlTask1 = newSqlTask(queryId, queryLimitMemoryPool);
TestOperatorContext operatorContext11 = createTestingOperatorContexts(sqlTask1, "operator11");
allOperatorContexts = ImmutableSet.of(operatorContext11);
List<SqlTask> tasks = ImmutableList.of(sqlTask1);
MemoryRevokingScheduler scheduler = new MemoryRevokingScheduler(
singletonList(queryLimitMemoryPool),
() -> tasks,
queryContexts::get,
1.0,
1.0,
ORDER_BY_REVOCABLE_BYTES,
false);
try {
scheduler.start();
assertMemoryRevokingNotRequested();
// exceed the query memory limit of 500KB
operatorContext11.localRevocableMemoryContext().setBytes(600_000);
scheduler.awaitAsynchronousCallbacksRun();
assertMemoryRevokingNotRequested();
}
finally {
scheduler.stop();
}
}
private OperatorContext createContexts(SqlTask sqlTask)
{
TaskContext taskContext = getOrCreateTaskContext(sqlTask);
PipelineContext pipelineContext = taskContext.addPipelineContext(0, false, false, false);
DriverContext driverContext = pipelineContext.addDriverContext();
return driverContext.addOperatorContext(1, new PlanNodeId("na"), "na");
}
private TestOperatorContext createTestingOperatorContexts(SqlTask sqlTask, String operatorName)
{
// update task to update underlying taskHolderReference with taskExecution + create a new taskContext
sqlTask.updateTask(TEST_SESSION,
Optional.of(PLAN_FRAGMENT),
ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), false)),
createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(),
Optional.of(new TableWriteInfo(Optional.empty(), Optional.empty())));
// use implicitly created task context from updateTask. It should be the only task in this QueryContext's tasks
TaskContext taskContext = sqlTask.getQueryContext().getTaskContextByTaskId(sqlTask.getTaskId());
PipelineContext pipelineContext = taskContext.addPipelineContext(0, false, false, false);
DriverContext driverContext = pipelineContext.addDriverContext();
TestOperatorContext testOperatorContext = new TestOperatorContext(
1,
new PlanNodeId("na"),
"na",
driverContext,
singleThreadedExecutor,
driverContext.getDriverMemoryContext().newMemoryTrackingContext(),
operatorName);
driverContext.addOperatorContext(testOperatorContext);
return testOperatorContext;
}
private static class TestOperatorContext
extends OperatorContext
{
public static String firstOperator;
private final String operatorName;
public TestOperatorContext(
int operatorId,
PlanNodeId planNodeId,
String operatorType,
DriverContext driverContext,
Executor executor,
MemoryTrackingContext operatorMemoryContext,
String operatorName)
{
super(operatorId, planNodeId, operatorType, driverContext, executor, operatorMemoryContext);
this.operatorName = operatorName;
}
@Override
public long requestMemoryRevoking()
{
if (firstOperator == null) {
// Due to the way MemoryRevokingScheduler works, revoking tasks one by one, simultaneous revoke of two tasks is impossible
// This is why updating this static member is safe
firstOperator = operatorName;
}
return super.requestMemoryRevoking();
}
}
private void requestMemoryRevoking(TaskThresholdMemoryRevokingScheduler scheduler)
throws Exception
{
scheduler.revokeHighMemoryTasksIfNeeded();
awaitTaskThresholdAsynchronousCallbacksRun();
}
private void awaitTaskThresholdAsynchronousCallbacksRun()
throws Exception
{
// Make sure asynchronous callback got called (executor is single-threaded).
singleThreadedScheduledExecutor.invokeAll(singletonList((Callable<?>) () -> null));
}
private void assertMemoryRevokingRequestedFor(OperatorContext... operatorContexts)
{
ImmutableSet<OperatorContext> operatorContextsSet = ImmutableSet.copyOf(operatorContexts);
operatorContextsSet.forEach(
operatorContext -> assertTrue(operatorContext.isMemoryRevokingRequested(), "expected memory requested for operator " + operatorContext.getOperatorId()));
Sets.difference(allOperatorContexts, operatorContextsSet).forEach(
operatorContext -> assertFalse(operatorContext.isMemoryRevokingRequested(), "expected memory not requested for operator " + operatorContext.getOperatorId()));
}
private void assertMemoryRevokingNotRequested()
{
assertMemoryRevokingRequestedFor();
}
private SqlTask newSqlTask(QueryId queryId, MemoryPool memoryPool)
{
QueryContext queryContext = getOrCreateQueryContext(queryId, memoryPool);
TaskId taskId = new TaskId(queryId.getId(), 0, 0, idGenerator.incrementAndGet(), 0);
URI location = URI.create("fake://task/" + taskId);
return createSqlTask(
taskId,
location,
"fake",
queryContext,
sqlTaskExecutionFactory,
new MockExchangeClientSupplier(),
singleThreadedExecutor,
Functions.identity(),
new DataSize(32, MEGABYTE).toBytes(),
new CounterStat(),
new SpoolingOutputBufferFactory(new FeaturesConfig()));
}
private QueryContext getOrCreateQueryContext(QueryId queryId, MemoryPool memoryPool)
{
return queryContexts.computeIfAbsent(queryId, id -> new QueryContext(id,
new DataSize(500, KILOBYTE),
new DataSize(500, KILOBYTE),
new DataSize(500, KILOBYTE),
new DataSize(1, GIGABYTE),
memoryPool,
new TestingGcMonitor(),
singleThreadedExecutor,
scheduledExecutor,
new DataSize(1, GIGABYTE),
spillSpaceTracker,
listJsonCodec(TaskMemoryReservationSummary.class)));
}
private TaskContext getOrCreateTaskContext(SqlTask sqlTask)
{
if (!sqlTask.getTaskContext().isPresent()) {
// update task to update underlying taskHolderReference with taskExecution + create a new taskContext
updateTask(sqlTask, ImmutableList.of(), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
}
return sqlTask.getTaskContext().orElseThrow(() -> new IllegalStateException("TaskContext not present"));
}
}