TestPrestoSparkTaskExecution.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.task;

import com.facebook.presto.Session;
import com.facebook.presto.execution.ScheduledSplit;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.TaskStateMachine;
import com.facebook.presto.execution.TaskTestUtils;
import com.facebook.presto.execution.executor.TaskExecutor;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.spark.execution.task.PrestoSparkTaskExecution;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.sql.planner.LocalExecutionPlanner;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.TestingOutputBuffer;
import com.facebook.presto.sql.planner.TestingRemoteSourceFactory;
import com.facebook.presto.testing.TestingSplit;
import com.facebook.presto.testing.TestingTaskContext;
import com.facebook.presto.testing.TestingTransactionHandle;
import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.units.DataSize;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import java.util.ArrayList;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;

import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed;
import static com.facebook.presto.execution.TaskManagerConfig.TaskPriorityTracking.TASK_FAIR;
import static com.facebook.presto.execution.TaskTestUtils.TABLE_SCAN_NODE_ID;
import static com.facebook.presto.execution.TaskTestUtils.createPlanFragment;
import static com.facebook.presto.execution.TaskTestUtils.createTestingPlanner;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME;
import static io.airlift.units.DataSize.Unit.GIGABYTE;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static org.testng.Assert.assertEquals;

@Test(singleThreaded = true)
public class TestPrestoSparkTaskExecution
{
    Session nativeTestSession;
    Session nonNativeTestSession;
    LocalExecutionPlanner.LocalExecutionPlan localExecutionPlan;
    TaskExecutor taskExecutor;
    PlanFragment planFragment = createPlanFragment();
    ExecutorService taskNotificationExecutor;
    ScheduledExecutorService scheduledExecutor;
    TaskStateMachine taskStateMachine;

    Set<ScheduledSplit> splits = ImmutableSet.of(
            new ScheduledSplit(1, TABLE_SCAN_NODE_ID, new Split(new ConnectorId("test"), TestingTransactionHandle.create(), TestingSplit.createLocalSplit())),
            new ScheduledSplit(2, TABLE_SCAN_NODE_ID, new Split(new ConnectorId("test"), TestingTransactionHandle.create(), TestingSplit.createLocalSplit())),
            new ScheduledSplit(3, TABLE_SCAN_NODE_ID, new Split(new ConnectorId("test"), TestingTransactionHandle.create(), TestingSplit.createLocalSplit())));

    @BeforeMethod
    public void setUp()
    {
        taskNotificationExecutor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s"));
        scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s"));
        taskExecutor = new TaskExecutor(8, 16, 3, 4, TASK_FAIR, Ticker.systemTicker());

        nativeTestSession = testSessionBuilder()
                .setCatalog("tpch")
                .setSchema(TINY_SCHEMA_NAME)
                .build();
        nonNativeTestSession = testSessionBuilder()
                .setCatalog("tpch")
                .setSchema(TINY_SCHEMA_NAME)
                .build();

        taskStateMachine = new TaskStateMachine(new TaskId("test_query_id", 4, 4, 1, 0), taskNotificationExecutor);

        localExecutionPlan = createTestingPlanner().plan(
                TestingTaskContext.createTaskContext(taskNotificationExecutor, scheduledExecutor, nativeTestSession),
                planFragment,
                new TestingOutputBuffer(),
                new TestingRemoteSourceFactory(),
                new TableWriteInfo(Optional.empty(), Optional.empty()),
                new ArrayList<>());
    }

    @AfterMethod
    public void tearDown()
    {
        taskStateMachine.finished();
        taskExecutor.stop();
        scheduledExecutor.shutdown();
        taskNotificationExecutor.shutdown();
    }

    @Test
    public void testJavaDriverInstanceCount()
    {
        testDriverCount(nonNativeTestSession, 3);
    }

    private void testDriverCount(Session session, int expectedDriverCount)
    {
        TaskContext taskContext = TestingTaskContext.createTaskContext(taskNotificationExecutor, scheduledExecutor, session, new DataSize(2, GIGABYTE));
        taskExecutor.start();
        PrestoSparkTaskExecution taskExecution = new PrestoSparkTaskExecution(taskStateMachine, taskContext, localExecutionPlan, taskExecutor, TaskTestUtils.createTestSplitMonitor(), taskNotificationExecutor, scheduledExecutor, false);
        taskExecution.start(ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, splits, true)));
        assertEquals(taskContext.getPipelineContexts().get(0).getPipelineStats().getDrivers().size(), expectedDriverCount);
    }
}