TestMemoryPools.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.memory;

import com.facebook.airlift.stats.TestingGcMonitor;
import com.facebook.presto.CompressionCodec;
import com.facebook.presto.Session;
import com.facebook.presto.common.Page;
import com.facebook.presto.execution.buffer.TestingPagesSerdeFactory;
import com.facebook.presto.memory.context.LocalMemoryContext;
import com.facebook.presto.operator.Driver;
import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.Operator;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.OutputFactory;
import com.facebook.presto.operator.TableScanOperator;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.operator.TaskMemoryReservationSummary;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.memory.MemoryPoolId;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spiller.SpillSpaceTracker;
import com.facebook.presto.testing.LocalQueryRunner;
import com.facebook.presto.testing.PageConsumerOperator.PageConsumerOutputFactory;
import com.facebook.presto.tpch.TpchConnectorFactory;
import com.google.common.base.Predicate;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.DataSize;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.Test;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static com.facebook.airlift.json.JsonCodec.listJsonCodec;
import static com.facebook.presto.SystemSessionProperties.REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN;
import static com.facebook.presto.testing.LocalQueryRunner.queryRunnerWithInitialTransaction;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.facebook.presto.testing.TestingTaskContext.createTaskContext;
import static com.google.common.base.Preconditions.checkState;
import static io.airlift.units.DataSize.Unit.BYTE;
import static io.airlift.units.DataSize.Unit.GIGABYTE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;

@Test(singleThreaded = true)
public class TestMemoryPools
{
    private static final DataSize TEN_MEGABYTES = new DataSize(10, MEGABYTE);
    private static final DataSize TEN_MEGABYTES_WITHOUT_TWO_BYTES = new DataSize(TEN_MEGABYTES.toBytes() - 2, BYTE);
    private static final DataSize ONE_BYTE = new DataSize(1, BYTE);

    private QueryId fakeQueryId;
    private LocalQueryRunner localQueryRunner;
    private MemoryPool userPool;
    private List<Driver> drivers;
    private TaskContext taskContext;

    private void setUp(Supplier<List<Driver>> driversSupplier)
    {
        checkState(localQueryRunner == null, "Already set up");

        Session session = testSessionBuilder()
                .setCatalog("tpch")
                .setSchema("tiny")
                .setSystemProperty("task_default_concurrency", "1")
                .setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, "false")
                .build();

        localQueryRunner = queryRunnerWithInitialTransaction(session);

        // add tpch
        localQueryRunner.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of());

        userPool = new MemoryPool(new MemoryPoolId("test"), TEN_MEGABYTES);
        fakeQueryId = new QueryId("fake");
        SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(new DataSize(1, GIGABYTE));
        QueryContext queryContext = new QueryContext(new QueryId("query"),
                TEN_MEGABYTES,
                new DataSize(20, MEGABYTE),
                TEN_MEGABYTES,
                new DataSize(1, GIGABYTE),
                userPool,
                new TestingGcMonitor(),
                localQueryRunner.getExecutor(),
                localQueryRunner.getScheduler(),
                TEN_MEGABYTES,
                spillSpaceTracker,
                listJsonCodec(TaskMemoryReservationSummary.class));
        taskContext = createTaskContext(queryContext, localQueryRunner.getExecutor(), session);
        drivers = driversSupplier.get();
    }

    private void setUpCountStarFromOrdersWithJoin()
    {
        // query will reserve all memory in the user pool and discard the output
        setUp(() -> {
            OutputFactory outputFactory = new PageConsumerOutputFactory(types -> (page -> {}));
            return localQueryRunner.createDrivers("SELECT COUNT(*) FROM orders JOIN lineitem ON CAST(orders.orderkey AS VARCHAR) = CAST(lineitem.orderkey AS VARCHAR)", outputFactory, taskContext);
        });
    }

    private RevocableMemoryOperator setupConsumeRevocableMemory(DataSize reservedPerPage, long numberOfPages)
    {
        AtomicReference<RevocableMemoryOperator> createOperator = new AtomicReference<>();
        setUp(() -> {
            DriverContext driverContext = taskContext.addPipelineContext(0, false, false, false).addDriverContext();
            OperatorContext revokableOperatorContext = driverContext.addOperatorContext(
                    Integer.MAX_VALUE,
                    new PlanNodeId("revokable_operator"),
                    TableScanOperator.class.getSimpleName());

            OutputFactory outputFactory = new PageConsumerOutputFactory(types -> (page -> {}));
            Operator outputOperator = outputFactory.createOutputOperator(2, new PlanNodeId("output"), ImmutableList.of(), Function.identity(), Optional.empty(), new TestingPagesSerdeFactory(CompressionCodec.LZ4))
                    .createOperator(driverContext);
            RevocableMemoryOperator revocableMemoryOperator = new RevocableMemoryOperator(revokableOperatorContext, reservedPerPage, numberOfPages);
            createOperator.set(revocableMemoryOperator);

            Driver driver = Driver.createDriver(driverContext, revocableMemoryOperator, outputOperator);
            return ImmutableList.of(driver);
        });
        return createOperator.get();
    }

    @AfterMethod(alwaysRun = true)
    public void tearDown()
    {
        if (localQueryRunner != null) {
            localQueryRunner.close();
            localQueryRunner = null;
        }
    }

    @Test
    public void testBlockingOnUserMemory()
    {
        setUpCountStarFromOrdersWithJoin();
        assertTrue(userPool.tryReserve(fakeQueryId, "test", TEN_MEGABYTES.toBytes()));
        runDriversUntilBlocked(waitingForUserMemory());
        assertTrue(userPool.getFreeBytes() <= 0, String.format("Expected empty pool but got [%d]", userPool.getFreeBytes()));
        userPool.free(fakeQueryId, "test", TEN_MEGABYTES.toBytes());
        assertDriversProgress(waitingForUserMemory());
    }

    @Test
    public void testNotifyListenerOnMemoryReserved()
    {
        setupConsumeRevocableMemory(ONE_BYTE, 10);
        AtomicReference<MemoryPool> notifiedPool = new AtomicReference<>();
        AtomicLong notifiedBytes = new AtomicLong();
        userPool.addListener((pool, queryId, memoryReservation) -> {
            notifiedPool.set(pool);
            notifiedBytes.set(pool.getReservedBytes());
        });

        userPool.reserve(fakeQueryId, "test", 3);
        assertEquals(notifiedPool.get(), userPool);
        assertEquals(notifiedBytes.get(), 3L);
    }

    @Test
    public void testMemoryFutureCancellation()
    {
        setUpCountStarFromOrdersWithJoin();
        ListenableFuture future = userPool.reserve(fakeQueryId, "test", TEN_MEGABYTES.toBytes());
        assertTrue(!future.isDone());
        try {
            future.cancel(true);
            fail("cancel should fail");
        }
        catch (UnsupportedOperationException e) {
            assertEquals(e.getMessage(), "cancellation is not supported");
        }
        userPool.free(fakeQueryId, "test", TEN_MEGABYTES.toBytes());
        assertTrue(future.isDone());
    }

    @Test
    public void testBlockingOnRevocableMemoryFreeUser()
    {
        setupConsumeRevocableMemory(ONE_BYTE, 10);
        assertTrue(userPool.tryReserve(fakeQueryId, "test", TEN_MEGABYTES_WITHOUT_TWO_BYTES.toBytes()));

        // we expect 2 iterations as we have 2 bytes remaining in memory pool and we allocate 1 byte per page
        assertEquals(runDriversUntilBlocked(waitingForRevocableSystemMemory()), 2);
        assertTrue(userPool.getFreeBytes() <= 0, String.format("Expected empty pool but got [%d]", userPool.getFreeBytes()));

        // lets free 5 bytes
        userPool.free(fakeQueryId, "test", 5);
        assertEquals(runDriversUntilBlocked(waitingForRevocableSystemMemory()), 5);
        assertTrue(userPool.getFreeBytes() <= 0, String.format("Expected empty pool but got [%d]", userPool.getFreeBytes()));

        // 3 more bytes is enough for driver to finish
        userPool.free(fakeQueryId, "test", 3);
        assertDriversProgress(waitingForRevocableSystemMemory());
        assertEquals(userPool.getFreeBytes(), 10);
    }

    @Test
    public void testBlockingOnRevocableMemoryFreeViaRevoke()
    {
        RevocableMemoryOperator revocableMemoryOperator = setupConsumeRevocableMemory(ONE_BYTE, 5);
        assertTrue(userPool.tryReserve(fakeQueryId, "test", TEN_MEGABYTES_WITHOUT_TWO_BYTES.toBytes()));

        // we expect 2 iterations as we have 2 bytes remaining in memory pool and we allocate 1 byte per page
        assertEquals(runDriversUntilBlocked(waitingForRevocableSystemMemory()), 2);
        revocableMemoryOperator.getOperatorContext().requestMemoryRevoking();

        // 2 more iterations
        assertEquals(runDriversUntilBlocked(waitingForRevocableSystemMemory()), 2);
        revocableMemoryOperator.getOperatorContext().requestMemoryRevoking();

        // 3 more bytes is enough for driver to finish
        assertDriversProgress(waitingForRevocableSystemMemory());
        assertEquals(userPool.getFreeBytes(), 2);
    }

    @Test
    public void testTaggedAllocations()
    {
        QueryId testQuery = new QueryId("test_query");
        MemoryPool testPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(1000, BYTE));

        testPool.reserve(testQuery, "test_tag", 10);

        Map<String, Long> allocations = testPool.getTaggedMemoryAllocations(testQuery);
        assertEquals(allocations, ImmutableMap.of("test_tag", 10L));

        // free 5 bytes for test_tag
        testPool.free(testQuery, "test_tag", 5);
        allocations = testPool.getTaggedMemoryAllocations(testQuery);
        assertEquals(allocations, ImmutableMap.of("test_tag", 5L));

        testPool.reserve(testQuery, "test_tag2", 20);
        allocations = testPool.getTaggedMemoryAllocations(testQuery);
        assertEquals(allocations, ImmutableMap.of("test_tag", 5L, "test_tag2", 20L));

        // free the remaining 5 bytes for test_tag
        testPool.free(testQuery, "test_tag", 5);
        allocations = testPool.getTaggedMemoryAllocations(testQuery);
        assertEquals(allocations, ImmutableMap.of("test_tag2", 20L));

        // free all for test_tag2
        testPool.free(testQuery, "test_tag2", 20);
        assertEquals(testPool.getTaggedMemoryAllocations().size(), 0);
    }

    @Test
    public void testMoveQuery()
    {
        QueryId testQuery = new QueryId("test_query");
        MemoryPool pool1 = new MemoryPool(new MemoryPoolId("test"), new DataSize(1000, BYTE));
        MemoryPool pool2 = new MemoryPool(new MemoryPoolId("test"), new DataSize(1000, BYTE));
        pool1.reserve(testQuery, "test_tag", 10);

        Map<String, Long> allocations = pool1.getTaggedMemoryAllocations(testQuery);
        assertEquals(allocations, ImmutableMap.of("test_tag", 10L));

        pool1.moveQuery(testQuery, pool2);
        assertNull(pool1.getTaggedMemoryAllocations(testQuery));
        allocations = pool2.getTaggedMemoryAllocations(testQuery);
        assertEquals(allocations, ImmutableMap.of("test_tag", 10L));

        assertEquals(pool1.getFreeBytes(), 1000);
        assertEquals(pool2.getFreeBytes(), 990);

        pool2.free(testQuery, "test", 10);
        assertEquals(pool2.getFreeBytes(), 1000);
    }

    private long runDriversUntilBlocked(Predicate<OperatorContext> reason)
    {
        long iterationsCount = 0;

        // run driver, until it blocks
        while (!isOperatorBlocked(drivers, reason)) {
            for (Driver driver : drivers) {
                driver.process();
            }
            iterationsCount++;
        }

        // driver should be blocked waiting for memory
        for (Driver driver : drivers) {
            assertFalse(driver.isFinished());
        }
        return iterationsCount;
    }

    private void assertDriversProgress(Predicate<OperatorContext> reason)
    {
        do {
            assertFalse(isOperatorBlocked(drivers, reason));
            boolean progress = false;
            for (Driver driver : drivers) {
                ListenableFuture<?> blocked = driver.process();
                progress = progress | blocked.isDone();
            }
            // query should not block
            assertTrue(progress);
        }
        while (!drivers.stream().allMatch(Driver::isFinished));
    }

    private Predicate<OperatorContext> waitingForUserMemory()
    {
        return (OperatorContext operatorContext) -> !operatorContext.isWaitingForMemory().isDone();
    }

    private Predicate<OperatorContext> waitingForRevocableSystemMemory()
    {
        return (OperatorContext operatorContext) ->
                !operatorContext.isWaitingForRevocableMemory().isDone() &&
                        !operatorContext.isMemoryRevokingRequested();
    }

    private static boolean isOperatorBlocked(List<Driver> drivers, Predicate<OperatorContext> reason)
    {
        for (Driver driver : drivers) {
            for (OperatorContext operatorContext : driver.getDriverContext().getOperatorContexts()) {
                if (reason.apply(operatorContext)) {
                    return true;
                }
            }
        }
        return false;
    }

    private class RevocableMemoryOperator
            implements Operator
    {
        private final DataSize reservedPerPage;
        private final long numberOfPages;
        private final OperatorContext operatorContext;
        private long producedPagesCount;
        private final LocalMemoryContext revocableMemoryContext;

        public RevocableMemoryOperator(OperatorContext operatorContext, DataSize reservedPerPage, long numberOfPages)
        {
            this.operatorContext = operatorContext;
            this.reservedPerPage = reservedPerPage;
            this.numberOfPages = numberOfPages;
            this.revocableMemoryContext = operatorContext.localRevocableMemoryContext();
        }

        @Override
        public ListenableFuture<?> startMemoryRevoke()
        {
            return Futures.immediateFuture(null);
        }

        @Override
        public void finishMemoryRevoke()
        {
            revocableMemoryContext.setBytes(0);
        }

        @Override
        public OperatorContext getOperatorContext()
        {
            return operatorContext;
        }

        @Override
        public void finish()
        {
            revocableMemoryContext.setBytes(0);
        }

        @Override
        public boolean isFinished()
        {
            return producedPagesCount >= numberOfPages;
        }

        @Override
        public boolean needsInput()
        {
            return false;
        }

        @Override
        public void addInput(Page page)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public Page getOutput()
        {
            revocableMemoryContext.setBytes(revocableMemoryContext.getBytes() + reservedPerPage.toBytes());
            producedPagesCount++;
            if (producedPagesCount == numberOfPages) {
                finish();
            }
            return new Page(10);
        }
    }
}