TestPrestoSparkHttpClient.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.execution.http;

import com.facebook.airlift.http.client.HeaderName;
import com.facebook.airlift.http.client.HttpClient;
import com.facebook.airlift.http.client.HttpStatus;
import com.facebook.airlift.http.client.Request;
import com.facebook.airlift.http.client.RequestStats;
import com.facebook.airlift.http.client.Response;
import com.facebook.airlift.http.client.ResponseHandler;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.presto.client.ServerInfo;
import com.facebook.presto.execution.QueryManagerConfig;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskInfo;
import com.facebook.presto.execution.TaskManagerConfig;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.TaskState;
import com.facebook.presto.execution.TaskStatus;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.operator.PageBufferClient;
import com.facebook.presto.operator.PageTransportErrorException;
import com.facebook.presto.operator.TaskStats;
import com.facebook.presto.server.smile.BaseResponse;
import com.facebook.presto.spark.execution.nativeprocess.HttpNativeExecutionTaskInfoFetcher;
import com.facebook.presto.spark.execution.nativeprocess.HttpNativeExecutionTaskResultFetcher;
import com.facebook.presto.spark.execution.nativeprocess.NativeExecutionProcess;
import com.facebook.presto.spark.execution.nativeprocess.NativeExecutionProcessFactory;
import com.facebook.presto.spark.execution.property.NativeExecutionConnectorConfig;
import com.facebook.presto.spark.execution.property.NativeExecutionNodeConfig;
import com.facebook.presto.spark.execution.property.NativeExecutionSystemConfig;
import com.facebook.presto.spark.execution.property.NativeExecutionVeloxConfig;
import com.facebook.presto.spark.execution.property.PrestoSparkWorkerProperty;
import com.facebook.presto.spark.execution.task.NativeExecutionTask;
import com.facebook.presto.spark.execution.task.NativeExecutionTaskFactory;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.PrestoTransportException;
import com.facebook.presto.spi.page.PageCodecMarker;
import com.facebook.presto.spi.page.PagesSerdeUtil;
import com.facebook.presto.spi.page.SerializedPage;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.testing.TestingSession;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.net.MediaType;
import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;

import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilder;
import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
import static com.facebook.presto.PrestoMediaTypes.PRESTO_PAGES_TYPE;
import static com.facebook.presto.client.NodeVersion.UNKNOWN;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_BUFFER_COMPLETE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_PAGE_NEXT_TOKEN;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_PAGE_TOKEN;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_TASK_INSTANCE_ID;
import static com.facebook.presto.execution.TaskTestUtils.createPlanFragment;
import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.PARTITIONED;
import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.expectThrows;
import static org.testng.Assert.fail;

public class TestPrestoSparkHttpClient
{
    private static final String TASK_ROOT_PATH = "/v1/task";
    private static final URI BASE_URI = uriBuilder()
            .scheme("http")
            .host("localhost")
            .port(8080)
            .build();
    private static final Duration NO_DURATION = new Duration(0, TimeUnit.MILLISECONDS);
    private static final JsonCodec<TaskInfo> TASK_INFO_JSON_CODEC = JsonCodec.jsonCodec(TaskInfo.class);
    private static final JsonCodec<PlanFragment> PLAN_FRAGMENT_JSON_CODEC = JsonCodec.jsonCodec(PlanFragment.class);
    private static final JsonCodec<BatchTaskUpdateRequest> TASK_UPDATE_REQUEST_JSON_CODEC = JsonCodec.jsonCodec(BatchTaskUpdateRequest.class);
    private static final JsonCodec<ServerInfo> SERVER_INFO_JSON_CODEC = JsonCodec.jsonCodec(ServerInfo.class);

    private ScheduledExecutorService scheduledExecutorService;

    @BeforeClass
    public void beforeClass()
    {
        scheduledExecutorService = newScheduledThreadPool(4);
    }

    @AfterClass(alwaysRun = true)
    public void afterClass()
    {
        if (scheduledExecutorService != null) {
            scheduledExecutorService.shutdownNow();
            scheduledExecutorService = null;
        }
    }

    @Test
    public void testResultGet()
    {
        TaskId taskId = new TaskId(
                "testid",
                0,
                0,
                0,
                0);

        PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
        ListenableFuture<PageBufferClient.PagesResponse> future = workerClient.getResults(
                0,
                new DataSize(32, MEGABYTE));
        try {
            PageBufferClient.PagesResponse page = future.get();
            assertEquals(0, page.getToken());
            assertTrue(page.isClientComplete());
            assertEquals(taskId.toString(), page.getTaskInstanceId());
        }
        catch (Exception e) {
            e.printStackTrace();
            fail();
        }
    }

    @Test
    public void testResultAcknowledge()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);

        PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
        workerClient.acknowledgeResultsAsync(1);
    }

    private PrestoSparkHttpTaskClient createWorkerClient(TaskId taskId)
    {
        return createWorkerClient(taskId, new TestingHttpClient(scheduledExecutorService, new TestingResponseManager(taskId.toString())));
    }

    private PrestoSparkHttpTaskClient createWorkerClient(TaskId taskId, TestingHttpClient httpClient)
    {
        return new PrestoSparkHttpTaskClient(
                httpClient,
                taskId,
                BASE_URI,
                TASK_INFO_JSON_CODEC,
                PLAN_FRAGMENT_JSON_CODEC,
                TASK_UPDATE_REQUEST_JSON_CODEC,
                new Duration(1, TimeUnit.SECONDS),
                scheduledExecutorService,
                scheduledExecutorService,
                new Duration(1, TimeUnit.SECONDS));
    }

    HttpNativeExecutionTaskResultFetcher createResultFetcher(PrestoSparkHttpTaskClient workerClient)
    {
        return createResultFetcher(workerClient, new Object());
    }

    HttpNativeExecutionTaskResultFetcher createResultFetcher(PrestoSparkHttpTaskClient workerClient, Object lock)
    {
        return new HttpNativeExecutionTaskResultFetcher(
                scheduledExecutorService,
                workerClient,
                lock);
    }

    @Test
    public void testResultAbort()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);

        PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
        ListenableFuture<?> future = workerClient.abortResultsAsync();
        try {
            future.get();
        }
        catch (Exception e) {
            e.printStackTrace();
            fail();
        }
    }

    @Test
    public void testGetTaskInfo()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);

        PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
        try {
            TaskInfo taskInfo = workerClient.getTaskInfo();
            assertEquals(taskInfo.getTaskId().toString(), taskId.toString());
        }
        catch (Exception e) {
            e.printStackTrace();
            fail();
        }
    }

    @Test
    public void testUpdateTask()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);

        PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);

        List<TaskSource> sources = new ArrayList<>();

        try {
            TaskInfo taskInfo = workerClient.updateTask(
                    sources,
                    createPlanFragment(),
                    new TableWriteInfo(Optional.empty(), Optional.empty()),
                    Optional.empty(),
                    Optional.empty(),
                    TestingSession.testSessionBuilder().build(),
                    createInitialEmptyOutputBuffers(PARTITIONED));
            assertEquals(taskInfo.getTaskId().toString(), taskId.toString());
        }
        catch (Exception e) {
            e.printStackTrace();
            fail();
        }
    }

    @Test
    public void testUpdateTaskUnexpectedResponse()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
        PrestoSparkHttpTaskClient workerClient = createWorkerClient(
                taskId,
                new TestingHttpClient(scheduledExecutorService, new TestingResponseManager(taskId.toString(), new UnexpectedResponseTaskInfoRetryResponseManager())));
        assertThatThrownBy(() -> workerClient.updateTask(
                new ArrayList<>(),
                createPlanFragment(),
                new TableWriteInfo(Optional.empty(), Optional.empty()),
                Optional.empty(),
                Optional.empty(),
                TestingSession.testSessionBuilder().build(),
                createInitialEmptyOutputBuffers(PARTITIONED)))
                .isInstanceOf(PrestoException.class)
                .hasMessageContaining("500");
    }

    @Test
    public void testUpdateTaskWithRetries()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
        PrestoSparkHttpTaskClient workerClient = createWorkerClient(
                taskId,
                new TestingHttpClient(scheduledExecutorService, new TestingResponseManager(taskId.toString(), new FailureRetryTaskInfoResponseManager(2))));
        workerClient.updateTask(
                new ArrayList<>(),
                createPlanFragment(),
                new TableWriteInfo(Optional.empty(), Optional.empty()),
                Optional.empty(),
                Optional.empty(),
                TestingSession.testSessionBuilder().build(),
                createInitialEmptyOutputBuffers(PARTITIONED));
    }

    @Test
    public void testGetServerInfo()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
        ServerInfo expected = new ServerInfo(UNKNOWN, "test", true, false, Optional.of(Duration.valueOf("2m")));

        PrestoSparkHttpServerClient workerClient = new PrestoSparkHttpServerClient(
                new TestingHttpClient(scheduledExecutorService, new TestingResponseManager(taskId.toString())),
                BASE_URI,
                SERVER_INFO_JSON_CODEC);
        ListenableFuture<BaseResponse<ServerInfo>> future = workerClient.getServerInfo();
        try {
            ServerInfo serverInfo = future.get().getValue();
            assertEquals(serverInfo, expected);
        }
        catch (Exception e) {
            e.printStackTrace();
            fail();
        }
    }

    @Test
    public void testGetServerInfoWithRetry()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
        ServerInfo expected = new ServerInfo(UNKNOWN, "test", true, false, Optional.of(Duration.valueOf("2m")));
        Duration maxTimeout = new Duration(1, TimeUnit.MINUTES);
        NativeExecutionProcess process = createNativeExecutionProcess(
                maxTimeout,
                new TestingResponseManager(taskId.toString(), new FailureRetryResponseManager(5)));

        SettableFuture<ServerInfo> future = process.getServerInfoWithRetry();
        try {
            ServerInfo serverInfo = future.get();
            assertEquals(serverInfo, expected);
        }
        catch (Exception e) {
            e.printStackTrace();
            fail();
        }
    }

    @Test
    public void testGetServerInfoWithRetryTimeout()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
        Duration maxTimeout = new Duration(0, TimeUnit.MILLISECONDS);
        NativeExecutionProcess process = createNativeExecutionProcess(
                maxTimeout,
                new TestingResponseManager(taskId.toString(), new FailureRetryResponseManager(5)));

        SettableFuture<ServerInfo> future = process.getServerInfoWithRetry();
        Exception exception = expectThrows(ExecutionException.class, future::get);
        assertTrue(exception.getMessage().contains("Native process launch failed with multiple retries"));
    }

    @Test
    public void testResultFetcher()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);

        PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
        HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient);
        taskResultFetcher.start();
        try {
            List<SerializedPage> pages = new ArrayList<>();
            Optional<SerializedPage> page = taskResultFetcher.pollPage();
            while (page.isPresent()) {
                pages.add(page.get());
                page = taskResultFetcher.pollPage();
            }

            assertEquals(1, pages.size());
            assertEquals(0, pages.get(0).getSizeInBytes());
        }
        catch (InterruptedException e) {
            e.printStackTrace();
            fail();
        }
    }

    private List<SerializedPage> fetchResults(HttpNativeExecutionTaskResultFetcher taskResultFetcher, int numPages)
            throws InterruptedException
    {
        List<SerializedPage> pages = new ArrayList<>();
        for (int i = 0; i < 1_000 && pages.size() < numPages; ++i) {
            Optional<SerializedPage> page = taskResultFetcher.pollPage();
            if (page.isPresent()) {
                pages.add(page.get());
            }
        }
        return pages;
    }

    @Test
    public void testResultFetcherMultipleNonEmptyResults()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
        int serializedPageSize = (int) new DataSize(1, MEGABYTE).toBytes();
        int numPages = 10;
        PrestoSparkHttpTaskClient workerClient = createWorkerClient(
                taskId,
                new TestingHttpClient(
                        scheduledExecutorService,
                        new TestingResponseManager(taskId.toString(), new TestingResponseManager.TestingResultResponseManager()
                        {
                            private int requestCount;

                            @Override
                            public Response createResultResponse(String taskId)
                                    throws PageTransportErrorException
                            {
                                requestCount++;
                                if (requestCount < numPages) {
                                    return createResultResponseHelper(
                                            HttpStatus.OK,
                                            taskId,
                                            requestCount - 1,
                                            requestCount,
                                            false,
                                            serializedPageSize);
                                }
                                else if (requestCount == numPages) {
                                    return createResultResponseHelper(
                                            HttpStatus.OK,
                                            taskId,
                                            requestCount - 1,
                                            requestCount,
                                            true,
                                            serializedPageSize);
                                }
                                else {
                                    fail("Retrieving results after buffer completion");
                                    return null;
                                }
                            }
                        })));
        HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient);
        taskResultFetcher.start();
        try {
            List<SerializedPage> pages = fetchResults(taskResultFetcher, numPages);

            assertEquals(numPages, pages.size());
            for (int i = 0; i < numPages; i++) {
                assertEquals(pages.get(i).getSizeInBytes(), serializedPageSize);
            }
        }
        catch (InterruptedException e) {
            e.printStackTrace();
            fail();
        }
    }

    private static class BreakingLimitResponseManager
            extends TestingResponseManager.TestingResultResponseManager
    {
        private final int serializedPageSize;
        private final int numPages;

        private int requestCount;

        public BreakingLimitResponseManager(int serializedPageSize, int numPages)
        {
            this.serializedPageSize = serializedPageSize;
            this.numPages = numPages;
        }

        @Override
        public Response createResultResponse(String taskId)
                throws PageTransportErrorException
        {
            requestCount++;
            if (requestCount < numPages) {
                return createResultResponseHelper(
                        HttpStatus.OK,
                        taskId,
                        requestCount - 1,
                        requestCount,
                        false,
                        serializedPageSize);
            }
            else if (requestCount == numPages) {
                return createResultResponseHelper(
                        HttpStatus.OK,
                        taskId,
                        requestCount - 1,
                        requestCount,
                        true,
                        serializedPageSize);
            }
            else {
                fail("Retrieving results after buffer completion");
                return null;
            }
        }

        public int getRemainingPageCount()
        {
            return numPages - requestCount;
        }
    }

    @Test
    public void testResultFetcherExceedingBufferLimit()
    {
        int numPages = 10;
        int serializedPageSize = (int) new DataSize(32, MEGABYTE).toBytes();
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);

        BreakingLimitResponseManager breakingLimitResponseManager =
                new BreakingLimitResponseManager(serializedPageSize, numPages);

        PrestoSparkHttpTaskClient workerClient = createWorkerClient(
                taskId,
                new TestingHttpClient(
                        scheduledExecutorService,
                        new TestingResponseManager(
                                taskId.toString(),
                                breakingLimitResponseManager)));
        HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient);
        taskResultFetcher.start();
        try {
            Optional<SerializedPage> page = Optional.empty();
            while (!page.isPresent()) {
                page = taskResultFetcher.pollPage();
            }
            // Wait a bit for fetches to overwhelm memory.
            Thread.sleep(5000);
            assertEquals(breakingLimitResponseManager.getRemainingPageCount(), 5);
            List<SerializedPage> pages = new ArrayList<>();
            pages.add(page.get());
            while (pages.size() < numPages) {
                page = taskResultFetcher.pollPage();
                page.ifPresent(pages::add);
            }

            assertEquals(numPages, pages.size());
            for (int i = 0; i < numPages; i++) {
                assertEquals(pages.get(i).getSizeInBytes(), serializedPageSize);
            }
        }
        catch (InterruptedException e) {
            e.printStackTrace();
            fail();
        }
    }

    private static class TimeoutResponseManager
            extends TestingResponseManager.TestingResultResponseManager
    {
        private final int serializedPageSize;
        private final int numPages;
        private final int numInitialTimeouts;

        private int requestCount;
        private int timeoutCount;

        public TimeoutResponseManager(int serializedPageSize, int numPages, int numInitialTimeouts)
        {
            this.serializedPageSize = serializedPageSize;
            this.numPages = numPages;
            this.numInitialTimeouts = numInitialTimeouts;
        }

        @Override
        public Response createResultResponse(String taskId)
                throws PageTransportErrorException
        {
            if (++timeoutCount <= numInitialTimeouts) {
                throw new RuntimeException("test failure");
            }
            requestCount++;
            if (requestCount < numPages) {
                return createResultResponseHelper(
                        HttpStatus.OK,
                        taskId,
                        requestCount - 1,
                        requestCount,
                        false,
                        serializedPageSize);
            }
            else if (requestCount == numPages) {
                return createResultResponseHelper(
                        HttpStatus.OK,
                        taskId,
                        requestCount - 1,
                        requestCount,
                        true,
                        serializedPageSize);
            }
            else {
                fail("Retrieving results after buffer completion");
                return null;
            }
        }
    }

    private static class PrestoExceptionResponseManager
            extends TestingResponseManager.TestingResultResponseManager
    {
        private int requestCount;

        @Override
        public Response createResultResponse(String taskId)
                throws PageTransportErrorException
        {
            if (requestCount == 0) {
                requestCount++;
                throw new PrestoException(GENERIC_INTERNAL_ERROR, "non retriable failure");
            }
            throw new RuntimeException("expected to be called only once");
        }
    }

    @Test
    public void testResultFetcherTransportErrorRecovery()
    {
        int numPages = 10;
        int serializedPageSize = 0;
        // Transport error count less than MAX_TRANSPORT_ERROR_RETRIES (5).
        // Expecting recovery from failed requests
        int numTransportErrors = 3;

        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);

        TimeoutResponseManager timeoutResponseManager =
                new TimeoutResponseManager(serializedPageSize, numPages, numTransportErrors);

        PrestoSparkHttpTaskClient workerClient = createWorkerClient(
                taskId,
                new TestingHttpClient(
                        scheduledExecutorService,
                        new TestingResponseManager(
                                taskId.toString(),
                                timeoutResponseManager)));
        HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient);
        taskResultFetcher.start();
        try {
            List<SerializedPage> pages = fetchResults(taskResultFetcher, numPages);

            assertEquals(pages.size(), numPages);
            for (int i = 0; i < numPages; i++) {
                assertEquals(pages.get(i).getSizeInBytes(), serializedPageSize);
            }
        }
        catch (InterruptedException e) {
            e.printStackTrace();
            fail();
        }
    }

    @Test
    public void testResultFetcherTransportErrorFail()
            throws InterruptedException
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);

        PrestoSparkHttpTaskClient workerClient = createWorkerClient(
                taskId,
                new TestingHttpClient(
                        scheduledExecutorService,
                        new TestingResponseManager(taskId.toString(), new TimeoutResponseManager(0, 10, 10))));
        HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient);
        taskResultFetcher.start();
        try {
            for (int i = 0; i < 1_000; ++i) {
                taskResultFetcher.pollPage();
            }
            fail("Expected an exception");
        }
        catch (PrestoTransportException e) {
            assertTrue(e.getMessage().startsWith("getResults encountered too many errors talking to native process"));
        }
    }

    @Test
    public void testResultFetcherPrestoException()
            throws InterruptedException
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
        PrestoSparkHttpTaskClient workerClient = createWorkerClient(
                taskId,
                new TestingHttpClient(
                        scheduledExecutorService,
                        new TestingResponseManager(taskId.toString(), new PrestoExceptionResponseManager())));
        Object monitor = new Object();
        HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient, monitor);
        taskResultFetcher.start();
        synchronized (monitor) {
            try {
                while (!taskResultFetcher.hasPage()) {
                    monitor.wait();
                }
            }
            catch (RuntimeException ignored) {
            }
        }
        assertThatThrownBy(taskResultFetcher::pollPage)
                .isInstanceOf(PrestoException.class)
                .hasMessage("non retriable failure");
    }

    @Test
    public void testResultFetcherWaitOnSignal()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
        Object lock = new Object();

        PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
        HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient, lock);
        taskResultFetcher.start();
        try {
            synchronized (lock) {
                while (!taskResultFetcher.hasPage()) {
                    lock.wait();
                }
            }
            assertTrue(taskResultFetcher.hasPage());
        }
        catch (InterruptedException e) {
            e.printStackTrace();
            fail();
        }
    }

    @Test
    public void testInfoFetcher()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);

        Duration fetchInterval = new Duration(1, TimeUnit.SECONDS);
        HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(taskId, new TestingResponseManager(taskId.toString()));
        assertFalse(taskInfoFetcher.getTaskInfo().isPresent());
        taskInfoFetcher.start();
        try {
            Thread.sleep(3 * fetchInterval.toMillis());
        }
        catch (InterruptedException e) {
            e.printStackTrace();
            fail();
        }
        assertTrue(taskInfoFetcher.getTaskInfo().isPresent());
    }

    @Test
    public void testInfoFetcherWithRetry()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);

        Duration fetchInterval = new Duration(1, TimeUnit.SECONDS);
        HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(
                taskId,
                new TestingResponseManager(taskId.toString(), new FailureTaskInfoRetryResponseManager(1)),
                new Duration(5, TimeUnit.SECONDS),
                new Object());
        assertFalse(taskInfoFetcher.getTaskInfo().isPresent());
        taskInfoFetcher.start();
        try {
            Thread.sleep(3 * fetchInterval.toMillis());
        }
        catch (InterruptedException e) {
            e.printStackTrace();
            fail();
        }

        // First fetch is expected to succeed.
        assertTrue(taskInfoFetcher.getTaskInfo().isPresent());

        try {
            Thread.sleep(10 * fetchInterval.toMillis());
        }
        catch (InterruptedException e) {
            e.printStackTrace();
            fail();
        }
        Exception exception = expectThrows(RuntimeException.class, taskInfoFetcher::getTaskInfo);
        assertThat(exception.getMessage())
                .contains("getTaskInfo encountered too many errors talking to native process");
    }

    @Test(timeOut = 60 * 1000)
    public void testInfoFetcherUnexpectedResponse()
            throws InterruptedException
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
        Object monitor = new Object();
        HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(
                taskId,
                new TestingResponseManager(taskId.toString(), new UnexpectedResponseTaskInfoRetryResponseManager()),
                new Duration(5, TimeUnit.SECONDS),
                monitor);
        taskInfoFetcher.start();
        synchronized (monitor) {
            while (taskInfoFetcher.getLastException().get() == null && !taskInfoFetcher.getTaskInfo().isPresent()) {
                monitor.wait();
            }
        }
        assertThatThrownBy(taskInfoFetcher::getTaskInfo)
                .isInstanceOf(PrestoException.class)
                .hasMessageContaining("500");
    }

    @Test
    public void testInfoFetcherWaitOnSignal()
    {
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
        Object lock = new Object();

        HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(taskId, new TestingResponseManager(taskId.toString(), TaskState.FINISHED), lock);
        assertFalse(taskInfoFetcher.getTaskInfo().isPresent());
        taskInfoFetcher.start();
        try {
            synchronized (lock) {
                while (!isTaskDone(taskInfoFetcher.getTaskInfo())) {
                    lock.wait();
                }
            }
        }
        catch (InterruptedException e) {
            fail();
        }
        assertTrue(isTaskDone(taskInfoFetcher.getTaskInfo()));
    }

    private boolean isTaskDone(Optional<TaskInfo> taskInfo)
    {
        return taskInfo.isPresent() && taskInfo.get().getTaskStatus().getState().isDone();
    }

    @Test
    public void testNativeExecutionTask()
    {
        // We need multi-thread scheduler to increase scheduling concurrency.
        // Otherwise, async execution assumption is not going to hold with a
        // single thread.
        TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
        TaskManagerConfig taskConfig = new TaskManagerConfig();
        QueryManagerConfig queryConfig = new QueryManagerConfig();
        taskConfig.setInfoRefreshMaxWait(new Duration(5, TimeUnit.SECONDS));
        taskConfig.setInfoUpdateInterval(new Duration(200, TimeUnit.MILLISECONDS));
        queryConfig.setRemoteTaskMaxErrorDuration(new Duration(1, TimeUnit.MINUTES));
        List<TaskSource> sources = new ArrayList<>();
        try {
            NativeExecutionTaskFactory taskFactory = new NativeExecutionTaskFactory(
                    new TestingHttpClient(
                            scheduledExecutorService,
                            new TestingResponseManager(taskId.toString(), new TimeoutResponseManager(0, 10, 0))),
                    scheduledExecutorService,
                    scheduledExecutorService,
                    TASK_INFO_JSON_CODEC,
                    PLAN_FRAGMENT_JSON_CODEC,
                    TASK_UPDATE_REQUEST_JSON_CODEC,
                    taskConfig,
                    queryConfig);
            NativeExecutionTask task = taskFactory.createNativeExecutionTask(
                    testSessionBuilder().build(),
                    BASE_URI,
                    taskId,
                    createPlanFragment(),
                    sources,
                    new TableWriteInfo(Optional.empty(), Optional.empty()),
                    Optional.empty(),
                    Optional.empty());
            assertNotNull(task);
            assertFalse(task.getTaskInfo().isPresent());
            assertFalse(task.pollResult().isPresent());

            // Start task
            TaskInfo taskInfo = task.start();
            assertFalse(taskInfo.getTaskStatus().getState().isDone());

            List<SerializedPage> resultPages = new ArrayList<>();
            for (int i = 0; i < 100 && resultPages.size() < 10; ++i) {
                Optional<SerializedPage> page = task.pollResult();
                page.ifPresent(resultPages::add);
            }
            assertFalse(task.pollResult().isPresent());
            assertEquals(10, resultPages.size());
            assertTrue(task.getTaskInfo().isPresent());

            task.stop(true);
        }
        catch (InterruptedException e) {
            e.printStackTrace();
            fail();
        }
    }

    private NativeExecutionProcess createNativeExecutionProcess(
            Duration maxErrorDuration,
            TestingResponseManager responseManager)
    {
        PrestoSparkWorkerProperty workerProperty = new PrestoSparkWorkerProperty(
                new NativeExecutionConnectorConfig(),
                new NativeExecutionNodeConfig(),
                new NativeExecutionSystemConfig(),
                new NativeExecutionVeloxConfig());
        NativeExecutionProcessFactory factory = new NativeExecutionProcessFactory(
                new TestingHttpClient(scheduledExecutorService, responseManager),
                scheduledExecutorService,
                scheduledExecutorService,
                SERVER_INFO_JSON_CODEC,
                workerProperty,
                new FeaturesConfig());
        return factory.createNativeExecutionProcess(testSessionBuilder().build(), maxErrorDuration);
    }

    private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager)
    {
        return createTaskInfoFetcher(taskId, testingResponseManager, new Duration(1, TimeUnit.MINUTES), new Object());
    }

    private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager, Object lock)
    {
        return createTaskInfoFetcher(taskId, testingResponseManager, new Duration(1, TimeUnit.MINUTES), lock);
    }

    private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager, Duration maxErrorDuration, Object lock)
    {
        PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId, new TestingHttpClient(scheduledExecutorService, testingResponseManager));
        return new HttpNativeExecutionTaskInfoFetcher(
                scheduledExecutorService,
                workerClient,
                new Duration(1, TimeUnit.SECONDS),
                lock);
    }

    private static class TestingHttpResponseFuture<T>
            extends AbstractFuture<T>
            implements HttpClient.HttpResponseFuture<T>
    {
        @Override
        public String getState()
        {
            return null;
        }

        public void complete(T value)
        {
            super.set(value);
        }

        public void completeExceptionally(Throwable t)
        {
            super.setException(t);
        }
    }

    public static class TestingHttpClient
            implements com.facebook.airlift.http.client.HttpClient
    {
        private static final String TASK_ID_REGEX = "/v1/task/[a-zA-Z0-9]+.[0-9]+.[0-9]+.[0-9]+.[0-9]+";
        private final ScheduledExecutorService executor;
        private final TestingResponseManager responseManager;

        public TestingHttpClient(ScheduledExecutorService executor, TestingResponseManager responseManager)
        {
            this.executor = executor;
            this.responseManager = responseManager;
        }

        @Override
        public <T, E extends Exception> T execute(Request request, ResponseHandler<T, E> responseHandler)
                throws E
        {
            try {
                return executeAsync(request, responseHandler).get();
            }
            catch (Exception e) {
                e.printStackTrace();
                return null;
            }
        }

        @Override
        public <T, E extends Exception> HttpResponseFuture<T> executeAsync(Request request, ResponseHandler<T, E> responseHandler)
        {
            TestingHttpResponseFuture<T> future = new TestingHttpResponseFuture<>();
            executor.schedule(
                    () ->
                    {
                        URI uri = request.getUri();
                        String method = request.getMethod();
                        String path = uri.getPath();
                        try {
                            if (method.equalsIgnoreCase("GET")) {
                                // GET /v1/task/{taskId}
                                if (Pattern.compile(TASK_ID_REGEX + "\\z").matcher(path).find()) {
                                    future.complete(responseHandler.handle(request, responseManager.createTaskInfoResponse(HttpStatus.OK)));
                                }
                                // GET /v1/task/{taskId}/results/{bufferId}/{token}/acknowledge
                                else if (Pattern.compile(".*/results/[0-9]+/[0-9]+/acknowledge\\z").matcher(path).find()) {
                                    future.complete(responseHandler.handle(request, responseManager.createDummyResultResponse()));
                                }
                                // GET /v1/task/{taskId}/results/{bufferId}/{token}
                                else if (Pattern.compile(".*/results/[0-9]+/[0-9]+\\z").matcher(path).find()) {
                                    future.complete(responseHandler.handle(
                                            request,
                                            responseManager.createResultResponse()));
                                }
                                // GET /v1/info
                                else if (Pattern.compile("/v1/info").matcher(path).find()) {
                                    future.complete(responseHandler.handle(
                                            request,
                                            responseManager.createServerInfoResponse()));
                                }
                            }
                            else if (method.equalsIgnoreCase("POST")) {
                                // POST /v1/task/{taskId}/batch
                                if (Pattern.compile(format("%s\\/batch\\z", TASK_ID_REGEX)).matcher(path).find()) {
                                    future.complete(responseHandler.handle(request, responseManager.createTaskInfoResponse(HttpStatus.OK)));
                                }
                            }
                            else if (method.equalsIgnoreCase("DELETE")) {
                                // DELETE /v1/task/{taskId}/results/{bufferId}
                                if (Pattern.compile(format("%s\\/results\\/[0-9]+\\z", TASK_ID_REGEX)).matcher(path).find()) {
                                    future.complete(responseHandler.handle(request, responseManager.createDummyResultResponse()));
                                }
                                // DELETE /v1/task/{taskId}
                                else if (Pattern.compile(TASK_ID_REGEX + "\\z").matcher(path).find()) {
                                    future.complete(responseHandler.handle(request, responseManager.createDummyResultResponse()));
                                }
                            }
                        }
                        catch (Exception e) {
                            e.printStackTrace();
                            future.completeExceptionally(e);
                        }

                        if (!future.isDone()) {
                            future.completeExceptionally(new Exception(format("Unsupported request: %s %s", method, path)));
                        }
                    },
                    (long) NO_DURATION.getValue(),
                    NO_DURATION.getUnit());
            return future;
        }

        @Override
        public RequestStats getStats()
        {
            return null;
        }

        @Override
        public long getMaxContentLength()
        {
            return 0;
        }

        @Override
        public void close()
        {
        }

        @Override
        public boolean isClosed()
        {
            return false;
        }
    }

    /**
     * A stateful response manager for testing purpose. The lifetime of an instantiation of this class should be equivalent to the lifetime of the http client.
     */
    public static class TestingResponseManager
    {
        private static final JsonCodec<TaskInfo> taskInfoCodec = JsonCodec.jsonCodec(TaskInfo.class);
        private static final JsonCodec<ServerInfo> serverInfoCodec = JsonCodec.jsonCodec(ServerInfo.class);
        private final TestingResultResponseManager resultResponseManager;
        private final TestingServerResponseManager serverResponseManager;
        private final TestingTaskInfoResponseManager taskInfoResponseManager;
        private final String taskId;

        public TestingResponseManager(String taskId)
        {
            this.taskId = requireNonNull(taskId, "taskId is null");
            this.resultResponseManager = new TestingResultResponseManager();
            this.serverResponseManager = new TestingServerResponseManager();
            this.taskInfoResponseManager = new TestingTaskInfoResponseManager();
        }

        public TestingResponseManager(String taskId, TaskState taskState)
        {
            this.taskId = requireNonNull(taskId, "taskId is null");
            this.resultResponseManager = new TestingResultResponseManager();
            this.serverResponseManager = new TestingServerResponseManager();
            this.taskInfoResponseManager = new TestingTaskInfoResponseManager(taskState);
        }

        public TestingResponseManager(String taskId, TestingResultResponseManager resultResponseManager)
        {
            this.taskId = requireNonNull(taskId, "taskId is null");
            this.resultResponseManager = requireNonNull(resultResponseManager, "resultResponseManager is null.");
            this.serverResponseManager = new TestingServerResponseManager();
            this.taskInfoResponseManager = new TestingTaskInfoResponseManager();
        }

        public TestingResponseManager(String taskId, TestingServerResponseManager serverResponseManager)
        {
            this.taskId = requireNonNull(taskId, "taskId is null");
            this.resultResponseManager = new TestingResultResponseManager();
            this.taskInfoResponseManager = new TestingTaskInfoResponseManager();
            this.serverResponseManager = requireNonNull(serverResponseManager, "serverResponseManager is null");
        }

        public TestingResponseManager(String taskId, TestingTaskInfoResponseManager taskInfoResponseManager)
        {
            this.taskId = requireNonNull(taskId, "taskId is null");
            this.resultResponseManager = new TestingResultResponseManager();
            this.serverResponseManager = new TestingServerResponseManager();
            this.taskInfoResponseManager = requireNonNull(taskInfoResponseManager, "taskInfoResponseManager is null");
        }

        public Response createDummyResultResponse()
        {
            return new TestingResponse();
        }

        public Response createResultResponse()
                throws PageTransportErrorException
        {
            return resultResponseManager.createResultResponse(taskId);
        }

        public Response createServerInfoResponse()
                throws PrestoException
        {
            return serverResponseManager.createServerInfoResponse();
        }

        public Response createTaskInfoResponse(HttpStatus httpStatus)
                throws PrestoException
        {
            return taskInfoResponseManager.createTaskInfoResponse(httpStatus, taskId);
        }

        /**
         * Manager for server related endpoints. It maintains any stateful information inside itself. Callers can extend this class to create their own response handling
         * logic.
         */
        public static class TestingServerResponseManager
        {
            public Response createServerInfoResponse()
                    throws PrestoException
            {
                ServerInfo serverInfo = new ServerInfo(UNKNOWN, "test", true, false, Optional.of(Duration.valueOf("2m")));
                HttpStatus httpStatus = HttpStatus.OK;
                ListMultimap<HeaderName, String> headers = ArrayListMultimap.create();
                headers.put(HeaderName.of(CONTENT_TYPE), String.valueOf(MediaType.create("application", "json")));
                return new TestingResponse(
                        httpStatus.code(),
                        headers,
                        new ByteArrayInputStream(serverInfoCodec.toBytes(serverInfo)));
            }
        }

        /**
         * Manager for result fetching related endpoints. It maintains any stateful information inside itself. Callers can extend this class to create their own response handling
         * logic.
         */
        public static class TestingResultResponseManager
        {
            /**
             * A dummy implementation of result creation logic. It shall be overriden by users to create customized result returning logic.
             */
            public Response createResultResponse(String taskId)
                    throws PageTransportErrorException
            {
                return createResultResponseHelper(HttpStatus.OK,
                        taskId,
                        0,
                        1,
                        true,
                        0);
            }

            protected Response createResultResponseHelper(
                    HttpStatus httpStatus,
                    String taskId,
                    long token,
                    long nextToken,
                    boolean bufferComplete,
                    int serializedPageSizeBytes)
            {
                DynamicSliceOutput slicedOutput = new DynamicSliceOutput(1024);
                PagesSerdeUtil.writeSerializedPage(slicedOutput, createSerializedPage(serializedPageSizeBytes));
                ListMultimap<HeaderName, String> headers = ArrayListMultimap.create();
                headers.put(HeaderName.of(PRESTO_PAGE_TOKEN), String.valueOf(token));
                headers.put(HeaderName.of(PRESTO_PAGE_NEXT_TOKEN), String.valueOf(nextToken));
                headers.put(HeaderName.of(PRESTO_BUFFER_COMPLETE), String.valueOf(bufferComplete));
                headers.put(HeaderName.of(PRESTO_TASK_INSTANCE_ID), taskId);
                headers.put(HeaderName.of(CONTENT_TYPE), PRESTO_PAGES_TYPE.toString());
                return new TestingResponse(
                        httpStatus.code(),
                        headers,
                        slicedOutput.slice().getInput());
            }
        }

        /**
         * Manager for taskInfo fetching related endpoints. It maintains any stateful information inside itself. Callers can extend this class to create their own response handling
         * logic.
         */
        public static class TestingTaskInfoResponseManager
        {
            private final TaskState taskState;

            public TestingTaskInfoResponseManager()
            {
                taskState = TaskState.PLANNED;
            }

            public TestingTaskInfoResponseManager(TaskState taskState)
            {
                this.taskState = taskState;
            }

            public Response createTaskInfoResponse(HttpStatus httpStatus, String taskId)
                    throws PrestoException
            {
                URI location = uriBuilderFrom(BASE_URI).appendPath(TASK_ROOT_PATH).build();
                ListMultimap<HeaderName, String> headers = ArrayListMultimap.create();
                headers.put(HeaderName.of(CONTENT_TYPE), String.valueOf(MediaType.create("application", "json")));
                TaskInfo taskInfo = TaskInfo.createInitialTask(
                        TaskId.valueOf(taskId),
                        location,
                        new ArrayList<>(),
                        new TaskStats(System.currentTimeMillis(), 0L),
                        "dummy-node").withTaskStatus(createTaskStatusDone(location));
                return new TestingResponse(
                        httpStatus.code(),
                        headers,
                        new ByteArrayInputStream(taskInfoCodec.toBytes(taskInfo)));
            }

            private TaskStatus createTaskStatusDone(URI location)
            {
                return new TaskStatus(
                        0L,
                        0L,
                        0,
                        taskState,
                        location,
                        ImmutableSet.of(),
                        ImmutableList.of(),
                        0,
                        0,
                        0.0,
                        false,
                        0,
                        0,
                        0,
                        0,
                        0,
                        0,
                        0,
                        0,
                        0L,
                        0L);
            }
        }

        public static class CrashingTaskInfoResponseManager
                extends TestingResponseManager.TestingTaskInfoResponseManager
        {
            public CrashingTaskInfoResponseManager()
            {
                super();
            }

            @Override
            public Response createTaskInfoResponse(HttpStatus httpStatus, String taskId)
                    throws PrestoException
            {
                throw new RuntimeException("Server refused connection");
            }
        }
    }

    public static class TestingResponse
            implements Response
    {
        private final int statusCode;
        private final ListMultimap<HeaderName, String> headers;
        private InputStream inputStream;

        private TestingResponse()
        {
            this.statusCode = HttpStatus.OK.code();
            this.headers = ArrayListMultimap.create();
        }

        private TestingResponse(
                int statusCode,
                ListMultimap<HeaderName, String> headers,
                InputStream inputStream)
        {
            this.statusCode = statusCode;
            this.headers = headers;
            this.inputStream = inputStream;
        }

        @Override
        public int getStatusCode()
        {
            return statusCode;
        }

        @Override
        public ListMultimap<HeaderName, String> getHeaders()
        {
            return headers;
        }

        @Override
        public long getBytesRead()
        {
            return 0;
        }

        @Override
        public InputStream getInputStream()
        {
            return inputStream;
        }
    }

    private static SerializedPage createSerializedPage(int numBytes)
    {
        byte[] bytes = new byte[numBytes];
        Arrays.fill(bytes, (byte) 8);
        Slice slice = Slices.wrappedBuffer(bytes);
        return new SerializedPage(
                slice,
                PageCodecMarker.none(),
                0,
                numBytes,
                0);
    }

    public static class FailureRetryResponseManager
            extends TestingResponseManager.TestingServerResponseManager
    {
        private final int maxRetryCount;
        private int retryCount;

        public FailureRetryResponseManager(int maxRetryCount)
        {
            this.maxRetryCount = maxRetryCount;
        }

        @Override
        public Response createServerInfoResponse()
                throws PrestoException
        {
            if (retryCount++ < maxRetryCount) {
                throw new RuntimeException("Get ServerInfo request failure.");
            }

            return super.createServerInfoResponse();
        }
    }

    public static class FailureRetryTaskInfoResponseManager
            extends TestingResponseManager.TestingTaskInfoResponseManager
    {
        private final int maxRetryCount;
        private int retryCount;

        public FailureRetryTaskInfoResponseManager(int maxRetryCount)
        {
            this.maxRetryCount = maxRetryCount;
        }

        @Override
        public Response createTaskInfoResponse(HttpStatus httpStatus, String taskId)
                throws PrestoException
        {
            if (retryCount++ < maxRetryCount) {
                throw new RuntimeException("retriable failure");
            }

            return super.createTaskInfoResponse(httpStatus, taskId);
        }
    }

    private static class FailureTaskInfoRetryResponseManager
            extends TestingResponseManager.TestingTaskInfoResponseManager
    {
        private final int failureCount;
        private int retryCount;

        public FailureTaskInfoRetryResponseManager(int failureCount)
        {
            super();
            this.failureCount = failureCount;
        }

        @Override
        public Response createTaskInfoResponse(HttpStatus httpStatus, String taskId)
                throws PrestoException
        {
            if (retryCount++ > failureCount) {
                throw new RuntimeException("retriable failure");
            }

            return super.createTaskInfoResponse(httpStatus, taskId);
        }
    }

    private static class UnexpectedResponseTaskInfoRetryResponseManager
            extends TestingResponseManager.TestingTaskInfoResponseManager
    {
        private int requestCount;

        @Override
        public Response createTaskInfoResponse(HttpStatus httpStatus, String taskId)
                throws PrestoException
        {
            if (requestCount == 0) {
                requestCount++;
                return super.createTaskInfoResponse(HttpStatus.INTERNAL_SERVER_ERROR, taskId);
            }
            throw new RuntimeException("response handler is not expected to be called more than once");
        }
    }
}