MemoryRevokingScheduler.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.memory.LocalMemoryManager;
import com.facebook.presto.memory.MemoryPool;
import com.facebook.presto.memory.MemoryPoolListener;
import com.facebook.presto.memory.QueryContext;
import com.facebook.presto.memory.VoidTraversingQueryContextVisitor;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.PipelineContext;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.memory.MemoryPoolId;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.analyzer.FeaturesConfig.TaskSpillingStrategy;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.inject.Inject;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.Supplier;
import static com.facebook.airlift.concurrent.Threads.threadsNamed;
import static com.facebook.presto.execution.MemoryRevokingUtils.getMemoryPools;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.TaskSpillingStrategy.PER_TASK_MEMORY_THRESHOLD;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Collections.singletonList;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newSingleThreadExecutor;
public class MemoryRevokingScheduler
{
private static final Logger log = Logger.get(MemoryRevokingScheduler.class);
private static final Ordering<SqlTask> ORDER_BY_CREATE_TIME = Ordering.natural().onResultOf(SqlTask::getTaskCreatedTime);
private final Function<QueryId, QueryContext> queryContextSupplier;
private final Supplier<List<SqlTask>> currentTasksSupplier;
private final ExecutorService memoryRevocationExecutor;
private final double memoryRevokingThreshold;
private final double memoryRevokingTarget;
private final TaskSpillingStrategy spillingStrategy;
private final List<MemoryPool> memoryPools;
private final MemoryPoolListener memoryPoolListener = this::onMemoryReserved;
private final boolean queryLimitSpillEnabled;
@Inject
public MemoryRevokingScheduler(
LocalMemoryManager localMemoryManager,
SqlTaskManager sqlTaskManager,
FeaturesConfig config)
{
this(
ImmutableList.copyOf(getMemoryPools(localMemoryManager)),
requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")::getAllTasks,
requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")::getQueryContext,
config.getMemoryRevokingThreshold(),
config.getMemoryRevokingTarget(),
config.getTaskSpillingStrategy(),
config.isQueryLimitSpillEnabled());
}
@VisibleForTesting
MemoryRevokingScheduler(
List<MemoryPool> memoryPools,
Supplier<List<SqlTask>> currentTasksSupplier,
Function<QueryId, QueryContext> queryContextSupplier,
double memoryRevokingThreshold,
double memoryRevokingTarget,
TaskSpillingStrategy taskSpillingStrategy,
boolean queryLimitSpillEnabled)
{
this.memoryPools = ImmutableList.copyOf(requireNonNull(memoryPools, "memoryPools is null"));
this.currentTasksSupplier = requireNonNull(currentTasksSupplier, "allTasksSupplier is null");
this.queryContextSupplier = requireNonNull(queryContextSupplier, "queryContextSupplier is null");
this.memoryRevokingThreshold = checkFraction(memoryRevokingThreshold, "memoryRevokingThreshold");
this.memoryRevokingTarget = checkFraction(memoryRevokingTarget, "memoryRevokingTarget");
// by using a single thread executor, we don't need to worry about locking to ensure only
// one revocation request per-query/memory pool is processed at a time.
this.memoryRevocationExecutor = newSingleThreadExecutor(threadsNamed("memory-revocation"));
this.spillingStrategy = requireNonNull(taskSpillingStrategy, "taskSpillingStrategy is null");
checkArgument(spillingStrategy != PER_TASK_MEMORY_THRESHOLD, "spilling strategy cannot be PER_TASK_MEMORY_THRESHOLD in MemoryRevokingScheduler");
checkArgument(
memoryRevokingTarget <= memoryRevokingThreshold,
"memoryRevokingTarget should be less than or equal memoryRevokingThreshold, but got %s and %s respectively",
memoryRevokingTarget, memoryRevokingThreshold);
this.queryLimitSpillEnabled = queryLimitSpillEnabled;
}
private static double checkFraction(double value, String valueName)
{
requireNonNull(valueName, "valueName is null");
checkArgument(0 <= value && value <= 1, "%s should be within [0, 1] range, got %s", valueName, value);
return value;
}
@PostConstruct
public void start()
{
registerPoolListeners();
}
@PreDestroy
public void stop()
{
memoryPools.forEach(memoryPool -> memoryPool.removeListener(memoryPoolListener));
memoryRevocationExecutor.shutdown();
}
private void registerPoolListeners()
{
memoryPools.forEach(memoryPool -> memoryPool.addListener(memoryPoolListener));
}
@VisibleForTesting
void awaitAsynchronousCallbacksRun()
throws InterruptedException
{
memoryRevocationExecutor.invokeAll(singletonList((Callable<?>) () -> null));
}
@VisibleForTesting
void submitAsynchronousCallable(Callable<?> callable)
{
memoryRevocationExecutor.submit(callable);
}
private void onMemoryReserved(MemoryPool memoryPool, QueryId queryId, long queryMemoryReservation)
{
try {
if (queryLimitSpillEnabled) {
QueryContext queryContext = queryContextSupplier.apply(queryId);
verify(queryContext != null, "QueryContext not found for queryId %s", queryId);
long maxTotalMemory = queryContext.getMaxTotalMemory();
if (memoryRevokingNeededForQuery(queryMemoryReservation, maxTotalMemory)) {
log.debug("Scheduling check for %s", queryId);
scheduleQueryRevoking(queryContext, maxTotalMemory);
}
}
if (memoryRevokingNeededForPool(memoryPool)) {
log.debug("Scheduling check for %s", memoryPool);
scheduleMemoryPoolRevoking(memoryPool);
}
}
catch (Exception e) {
log.error(e, "Error when acting on memory pool reservation");
}
}
private boolean memoryRevokingNeededForQuery(long queryMemoryReservation, long maxTotalMemory)
{
return queryMemoryReservation >= maxTotalMemory;
}
private void scheduleQueryRevoking(QueryContext queryContext, long maxTotalMemory)
{
memoryRevocationExecutor.execute(() -> {
try {
revokeQueryMemory(queryContext, maxTotalMemory);
}
catch (Exception e) {
log.error(e, "Error requesting memory revoking");
}
});
}
private void revokeQueryMemory(QueryContext queryContext, long maxTotalMemory)
{
QueryId queryId = queryContext.getQueryId();
MemoryPool memoryPool = queryContext.getMemoryPool();
// get a fresh value for queryTotalMemory in case it's changed (e.g. by a previous revocation request)
long queryTotalMemory = getTotalQueryMemoryReservation(queryId, memoryPool);
// order tasks by decreasing revocableMemory so that we don't spill more tasks than needed
SortedMap<Long, TaskContext> queryTaskContextsMap = new TreeMap<>(Comparator.reverseOrder());
queryContext.getAllTaskContexts()
.forEach(taskContext -> queryTaskContextsMap.put(taskContext.getTaskMemoryContext().getRevocableMemory(), taskContext));
AtomicLong remainingBytesToRevoke = new AtomicLong(queryTotalMemory - maxTotalMemory);
Collection<TaskContext> queryTaskContexts = queryTaskContextsMap.values();
remainingBytesToRevoke.addAndGet(-MemoryRevokingSchedulerUtils.getMemoryAlreadyBeingRevoked(queryTaskContexts, remainingBytesToRevoke.get()));
for (TaskContext taskContext : queryTaskContexts) {
if (remainingBytesToRevoke.get() <= 0) {
break;
}
taskContext.accept(new VoidTraversingQueryContextVisitor<AtomicLong>()
{
@Override
public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke)
{
if (remainingBytesToRevoke.get() > 0) {
long revokedBytes = operatorContext.requestMemoryRevoking();
if (revokedBytes > 0) {
remainingBytesToRevoke.addAndGet(-revokedBytes);
log.debug("taskId=%s: requested revoking %s; remaining %s", taskContext.getTaskId(), revokedBytes, remainingBytesToRevoke);
}
}
return null;
}
}, remainingBytesToRevoke);
}
}
private static long getTotalQueryMemoryReservation(QueryId queryId, MemoryPool memoryPool)
{
return memoryPool.getQueryMemoryReservation(queryId) + memoryPool.getQueryRevocableMemoryReservation(queryId);
}
private void scheduleMemoryPoolRevoking(MemoryPool memoryPool)
{
memoryRevocationExecutor.execute(() -> {
try {
runMemoryPoolRevoking(memoryPool);
}
catch (Exception e) {
log.error(e, "Error requesting memory revoking");
}
});
}
@VisibleForTesting
void runMemoryPoolRevoking(MemoryPool memoryPool)
{
if (!memoryRevokingNeededForPool(memoryPool)) {
return;
}
Collection<SqlTask> allTasks = requireNonNull(currentTasksSupplier.get());
requestMemoryPoolRevoking(memoryPool, allTasks);
}
private void requestMemoryPoolRevoking(MemoryPool memoryPool, Collection<SqlTask> allTasks)
{
long remainingBytesToRevoke = (long) (-memoryPool.getFreeBytes() + (memoryPool.getMaxBytes() * (1.0 - memoryRevokingTarget)));
ArrayList<SqlTask> runningTasksInPool = findRunningTasksInMemoryPool(allTasks, memoryPool);
remainingBytesToRevoke -= getMemoryAlreadyBeingRevoked(runningTasksInPool, remainingBytesToRevoke);
if (remainingBytesToRevoke > 0) {
requestRevoking(memoryPool.getId(), runningTasksInPool, remainingBytesToRevoke);
}
}
private boolean memoryRevokingNeededForPool(MemoryPool memoryPool)
{
return memoryPool.getReservedRevocableBytes() > 0
&& memoryPool.getFreeBytes() <= memoryPool.getMaxBytes() * (1.0 - memoryRevokingThreshold);
}
private long getMemoryAlreadyBeingRevoked(List<SqlTask> sqlTasks, long targetRevokingLimit)
{
List<TaskContext> taskContexts = sqlTasks.stream()
.map(SqlTask::getTaskContext)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(toImmutableList());
return MemoryRevokingSchedulerUtils.getMemoryAlreadyBeingRevoked(taskContexts, targetRevokingLimit);
}
private void requestRevoking(MemoryPoolId memoryPoolId, ArrayList<SqlTask> sqlTasks, long remainingBytesToRevoke)
{
VoidTraversingQueryContextVisitor<AtomicLong> visitor = new VoidTraversingQueryContextVisitor<AtomicLong>()
{
@Override
public Void visitPipelineContext(PipelineContext pipelineContext, AtomicLong remainingBytesToRevoke)
{
if (remainingBytesToRevoke.get() <= 0) {
// exit immediately if no work needs to be done
return null;
}
return super.visitPipelineContext(pipelineContext, remainingBytesToRevoke);
}
@Override
public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke)
{
if (remainingBytesToRevoke.get() > 0) {
long revokedBytes = operatorContext.requestMemoryRevoking();
if (revokedBytes > 0) {
remainingBytesToRevoke.addAndGet(-revokedBytes);
log.debug("memoryPool=%s, operatorContext: %s: requested revoking %s; remaining %s", memoryPoolId, operatorContext, revokedBytes, remainingBytesToRevoke.get());
}
}
return null;
}
};
// Sort the tasks into their traversal order
log.debug("Ordering by %s", spillingStrategy);
sortTasksToTraversalOrder(sqlTasks, spillingStrategy);
AtomicLong remainingBytesToRevokeAtomic = new AtomicLong(remainingBytesToRevoke);
for (SqlTask task : sqlTasks) {
Optional<TaskContext> taskContext = task.getTaskContext();
if (taskContext.isPresent()) {
taskContext.get().accept(visitor, remainingBytesToRevokeAtomic);
if (remainingBytesToRevokeAtomic.get() <= 0) {
// No further revoking required
return;
}
}
}
}
private static void sortTasksToTraversalOrder(ArrayList<SqlTask> sqlTasks, TaskSpillingStrategy spillingStrategy)
{
switch (spillingStrategy) {
case ORDER_BY_CREATE_TIME:
sqlTasks.sort(ORDER_BY_CREATE_TIME);
break;
case ORDER_BY_REVOCABLE_BYTES:
// To avoid repeatedly generating the task info, we have to compare by their mapping
HashMap<TaskId, Long> taskRevocableReservations = new HashMap<>();
for (SqlTask sqlTask : sqlTasks) {
taskRevocableReservations.put(sqlTask.getTaskId(), sqlTask.getTaskInfo().getStats().getRevocableMemoryReservationInBytes());
}
sqlTasks.sort(Ordering.natural().reverse().onResultOf(task -> task == null ? 0L : taskRevocableReservations.getOrDefault(task.getTaskId(), 0L)));
break;
case PER_TASK_MEMORY_THRESHOLD:
throw new IllegalArgumentException("spilling strategy cannot be PER_TASK_MEMORY_THRESHOLD in MemoryRevokingScheduler");
default:
throw new UnsupportedOperationException("Unexpected spilling strategy in MemoryRevokingScheduler");
}
}
private static ArrayList<SqlTask> findRunningTasksInMemoryPool(Collection<SqlTask> allCurrentTasks, MemoryPool memoryPool)
{
ArrayList<SqlTask> sqlTasks = new ArrayList<>();
allCurrentTasks.stream()
.filter(task -> task.getTaskState() == TaskState.RUNNING && task.getQueryContext().getMemoryPool() == memoryPool)
.forEach(sqlTasks::add); // Resulting list must be mutable to enable sorting after the fact
return sqlTasks;
}
}