TaskThresholdMemoryRevokingScheduler.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.memory.LocalMemoryManager;
import com.facebook.presto.memory.MemoryPool;
import com.facebook.presto.memory.TaskRevocableMemoryListener;
import com.facebook.presto.memory.VoidTraversingQueryContextVisitor;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import javax.annotation.Nullable;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.inject.Inject;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.Supplier;
import static com.facebook.presto.execution.MemoryRevokingUtils.getMemoryPools;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.SECONDS;
public class TaskThresholdMemoryRevokingScheduler
{
private static final Logger log = Logger.get(TaskThresholdMemoryRevokingScheduler.class);
private final Supplier<List<SqlTask>> allTasksSupplier;
private final Function<TaskId, SqlTask> taskSupplier;
private final ScheduledExecutorService taskManagementExecutor;
private final long maxRevocableMemoryPerTask;
// Technically not thread safe but should be fine since we only call this on PostConstruct and PreDestroy.
// PreDestroy isn't called until server shuts down/ in between tests.
@Nullable
private ScheduledFuture<?> scheduledFuture;
private final AtomicBoolean checkPending = new AtomicBoolean();
private final List<MemoryPool> memoryPools;
private final TaskRevocableMemoryListener taskRevocableMemoryListener = this::onMemoryReserved;
@Inject
public TaskThresholdMemoryRevokingScheduler(
LocalMemoryManager localMemoryManager,
SqlTaskManager sqlTaskManager,
TaskManagementExecutor taskManagementExecutor,
FeaturesConfig config)
{
this(
ImmutableList.copyOf(getMemoryPools(localMemoryManager)),
requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")::getAllTasks,
requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")::getTask,
requireNonNull(taskManagementExecutor, "taskManagementExecutor cannot be null").getExecutor(),
requireNonNull(config.getMaxRevocableMemoryPerTask(), "maxRevocableMemoryPerTask cannot be null").toBytes());
log.debug("Using TaskThresholdMemoryRevokingScheduler spilling strategy");
}
@VisibleForTesting
TaskThresholdMemoryRevokingScheduler(
List<MemoryPool> memoryPools,
Supplier<List<SqlTask>> allTasksSupplier,
Function<TaskId, SqlTask> taskSupplier,
ScheduledExecutorService taskManagementExecutor,
long maxRevocableMemoryPerTask)
{
this.memoryPools = ImmutableList.copyOf(requireNonNull(memoryPools, "memoryPools is null"));
this.allTasksSupplier = requireNonNull(allTasksSupplier, "allTasksSupplier is null");
this.taskSupplier = requireNonNull(taskSupplier, "taskSupplier is null");
this.taskManagementExecutor = requireNonNull(taskManagementExecutor, "taskManagementExecutor is null");
this.maxRevocableMemoryPerTask = maxRevocableMemoryPerTask;
}
@PostConstruct
public void start()
{
registerTaskMemoryPeriodicCheck();
registerPoolListeners();
}
private void registerTaskMemoryPeriodicCheck()
{
this.scheduledFuture = taskManagementExecutor.scheduleWithFixedDelay(() -> {
try {
revokeHighMemoryTasksIfNeeded();
}
catch (Exception e) {
log.error(e, "Error requesting task memory revoking");
}
}, 1, 1, SECONDS);
}
@PreDestroy
public void stop()
{
if (scheduledFuture != null) {
scheduledFuture.cancel(true);
scheduledFuture = null;
}
memoryPools.forEach(memoryPool -> memoryPool.removeTaskRevocableMemoryListener(taskRevocableMemoryListener));
}
@VisibleForTesting
void registerPoolListeners()
{
memoryPools.forEach(memoryPool -> memoryPool.addTaskRevocableMemoryListener(taskRevocableMemoryListener));
}
@VisibleForTesting
void revokeHighMemoryTasksIfNeeded()
{
if (checkPending.compareAndSet(false, true)) {
revokeHighMemoryTasks();
}
}
private void onMemoryReserved(TaskId taskId)
{
try {
SqlTask task = taskSupplier.apply(taskId);
if (!memoryRevokingNeeded(task)) {
return;
}
if (checkPending.compareAndSet(false, true)) {
log.debug("Scheduling check for %s", taskId);
scheduleRevoking();
}
}
catch (Exception e) {
log.error(e, "Error when acting on memory pool reservation");
}
}
private void scheduleRevoking()
{
taskManagementExecutor.execute(() -> {
try {
revokeHighMemoryTasks();
}
catch (Exception e) {
log.error(e, "Error requesting memory revoking");
}
});
}
private boolean memoryRevokingNeeded(SqlTask task)
{
return task.getTaskContext().filter(taskContext -> taskContext.getTaskMemoryContext().getRevocableMemory() >= maxRevocableMemoryPerTask).isPresent();
}
private synchronized void revokeHighMemoryTasks()
{
if (checkPending.getAndSet(false)) {
Collection<SqlTask> sqlTasks = requireNonNull(allTasksSupplier.get());
for (SqlTask task : sqlTasks) {
Optional<TaskContext> taskContext = task.getTaskContext();
if (!taskContext.isPresent()) {
continue;
}
long currentTaskRevocableMemory = taskContext.get().getTaskMemoryContext().getRevocableMemory();
if (currentTaskRevocableMemory < maxRevocableMemoryPerTask) {
continue;
}
AtomicLong remainingBytesToRevokeAtomic = new AtomicLong(currentTaskRevocableMemory - maxRevocableMemoryPerTask);
taskContext.get().accept(new VoidTraversingQueryContextVisitor<AtomicLong>()
{
@Override
public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke)
{
if (remainingBytesToRevoke.get() > 0) {
long revokedBytes = operatorContext.requestMemoryRevoking();
if (revokedBytes > 0) {
remainingBytesToRevoke.addAndGet(-revokedBytes);
log.debug("taskId=%s: requested revoking %s; remaining %s", task.getTaskId(), revokedBytes, remainingBytesToRevoke.get());
}
}
return null;
}
}, remainingBytesToRevokeAtomic);
}
}
}
}