TestPredictorManager.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.router;

import com.facebook.airlift.bootstrap.Bootstrap;
import com.facebook.airlift.bootstrap.LifeCycleManager;
import com.facebook.airlift.http.server.testing.TestingHttpServerModule;
import com.facebook.airlift.jaxrs.JaxrsModule;
import com.facebook.airlift.json.JsonModule;
import com.facebook.airlift.log.Logging;
import com.facebook.airlift.node.testing.TestingNodeModule;
import com.facebook.presto.router.predictor.CpuInfo;
import com.facebook.presto.router.predictor.MemoryInfo;
import com.facebook.presto.router.predictor.PredictorManager;
import com.facebook.presto.router.predictor.ResourceGroup;
import com.facebook.presto.server.testing.TestingPrestoServer;
import com.facebook.presto.tpch.TpchPlugin;
import com.google.common.collect.ImmutableList;
import com.google.inject.Injector;
import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Optional;

import static com.facebook.presto.router.TestingRouterUtil.getConfigFile;
import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;

public class TestPredictorManager
{
    private static final int NUM_CLUSTERS = 2;

    private List<TestingPrestoServer> prestoServers;
    private LifeCycleManager lifeCycleManager;
    private PredictorManager predictorManager;
    private File configFile;
    private MockWebServer predictorServer;

    @BeforeClass
    public void setup()
            throws Exception
    {
        Logging.initialize();

        // set up server
        ImmutableList.Builder builder = ImmutableList.builder();
        for (int i = 0; i < NUM_CLUSTERS; ++i) {
            builder.add(createPrestoServer());
        }
        prestoServers = builder.build();
        File tempFile = File.createTempFile("router", "json");
        configFile = getConfigFile(prestoServers, tempFile);

        Bootstrap app = new Bootstrap(
                new TestingNodeModule("test"),
                new TestingHttpServerModule(),
                new JsonModule(),
                new JaxrsModule(true),
                new RouterModule(Optional.empty()));

        Injector injector = app
                .doNotInitializeLogging()
                .setRequiredConfigurationProperty("router.config-file", configFile.getAbsolutePath())
                .quiet()
                .initialize();

        lifeCycleManager = injector.getInstance(LifeCycleManager.class);
        predictorManager = injector.getInstance(PredictorManager.class);
        initializePredictorServer();
    }

    @Test(enabled = false)
    public void testPredictor()
    {
        String sql = "select * from presto.logs";

        ResourceGroup resourceGroup = predictorManager.fetchPrediction(sql).orElse(null);
        assertNotNull(resourceGroup, "The resource group should not be null");
        assertNotNull(resourceGroup.getCpuInfo());
        assertNotNull(resourceGroup.getMemoryInfo());

        resourceGroup = predictorManager.fetchPredictionParallel(sql).orElse(null);
        assertNotNull(resourceGroup, "The resource group should not be null");
        assertNotNull(resourceGroup.getCpuInfo());
        assertNotNull(resourceGroup.getMemoryInfo());

        CpuInfo cpuInfo = predictorManager.fetchCpuPrediction(sql).orElse(null);
        MemoryInfo memoryInfo = predictorManager.fetchMemoryPrediction(sql).orElse(null);
        assertNotNull(cpuInfo);
        assertNotNull(memoryInfo);

        int low = 0;
        int high = 3;
        assertTrue(low <= cpuInfo.getCpuTimeLabel(), "CPU time label should be larger or equal to " + low);
        assertTrue(cpuInfo.getCpuTimeLabel() <= high, "CPU time label should be smaller or equal to " + high);
        assertTrue(low <= memoryInfo.getMemoryBytesLabel(), "Memory bytes label should be larger or equal to " + low);
        assertTrue(memoryInfo.getMemoryBytesLabel() <= high, "Memory bytes label should be smaller or equal to " + high);
    }

    @AfterClass(alwaysRun = true)
    public void tearDownServer()
            throws Exception
    {
        for (TestingPrestoServer prestoServer : prestoServers) {
            prestoServer.close();
        }
        lifeCycleManager.stop();
        predictorServer.close();
    }

    private static TestingPrestoServer createPrestoServer()
            throws Exception
    {
        TestingPrestoServer server = new TestingPrestoServer();
        server.installPlugin(new TpchPlugin());
        server.createCatalog("tpch", "tpch");
        server.refreshNodes();

        return server;
    }

    private void initializePredictorServer()
            throws IOException
    {
        Dispatcher dispatcher = new Dispatcher()
        {
            @Override
            public MockResponse dispatch(RecordedRequest request)
            {
                switch (request.getPath()) {
                    case "/v1/cpu":
                        return new MockResponse()
                                .addHeader(CONTENT_TYPE, "application/json")
                                .setBody("{\"cpu_pred_label\": 2, \"cpu_pred_str\": \"1h - 5h\"}");
                    case "/v1/memory":
                        return new MockResponse()
                                .addHeader(CONTENT_TYPE, "application/json")
                                .setBody("{\"memory_pred_label\": 2, \"memory_pred_str\": \"> 1TB\"}");
                }
                return new MockResponse().setResponseCode(404);
            }
        };

        predictorServer = new MockWebServer();
        predictorServer.setDispatcher(dispatcher);
        predictorServer.start(8000);
    }
}