TestMemoryTracking.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.memory;

import com.facebook.airlift.stats.TestingGcMonitor;
import com.facebook.presto.ExceededMemoryLimitException;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskStateMachine;
import com.facebook.presto.memory.context.LocalMemoryContext;
import com.facebook.presto.memory.context.MemoryTrackingContext;
import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.DriverStats;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.OperatorStats;
import com.facebook.presto.operator.PipelineContext;
import com.facebook.presto.operator.PipelineStats;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.operator.TaskMemoryReservationSummary;
import com.facebook.presto.operator.TaskStats;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.memory.MemoryPoolId;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spiller.SpillSpaceTracker;
import io.airlift.units.DataSize;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Consumer;
import java.util.regex.Pattern;

import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed;
import static com.facebook.airlift.json.JsonCodec.listJsonCodec;
import static com.facebook.presto.execution.TaskTestUtils.PLAN_FRAGMENT;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static io.airlift.units.DataSize.Unit.GIGABYTE;
import static java.lang.String.format;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;

@Test(singleThreaded = true)
public class TestMemoryTracking
{
    private static final DataSize queryMaxMemory = new DataSize(1, GIGABYTE);
    private static final DataSize queryMaxTotalMemory = new DataSize(1, GIGABYTE);
    private static final DataSize queryMaxRevocableMemory = new DataSize(2, GIGABYTE);
    private static final DataSize memoryPoolSize = new DataSize(1, GIGABYTE);
    private static final DataSize maxSpillSize = new DataSize(1, GIGABYTE);
    private static final DataSize queryMaxSpillSize = new DataSize(1, GIGABYTE);
    private static final SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(maxSpillSize);

    private QueryContext queryContext;
    private TaskContext taskContext;
    private PipelineContext pipelineContext;
    private DriverContext driverContext;
    private OperatorContext operatorContext;
    private MemoryPool memoryPool;
    private ExecutorService notificationExecutor;
    private ScheduledExecutorService yieldExecutor;

    @BeforeClass
    public void setUp()
    {
        notificationExecutor = newCachedThreadPool(daemonThreadsNamed("local-query-runner-executor-%s"));
        yieldExecutor = newScheduledThreadPool(2, daemonThreadsNamed("local-query-runner-scheduler-%s"));
    }

    @AfterClass(alwaysRun = true)
    public void tearDown()
    {
        notificationExecutor.shutdownNow();
        yieldExecutor.shutdownNow();
        queryContext = null;
        taskContext = null;
        pipelineContext = null;
        driverContext = null;
        operatorContext = null;
        memoryPool = null;
    }

    @BeforeMethod
    public void setUpTest()
    {
        memoryPool = new MemoryPool(new MemoryPoolId("test"), memoryPoolSize);
        queryContext = new QueryContext(
                new QueryId("test_query"),
                queryMaxMemory,
                queryMaxTotalMemory,
                queryMaxMemory,
                queryMaxRevocableMemory,
                memoryPool,
                new TestingGcMonitor(),
                notificationExecutor,
                yieldExecutor,
                queryMaxSpillSize,
                spillSpaceTracker,
                listJsonCodec(TaskMemoryReservationSummary.class));
        taskContext = queryContext.addTaskContext(
                new TaskStateMachine(new TaskId("query", 0, 0, 0, 0), notificationExecutor),
                testSessionBuilder().build(),
                Optional.of(PLAN_FRAGMENT.getRoot()),
                true,
                true,
                true,
                true,
                false);
        pipelineContext = taskContext.addPipelineContext(0, true, true, false);
        driverContext = pipelineContext.addDriverContext();
        operatorContext = driverContext.addOperatorContext(1, new PlanNodeId("a"), "test-operator");
    }

    @Test
    public void testOperatorAllocations()
    {
        MemoryTrackingContext operatorMemoryContext = operatorContext.getOperatorMemoryContext();
        LocalMemoryContext systemMemory = operatorContext.localSystemMemoryContext();
        LocalMemoryContext userMemory = operatorContext.localUserMemoryContext();
        LocalMemoryContext revocableMemory = operatorContext.localRevocableMemoryContext();
        userMemory.setBytes(100);
        assertOperatorMemoryAllocations(operatorMemoryContext, 100, 0, 0);
        systemMemory.setBytes(1_000_000);
        assertOperatorMemoryAllocations(operatorMemoryContext, 100, 1_000_000, 0);
        systemMemory.setBytes(2_000_000);
        assertOperatorMemoryAllocations(operatorMemoryContext, 100, 2_000_000, 0);
        userMemory.setBytes(500);
        assertOperatorMemoryAllocations(operatorMemoryContext, 500, 2_000_000, 0);
        userMemory.setBytes(userMemory.getBytes() - 500);
        assertOperatorMemoryAllocations(operatorMemoryContext, 0, 2_000_000, 0);
        revocableMemory.setBytes(300);
        assertOperatorMemoryAllocations(operatorMemoryContext, 0, 2_000_000, 300);
        assertAllocationFails((ignored) -> userMemory.setBytes(userMemory.getBytes() - 500), "bytes cannot be negative");
        operatorContext.destroy();
        assertOperatorMemoryAllocations(operatorMemoryContext, 0, 0, 0);
    }

    @Test
    public void testLocalTotalMemoryLimitExceeded()
    {
        LocalMemoryContext systemMemoryContext = operatorContext.localSystemMemoryContext();
        systemMemoryContext.setBytes(100);
        assertOperatorMemoryAllocations(operatorContext.getOperatorMemoryContext(), 0, 100, 0);
        systemMemoryContext.setBytes(queryMaxTotalMemory.toBytes());
        assertOperatorMemoryAllocations(operatorContext.getOperatorMemoryContext(), 0, queryMaxTotalMemory.toBytes(), 0);
        try {
            systemMemoryContext.setBytes(queryMaxTotalMemory.toBytes() + 1);
            fail("allocation should hit the per-node total memory limit");
        }
        catch (ExceededMemoryLimitException e) {
            assertEquals(e.getMessage(), format("Query exceeded per-node total memory limit of %1$s [Allocated: %1$s, Delta: 1B (test-operator), Top Consumers: {test-operator=%1$s}]", queryMaxTotalMemory));
        }
    }

    @Test
    public void testLocalRevocableMemoryLimitExceeded()
    {
        LocalMemoryContext revocableMemoryContext = operatorContext.localRevocableMemoryContext();
        revocableMemoryContext.setBytes(100);
        assertOperatorMemoryAllocations(operatorContext.getOperatorMemoryContext(), 0, 0, 100);
        revocableMemoryContext.setBytes(queryMaxRevocableMemory.toBytes());
        assertOperatorMemoryAllocations(operatorContext.getOperatorMemoryContext(), 0, 0, queryMaxRevocableMemory.toBytes());
        try {
            revocableMemoryContext.setBytes(queryMaxRevocableMemory.toBytes() + 1);
            fail("allocation should hit the per-node revocable memory limit");
        }
        catch (ExceededMemoryLimitException e) {
            assertEquals(e.getMessage(), format("Query exceeded per-node revocable memory limit of %1$s [Allocated: %1$s, Delta: 1B (test-operator)]", queryMaxRevocableMemory));
        }
    }

    @Test
    public void testLocalSystemAllocations()
    {
        long pipelineLocalAllocation = 1_000_000;
        long taskLocalAllocation = 10_000_000;
        LocalMemoryContext pipelineLocalSystemMemoryContext = pipelineContext.localSystemMemoryContext();
        pipelineLocalSystemMemoryContext.setBytes(pipelineLocalAllocation);
        assertLocalMemoryAllocations(pipelineContext.getPipelineMemoryContext(),
                pipelineLocalAllocation,
                0,
                pipelineLocalAllocation);
        LocalMemoryContext taskLocalSystemMemoryContext = taskContext.localSystemMemoryContext();
        taskLocalSystemMemoryContext.setBytes(taskLocalAllocation);
        assertLocalMemoryAllocations(
                taskContext.getTaskMemoryContext(),
                pipelineLocalAllocation + taskLocalAllocation,
                0,
                taskLocalAllocation);
        assertEquals(pipelineContext.getPipelineStats().getSystemMemoryReservationInBytes(),
                pipelineLocalAllocation,
                "task level allocations should not be visible at the pipeline level");
        pipelineLocalSystemMemoryContext.setBytes(pipelineLocalSystemMemoryContext.getBytes() - pipelineLocalAllocation);
        assertLocalMemoryAllocations(
                pipelineContext.getPipelineMemoryContext(),
                taskLocalAllocation,
                0,
                0);
        taskLocalSystemMemoryContext.setBytes(taskLocalSystemMemoryContext.getBytes() - taskLocalAllocation);
        assertLocalMemoryAllocations(
                taskContext.getTaskMemoryContext(),
                0,
                0,
                0);
    }

    @Test
    public void testStats()
    {
        LocalMemoryContext systemMemory = operatorContext.localSystemMemoryContext();
        LocalMemoryContext userMemory = operatorContext.localUserMemoryContext();
        userMemory.setBytes(100_000_000);
        systemMemory.setBytes(200_000_000);

        assertStats(
                operatorContext.getOperatorStats(),
                driverContext.getDriverStats(),
                pipelineContext.getPipelineStats(),
                taskContext.getTaskStats(),
                100_000_000,
                0,
                200_000_000);

        // allocate more and check peak memory reservation
        userMemory.setBytes(600_000_000);
        assertStats(
                operatorContext.getOperatorStats(),
                driverContext.getDriverStats(),
                pipelineContext.getPipelineStats(),
                taskContext.getTaskStats(),
                600_000_000,
                0,
                200_000_000);

        userMemory.setBytes(userMemory.getBytes() - 300_000_000);
        assertStats(
                operatorContext.getOperatorStats(),
                driverContext.getDriverStats(),
                pipelineContext.getPipelineStats(),
                taskContext.getTaskStats(),
                300_000_000,
                0,
                200_000_000);

        userMemory.setBytes(userMemory.getBytes() - 300_000_000);
        assertStats(
                operatorContext.getOperatorStats(),
                driverContext.getDriverStats(),
                pipelineContext.getPipelineStats(),
                taskContext.getTaskStats(),
                0,
                0,
                200_000_000);

        operatorContext.destroy();

        assertStats(
                operatorContext.getOperatorStats(),
                driverContext.getDriverStats(),
                pipelineContext.getPipelineStats(),
                taskContext.getTaskStats(),
                0,
                0,
                0);
    }

    @Test
    public void testRevocableMemoryAllocations()
    {
        LocalMemoryContext systemMemory = operatorContext.localSystemMemoryContext();
        LocalMemoryContext userMemory = operatorContext.localUserMemoryContext();
        LocalMemoryContext revocableMemory = operatorContext.localRevocableMemoryContext();
        revocableMemory.setBytes(100_000_000);
        assertStats(
                operatorContext.getOperatorStats(),
                driverContext.getDriverStats(),
                pipelineContext.getPipelineStats(),
                taskContext.getTaskStats(),
                0,
                100_000_000,
                0);
        userMemory.setBytes(100_000_000);
        systemMemory.setBytes(100_000_000);
        revocableMemory.setBytes(200_000_000);
        assertStats(
                operatorContext.getOperatorStats(),
                driverContext.getDriverStats(),
                pipelineContext.getPipelineStats(),
                taskContext.getTaskStats(),
                100_000_000,
                200_000_000,
                100_000_000);
    }

    @Test
    public void testTrySetBytes()
    {
        LocalMemoryContext localMemoryContext = operatorContext.localUserMemoryContext();
        assertTrue(localMemoryContext.trySetBytes(100_000_000));
        assertStats(
                operatorContext.getOperatorStats(),
                driverContext.getDriverStats(),
                pipelineContext.getPipelineStats(),
                taskContext.getTaskStats(),
                100_000_000,
                0,
                0);

        assertTrue(localMemoryContext.trySetBytes(200_000_000));
        assertStats(
                operatorContext.getOperatorStats(),
                driverContext.getDriverStats(),
                pipelineContext.getPipelineStats(),
                taskContext.getTaskStats(),
                200_000_000,
                0,
                0);

        assertTrue(localMemoryContext.trySetBytes(100_000_000));
        assertStats(
                operatorContext.getOperatorStats(),
                driverContext.getDriverStats(),
                pipelineContext.getPipelineStats(),
                taskContext.getTaskStats(),
                100_000_000,
                0,
                0);

        // allocating more than the pool size should fail and we should have the same stats as before
        assertFalse(localMemoryContext.trySetBytes(memoryPool.getMaxBytes() + 1));
        assertStats(
                operatorContext.getOperatorStats(),
                driverContext.getDriverStats(),
                pipelineContext.getPipelineStats(),
                taskContext.getTaskStats(),
                100_000_000,
                0,
                0);
    }

    @Test
    public void testTrySetZeroBytesFullPool()
    {
        LocalMemoryContext localMemoryContext = operatorContext.localUserMemoryContext();
        // fill up the pool
        memoryPool.reserve(new QueryId("test_query"), "test", memoryPool.getFreeBytes());
        // try to reserve 0 bytes in the full pool
        assertTrue(localMemoryContext.trySetBytes(localMemoryContext.getBytes()));
    }

    @Test
    public void testDestroy()
    {
        LocalMemoryContext newLocalSystemMemoryContext = operatorContext.localSystemMemoryContext();
        LocalMemoryContext newLocalUserMemoryContext = operatorContext.localUserMemoryContext();
        LocalMemoryContext newLocalRevocableMemoryContext = operatorContext.localRevocableMemoryContext();
        newLocalSystemMemoryContext.setBytes(100_000);
        newLocalRevocableMemoryContext.setBytes(200_000);
        newLocalUserMemoryContext.setBytes(400_000);
        assertEquals(operatorContext.getOperatorMemoryContext().getSystemMemory(), 100_000);
        assertEquals(operatorContext.getOperatorMemoryContext().getUserMemory(), 400_000);
        operatorContext.destroy();
        assertOperatorMemoryAllocations(operatorContext.getOperatorMemoryContext(), 0, 0, 0);
    }

    @Test
    public void testCumulativeUserMemoryEstimation()
    {
        LocalMemoryContext userMemory = operatorContext.localUserMemoryContext();
        long userMemoryBytes = 100_000_000;
        userMemory.setBytes(userMemoryBytes);
        long startTime = System.nanoTime();
        double cumulativeUserMemory = taskContext.getTaskStats().getCumulativeUserMemory();
        long endTime = System.nanoTime();

        double elapsedTimeInMillis = (endTime - startTime) / 1_000_000.0;
        long averageMemoryForLastPeriod = userMemoryBytes / 2;

        assertTrue(cumulativeUserMemory < elapsedTimeInMillis * averageMemoryForLastPeriod);
    }

    @Test
    public void testCumulativeTotalMemoryEstimation()
    {
        LocalMemoryContext userMemory = operatorContext.localUserMemoryContext();
        LocalMemoryContext systemMemory = operatorContext.localSystemMemoryContext();
        long userMemoryBytes = 100_000_000;
        long systemMemoryBytes = 40_000_000;
        userMemory.setBytes(userMemoryBytes);
        systemMemory.setBytes(systemMemoryBytes);
        long startTime = System.nanoTime();
        double cumulativeTotalMemory = taskContext.getTaskStats().getCumulativeTotalMemory();
        long endTime = System.nanoTime();

        double elapsedTimeInMillis = (endTime - startTime) / 1_000_000.0;
        long averageMemoryForLastPeriod = (userMemoryBytes + systemMemoryBytes) / 2;

        assertTrue(cumulativeTotalMemory < elapsedTimeInMillis * averageMemoryForLastPeriod);
    }

    private void assertStats(
            OperatorStats operatorStats,
            DriverStats driverStats,
            PipelineStats pipelineStats,
            TaskStats taskStats,
            long expectedUserMemory,
            long expectedRevocableMemory,
            long expectedSystemMemory)
    {
        assertEquals(operatorStats.getUserMemoryReservationInBytes(), expectedUserMemory);
        assertEquals(driverStats.getUserMemoryReservationInBytes(), expectedUserMemory);
        assertEquals(pipelineStats.getUserMemoryReservationInBytes(), expectedUserMemory);
        assertEquals(taskStats.getUserMemoryReservationInBytes(), expectedUserMemory);

        assertEquals(operatorStats.getSystemMemoryReservationInBytes(), expectedSystemMemory);
        assertEquals(driverStats.getSystemMemoryReservationInBytes(), expectedSystemMemory);
        assertEquals(pipelineStats.getSystemMemoryReservationInBytes(), expectedSystemMemory);
        assertEquals(taskStats.getSystemMemoryReservationInBytes(), expectedSystemMemory);

        assertEquals(operatorStats.getRevocableMemoryReservationInBytes(), expectedRevocableMemory);
        assertEquals(driverStats.getRevocableMemoryReservationInBytes(), expectedRevocableMemory);
        assertEquals(pipelineStats.getRevocableMemoryReservationInBytes(), expectedRevocableMemory);
        assertEquals(taskStats.getRevocableMemoryReservationInBytes(), expectedRevocableMemory);
    }

    private void assertAllocationFails(Consumer<Void> allocationFunction, String expectedPattern)
    {
        try {
            allocationFunction.accept(null);
            fail("Expected exception");
        }
        catch (IllegalArgumentException e) {
            assertTrue(Pattern.matches(expectedPattern, e.getMessage()),
                    "\nExpected (re) :" + expectedPattern + "\nActual :" + e.getMessage());
        }
    }

    // the allocations that are done at the operator level are reflected at that level and all the way up to the pools
    private void assertOperatorMemoryAllocations(
            MemoryTrackingContext memoryTrackingContext,
            long expectedUserMemory,
            long expectedSystemMemory,
            long expectedRevocableMemory)
    {
        assertEquals(memoryTrackingContext.getUserMemory(), expectedUserMemory, "User memory verification failed");
        // both user and system memory are allocated from the same memoryPool
        assertEquals(memoryPool.getReservedBytes(), expectedUserMemory + expectedSystemMemory, "Memory pool verification failed");
        assertEquals(memoryTrackingContext.getSystemMemory(), expectedSystemMemory, "System memory verification failed");
        assertEquals(memoryTrackingContext.getRevocableMemory(), expectedRevocableMemory, "Revocable memory verification failed");
    }

    // the local allocations are reflected only at that level and all the way up to the pools
    private void assertLocalMemoryAllocations(
            MemoryTrackingContext memoryTrackingContext,
            long expectedPoolMemory,
            long expectedContextUserMemory,
            long expectedContextSystemMemory)
    {
        assertEquals(memoryTrackingContext.getUserMemory(), expectedContextUserMemory, "User memory verification failed");
        assertEquals(memoryPool.getReservedBytes(), expectedPoolMemory, "Memory pool verification failed");
        assertEquals(memoryTrackingContext.localSystemMemoryContext().getBytes(), expectedContextSystemMemory, "Local system memory verification failed");
    }
}