HighMemoryTaskKiller.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.memory;

import com.facebook.airlift.log.Logger;
import com.facebook.airlift.stats.GarbageCollectionNotificationInfo;
import com.facebook.presto.execution.SqlTask;
import com.facebook.presto.execution.SqlTaskManager;
import com.facebook.presto.execution.TaskInfo;
import com.facebook.presto.execution.TaskManagerConfig;
import com.facebook.presto.operator.TaskStats;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.QueryId;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Ticker;
import com.google.common.collect.ListMultimap;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.inject.Inject;
import javax.management.JMException;
import javax.management.Notification;
import javax.management.NotificationListener;
import javax.management.ObjectName;
import javax.management.openmbean.CompositeData;

import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryMXBean;
import java.lang.management.MemoryUsage;
import java.util.AbstractMap;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import static com.facebook.presto.memory.HighMemoryTaskKillerStrategy.FREE_MEMORY_ON_FREQUENT_FULL_GC;
import static com.facebook.presto.memory.HighMemoryTaskKillerStrategy.FREE_MEMORY_ON_FULL_GC;
import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_HEAP_MEMORY_LIMIT;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class HighMemoryTaskKiller
{
    private static final Logger log = Logger.get(HighMemoryTaskKiller.class);
    private static final String GC_NOTIFICATION_TYPE = "com.sun.management.gc.notification";
    private static final MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
    private final NotificationListener gcNotificationListener = (notification, ignored) -> onGCNotification(notification);
    private final SqlTaskManager sqlTaskManager;
    private final HighMemoryTaskKillerStrategy taskKillerStrategy;
    private final boolean taskKillerEnabled;
    private final Duration taskKillerFrequentFullGCDurationThreshold;
    private Duration lastFullGCTimestamp;
    private long lastFullGCCollectedBytes;
    private final long reclaimMemoryThreshold;
    private final long heapMemoryThreshold;
    Ticker ticker;

    @Inject
    public HighMemoryTaskKiller(SqlTaskManager sqlTaskManager, TaskManagerConfig taskManagerConfig)
    {
        requireNonNull(taskManagerConfig, "taskManagerConfig is null");

        this.sqlTaskManager = requireNonNull(sqlTaskManager, "sqlTaskManager must not be null");

        this.taskKillerStrategy = taskManagerConfig.getHighMemoryTaskKillerStrategy();
        this.taskKillerEnabled = taskManagerConfig.isHighMemoryTaskKillerEnabled();

        this.taskKillerFrequentFullGCDurationThreshold = taskManagerConfig.getHighMemoryTaskKillerFrequentFullGCDurationThreshold();
        this.reclaimMemoryThreshold = (long) (memoryMXBean.getHeapMemoryUsage().getMax() * taskManagerConfig.getHighMemoryTaskKillerGCReclaimMemoryThreshold());

        this.heapMemoryThreshold = (long) (memoryMXBean.getHeapMemoryUsage().getMax() * taskManagerConfig.getHighMemoryTaskKillerHeapMemoryThreshold());
        this.ticker = Ticker.systemTicker();
    }

    @PostConstruct
    public void start()
    {
        if (!taskKillerEnabled) {
            return;
        }

        for (GarbageCollectorMXBean mbean : ManagementFactory.getGarbageCollectorMXBeans()) {
            if (mbean.getName().equals("TestingMBeanServer")) {
                continue;
            }

            ObjectName objectName = mbean.getObjectName();
            try {
                ManagementFactory.getPlatformMBeanServer().addNotificationListener(
                        objectName,
                        gcNotificationListener,
                        null,
                        null);
            }
            catch (JMException e) {
                throw new RuntimeException("Unable to add listener", e);
            }
        }
    }

    @PreDestroy
    public void stop()
    {
        if (!taskKillerEnabled) {
            return;
        }

        for (GarbageCollectorMXBean mbean : ManagementFactory.getGarbageCollectorMXBeans()) {
            ObjectName objectName = mbean.getObjectName();
            try {
                ManagementFactory.getPlatformMBeanServer().removeNotificationListener(objectName, gcNotificationListener);
            }
            catch (JMException ignored) {
                log.error("Error removing notification: " + ignored);
            }
        }
    }

    private void onGCNotification(Notification notification)
    {
        if (GC_NOTIFICATION_TYPE.equals(notification.getType())) {
            GarbageCollectionNotificationInfo info = new GarbageCollectionNotificationInfo((CompositeData) notification.getUserData());
            if (info.isMajorGc()) {
                if (shouldTriggerTaskKiller(info)) {
                    //Kill task consuming most memory
                    List<SqlTask> activeTasks = getActiveTasks();
                    ListMultimap<QueryId, SqlTask> activeQueriesToTasksMap = activeTasks.stream()
                            .collect(toImmutableListMultimap(task -> task.getQueryContext().getQueryId(), Function.identity()));

                    Optional<QueryId> queryId = getMaxMemoryConsumingQuery(activeQueriesToTasksMap);

                    if (queryId.isPresent()) {
                        List<SqlTask> activeTasksToKill = activeQueriesToTasksMap.get(queryId.get());
                        for (SqlTask sqlTask : activeTasksToKill) {
                            TaskStats taskStats = sqlTask.getTaskInfo().getStats();
                            sqlTask.failed(new PrestoException(EXCEEDED_HEAP_MEMORY_LIMIT, format("Worker heap memory limit exceeded: User Memory: %d, System Memory: %d, Revocable Memory: %d", taskStats.getUserMemoryReservationInBytes(), taskStats.getSystemMemoryReservationInBytes(), taskStats.getRevocableMemoryReservationInBytes())));
                        }
                    }
                }
            }
        }
    }

    private boolean shouldTriggerTaskKiller(GarbageCollectionNotificationInfo info)
    {
        boolean triggerTaskKiller = false;
        DataSize beforeGcDataSize = info.getBeforeGcTotal();
        DataSize afterGcDataSize = info.getAfterGcTotal();

        if (taskKillerStrategy == FREE_MEMORY_ON_FREQUENT_FULL_GC) {
            long currentGarbageCollectedBytes = beforeGcDataSize.toBytes() - afterGcDataSize.toBytes();
            Duration currentFullGCTimestamp = new Duration(ticker.read(), TimeUnit.NANOSECONDS);

            if (isFrequentFullGC(lastFullGCTimestamp, currentFullGCTimestamp) && !hasFullGCFreedEnoughBytes(currentGarbageCollectedBytes)) {
                triggerTaskKiller = true;
            }

            lastFullGCTimestamp = currentFullGCTimestamp;
            lastFullGCCollectedBytes = currentGarbageCollectedBytes;
        }
        else if (taskKillerStrategy == FREE_MEMORY_ON_FULL_GC) {
            if (isLowMemory() && beforeGcDataSize.toBytes() - afterGcDataSize.toBytes() < reclaimMemoryThreshold) {
                triggerTaskKiller = true;
            }
        }
        log.debug("Task Killer Trigger: " + triggerTaskKiller + ", Before Full GC Head Size: " + beforeGcDataSize.toBytes() + " After Full GC Heap Size: " + afterGcDataSize.toBytes());

        return triggerTaskKiller;
    }

    private List<SqlTask> getActiveTasks()
    {
        return sqlTaskManager.getAllTasks().stream()
                .filter(task -> !task.getTaskState().isDone())
                .collect(toImmutableList());
    }

    @VisibleForTesting
    public static Optional<QueryId> getMaxMemoryConsumingQuery(ListMultimap<QueryId, SqlTask> queryIDToSqlTaskMap)
    {
        if (queryIDToSqlTaskMap.isEmpty()) {
            return Optional.empty();
        }

        Comparator<Map.Entry<QueryId, Long>> comparator = Comparator.comparingLong(Map.Entry::getValue);

        Optional<QueryId> maxMemoryConsumpingQueryId = queryIDToSqlTaskMap.asMap().entrySet().stream()
                .map(entry ->
                        new AbstractMap.SimpleEntry<>(entry.getKey(), entry.getValue().stream()
                                .map(SqlTask::getTaskInfo)
                                .map(TaskInfo::getStats)
                                .mapToLong(stats -> stats.getUserMemoryReservationInBytes() + stats.getSystemMemoryReservationInBytes() + stats.getRevocableMemoryReservationInBytes())
                                .sum())
                ).max(comparator).map(Map.Entry::getKey);

        return maxMemoryConsumpingQueryId;
    }

    private boolean isFrequentFullGC(Duration lastFullGCTime, Duration currentFullGCTime)
    {
        long diffBetweenFullGCMilis = currentFullGCTime.toMillis() - lastFullGCTime.toMillis();
        log.debug("Time difference between last 2 full GC in miliseconds: " + diffBetweenFullGCMilis);
        if (diffBetweenFullGCMilis > taskKillerFrequentFullGCDurationThreshold.getValue(TimeUnit.MILLISECONDS)) {
            log.debug("Skip killing tasks Due to full GCs were not happening frequently.");
            return false;
        }
        return true;
    }

    private boolean hasFullGCFreedEnoughBytes(long currentGarbageCollectedBytes)
    {
        if (currentGarbageCollectedBytes < reclaimMemoryThreshold && lastFullGCCollectedBytes < reclaimMemoryThreshold) {
            log.debug("Full GC not able to free enough memory. Current freed bytes: " + currentGarbageCollectedBytes + " previously freed bytes: " + lastFullGCCollectedBytes);
            return false;
        }
        log.debug("Full GC able to free enough memory. Current freed bytes: " + currentGarbageCollectedBytes + " previously freed bytes: " + lastFullGCCollectedBytes);
        return true;
    }

    private boolean isLowMemory()
    {
        MemoryUsage memoryUsage = memoryMXBean.getHeapMemoryUsage();

        if (memoryUsage.getUsed() > heapMemoryThreshold) {
            return true;
        }

        return false;
    }
}