MockRemoteTaskFactory.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution;
import com.facebook.airlift.stats.TestingGcMonitor;
import com.facebook.presto.Session;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.cost.StatsAndCosts;
import com.facebook.presto.execution.NodeTaskMap.NodeStatsTracker;
import com.facebook.presto.execution.buffer.LazyOutputBuffer;
import com.facebook.presto.execution.buffer.OutputBuffer;
import com.facebook.presto.execution.buffer.OutputBuffers;
import com.facebook.presto.execution.buffer.SpoolingOutputBufferFactory;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.memory.MemoryPool;
import com.facebook.presto.memory.QueryContext;
import com.facebook.presto.memory.context.SimpleLocalMemoryContext;
import com.facebook.presto.metadata.InternalNode;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.operator.TaskMemoryReservationSummary;
import com.facebook.presto.operator.TaskStats;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.SplitWeight;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.memory.MemoryPoolId;
import com.facebook.presto.spi.plan.Partitioning;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.StageExecutionDescriptor;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spiller.SpillSpaceTracker;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.testing.TestingHandle;
import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle;
import com.facebook.presto.testing.TestingMetadata.TestingTableHandle;
import com.facebook.presto.testing.TestingTransactionHandle;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.units.DataSize;
import javax.annotation.concurrent.GuardedBy;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Stream;
import static com.facebook.airlift.json.JsonCodec.listJsonCodec;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.execution.StateMachine.StateChangeListener;
import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.BROADCAST;
import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers;
import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static com.facebook.presto.metadata.MetadataUpdates.DEFAULT_METADATA_UPDATES;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION;
import static com.facebook.presto.util.Failures.toFailures;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.util.concurrent.Futures.nonCancellationPropagating;
import static io.airlift.units.DataSize.Unit.GIGABYTE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.lang.Math.addExact;
import static java.lang.System.currentTimeMillis;
import static java.util.Objects.requireNonNull;
public class MockRemoteTaskFactory
implements RemoteTaskFactory
{
private static final UUID TASK_INSTANCE_ID = UUID.randomUUID();
private final Executor executor;
private final ScheduledExecutorService scheduledExecutor;
public MockRemoteTaskFactory(Executor executor, ScheduledExecutorService scheduledExecutor)
{
this.executor = executor;
this.scheduledExecutor = scheduledExecutor;
}
public MockRemoteTask createTableScanTask(TaskId taskId, InternalNode newNode, List<Split> splits, NodeTaskMap.NodeStatsTracker nodeStatsTracker)
{
VariableReferenceExpression variable = new VariableReferenceExpression(Optional.empty(), "column", VARCHAR);
PlanNodeId sourceId = new PlanNodeId("sourceId");
PlanFragment testFragment = new PlanFragment(
new PlanFragmentId(0),
new TableScanNode(
Optional.empty(),
sourceId,
new TableHandle(new ConnectorId("test"), new TestingTableHandle(), TestingTransactionHandle.create(), Optional.of(TestingHandle.INSTANCE)),
ImmutableList.of(variable),
ImmutableMap.of(variable, new TestingColumnHandle("column")),
TupleDomain.all(),
TupleDomain.all(), Optional.empty()),
ImmutableSet.of(variable),
SOURCE_DISTRIBUTION,
ImmutableList.of(sourceId),
new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(variable)),
StageExecutionDescriptor.ungroupedExecution(),
false,
Optional.of(StatsAndCosts.empty()),
Optional.empty());
ImmutableMultimap.Builder<PlanNodeId, Split> initialSplits = ImmutableMultimap.builder();
for (Split sourceSplit : splits) {
initialSplits.put(sourceId, sourceSplit);
}
return createRemoteTask(
TEST_SESSION,
taskId,
newNode,
testFragment,
initialSplits.build(),
createInitialEmptyOutputBuffers(BROADCAST),
nodeStatsTracker,
true,
new TableWriteInfo(Optional.empty(), Optional.empty()),
SchedulerStatsTracker.NOOP);
}
@Override
public MockRemoteTask createRemoteTask(
Session session,
TaskId taskId,
InternalNode node,
PlanFragment fragment,
Multimap<PlanNodeId, Split> initialSplits,
OutputBuffers outputBuffers,
NodeTaskMap.NodeStatsTracker nodeStatsTracker,
boolean summarizeTaskInfo,
TableWriteInfo tableWriteInfo,
SchedulerStatsTracker schedulerStatsTracker)
{
return new MockRemoteTask(taskId, fragment, node.getNodeIdentifier(), executor, scheduledExecutor, initialSplits, nodeStatsTracker);
}
public static final class MockRemoteTask
implements RemoteTask
{
private final AtomicLong nextTaskInfoVersion = new AtomicLong(TaskStatus.STARTING_VERSION);
private final AtomicLong nextAgeOffset = new AtomicLong(0);
private final URI location;
private final TaskStateMachine taskStateMachine;
private final TaskContext taskContext;
private final OutputBuffer outputBuffer;
private final String nodeId;
private final PlanFragment fragment;
@GuardedBy("this")
private final Set<PlanNodeId> noMoreSplits = new HashSet<>();
@GuardedBy("this")
private final Multimap<PlanNodeId, Split> splits = HashMultimap.create();
@GuardedBy("this")
private int runningDrivers;
@GuardedBy("this")
private int maxUnacknowledgedSplits = Integer.MAX_VALUE;
@GuardedBy("this")
private int unacknowledgedSplits;
@GuardedBy("this")
private SettableFuture<?> whenSplitQueueHasSpace = SettableFuture.create();
private final NodeStatsTracker nodeStatsTracker;
public MockRemoteTask(TaskId taskId,
PlanFragment fragment,
String nodeId,
Executor executor,
ScheduledExecutorService scheduledExecutor,
Multimap<PlanNodeId, Split> initialSplits,
NodeTaskMap.NodeStatsTracker nodeStatsTracker)
{
this.taskStateMachine = new TaskStateMachine(requireNonNull(taskId, "taskId is null"), requireNonNull(executor, "executor is null"));
MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(1, GIGABYTE));
SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(new DataSize(1, GIGABYTE));
QueryContext queryContext = new QueryContext(taskId.getQueryId(),
new DataSize(1, MEGABYTE),
new DataSize(2, MEGABYTE),
new DataSize(1, MEGABYTE),
new DataSize(1, GIGABYTE),
memoryPool,
new TestingGcMonitor(),
executor,
scheduledExecutor,
new DataSize(1, MEGABYTE),
spillSpaceTracker,
listJsonCodec(TaskMemoryReservationSummary.class));
this.taskContext = queryContext.addTaskContext(
taskStateMachine,
TEST_SESSION,
Optional.of(fragment.getRoot()),
true,
true,
true,
true,
false);
this.location = URI.create("fake://task/" + taskId);
this.outputBuffer = new LazyOutputBuffer(
taskId,
TASK_INSTANCE_ID.toString(),
executor,
1L,
() -> new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"),
new SpoolingOutputBufferFactory(new FeaturesConfig()));
this.fragment = requireNonNull(fragment, "fragment is null");
this.nodeId = requireNonNull(nodeId, "nodeId is null");
splits.putAll(initialSplits);
this.nodeStatsTracker = requireNonNull(nodeStatsTracker, "nodeStatsTracker is null");
updateTaskStats();
updateSplitQueueSpace();
}
@Override
public TaskId getTaskId()
{
return taskStateMachine.getTaskId();
}
@Override
public String getNodeId()
{
return nodeId;
}
@Override
public TaskInfo getTaskInfo()
{
TaskStats stats = taskContext.getTaskStats();
TaskState state = taskStateMachine.getState();
List<ExecutionFailureInfo> failures = ImmutableList.of();
if (state == TaskState.FAILED) {
failures = toFailures(taskStateMachine.getFailureCauses());
}
return new TaskInfo(
taskStateMachine.getTaskId(),
new TaskStatus(
TASK_INSTANCE_ID.getLeastSignificantBits(),
TASK_INSTANCE_ID.getMostSignificantBits(),
nextTaskInfoVersion.getAndIncrement(),
state,
location,
ImmutableSet.of(),
failures,
0,
0,
0.0,
false,
0,
0,
0,
0,
0,
0,
0,
currentTimeMillis() + 100 - stats.getCreateTimeInMillis(),
0L,
0L),
currentTimeMillis(),
outputBuffer.getInfo(),
ImmutableSet.of(),
taskContext.getTaskStats(),
true,
DEFAULT_METADATA_UPDATES,
nodeId);
}
@Override
public URI getRemoteTaskLocation()
{
return location;
}
@Override
public TaskStatus getTaskStatus()
{
TaskStats stats = taskContext.getTaskStats();
PartitionedSplitsInfo combinedSplitsInfo = getPartitionedSplitsInfo();
PartitionedSplitsInfo queuedSplitsInfo = getQueuedPartitionedSplitsInfo();
return new TaskStatus(
TASK_INSTANCE_ID.getLeastSignificantBits(),
TASK_INSTANCE_ID.getMostSignificantBits(),
nextTaskInfoVersion.get(),
taskStateMachine.getState(),
location,
ImmutableSet.of(),
ImmutableList.of(),
queuedSplitsInfo.getCount(),
combinedSplitsInfo.getCount() - queuedSplitsInfo.getCount(),
0.0,
false,
stats.getPhysicalWrittenDataSizeInBytes(),
stats.getUserMemoryReservationInBytes(),
stats.getSystemMemoryReservationInBytes(),
stats.getPeakNodeTotalMemoryInBytes(),
0,
0,
stats.getTotalCpuTimeInNanos(),
// Adding 100 millis to make sure task age > 0 for testing
currentTimeMillis() + 100 - stats.getCreateTimeInMillis(),
queuedSplitsInfo.getWeightSum(),
combinedSplitsInfo.getWeightSum() - queuedSplitsInfo.getWeightSum());
}
private void updateTaskStats()
{
TaskStatus taskStatus = getTaskStatus();
if (taskStatus.getState().isDone()) {
nodeStatsTracker.setPartitionedSplits(PartitionedSplitsInfo.forZeroSplits());
nodeStatsTracker.setMemoryUsage(0);
nodeStatsTracker.setCpuUsage(taskStatus.getTaskAgeInMillis(), 0);
}
else {
nodeStatsTracker.setPartitionedSplits(getPartitionedSplitsInfo());
// setting some values for testing
nodeStatsTracker.setMemoryUsage(100);
long ageOffset = nextAgeOffset.addAndGet(1);
nodeStatsTracker.setCpuUsage(taskStatus.getTaskAgeInMillis() + ageOffset, taskStatus.getTaskAgeInMillis() + ageOffset);
}
}
private synchronized void updateSplitQueueSpace()
{
if (unacknowledgedSplits < maxUnacknowledgedSplits && getQueuedPartitionedSplitsInfo().getWeightSum() < 900L) {
if (!whenSplitQueueHasSpace.isDone()) {
whenSplitQueueHasSpace.set(null);
}
}
else {
if (whenSplitQueueHasSpace.isDone()) {
whenSplitQueueHasSpace = SettableFuture.create();
}
}
}
public synchronized void finishSplits(int splits)
{
List<Map.Entry<PlanNodeId, Split>> toRemove = new ArrayList<>();
Iterator<Map.Entry<PlanNodeId, Split>> iterator = this.splits.entries().iterator();
while (toRemove.size() < splits && iterator.hasNext()) {
toRemove.add(iterator.next());
}
for (Map.Entry<PlanNodeId, Split> entry : toRemove) {
this.splits.remove(entry.getKey(), entry.getValue());
}
updateSplitQueueSpace();
}
public synchronized void clearSplits()
{
unacknowledgedSplits = 0;
splits.clear();
updateTaskStats();
runningDrivers = 0;
updateSplitQueueSpace();
}
public synchronized void startSplits(int maxRunning)
{
runningDrivers = splits.size();
runningDrivers = Math.min(runningDrivers, maxRunning);
updateSplitQueueSpace();
}
public synchronized void setMaxUnacknowledgedSplits(int maxUnacknowledgedSplits)
{
checkArgument(maxUnacknowledgedSplits > 0);
this.maxUnacknowledgedSplits = maxUnacknowledgedSplits;
updateSplitQueueSpace();
}
public synchronized void setUnacknowledgedSplits(int unacknowledgedSplits)
{
checkArgument(unacknowledgedSplits >= 0);
this.unacknowledgedSplits = unacknowledgedSplits;
updateSplitQueueSpace();
}
@Override
public void start()
{
taskStateMachine.addStateChangeListener(newValue -> {
if (newValue.isDone()) {
clearSplits();
}
});
}
@Override
public void addSplits(Multimap<PlanNodeId, Split> splits)
{
synchronized (this) {
this.splits.putAll(splits);
}
updateTaskStats();
updateSplitQueueSpace();
}
@Override
public synchronized void noMoreSplits(PlanNodeId sourceId)
{
noMoreSplits.add(sourceId);
boolean allSourcesComplete = Stream.concat(
fragment.getTableScanSchedulingOrder().stream(),
fragment.getRemoteSourceNodes().stream()
.map(PlanNode::getId))
.allMatch(noMoreSplits::contains);
if (allSourcesComplete) {
taskStateMachine.finished();
}
}
@Override
public void noMoreSplits(PlanNodeId sourceId, Lifespan lifespan)
{
throw new UnsupportedOperationException();
}
@Override
public void setOutputBuffers(OutputBuffers outputBuffers)
{
outputBuffer.setOutputBuffers(outputBuffers);
}
@Override
public ListenableFuture<?> removeRemoteSource(TaskId remoteSourceTaskId)
{
throw new UnsupportedOperationException();
}
@Override
public void addStateChangeListener(StateChangeListener<TaskStatus> stateChangeListener)
{
taskStateMachine.addStateChangeListener(newValue -> stateChangeListener.stateChanged(getTaskStatus()));
}
@Override
public void addFinalTaskInfoListener(StateChangeListener<TaskInfo> stateChangeListener)
{
AtomicBoolean done = new AtomicBoolean();
StateChangeListener<TaskState> fireOnceStateChangeListener = state -> {
if (state.isDone() && done.compareAndSet(false, true)) {
stateChangeListener.stateChanged(getTaskInfo());
}
};
taskStateMachine.addStateChangeListener(fireOnceStateChangeListener);
fireOnceStateChangeListener.stateChanged(taskStateMachine.getState());
}
@Override
public synchronized ListenableFuture<?> whenSplitQueueHasSpace(long weightThreshold)
{
return nonCancellationPropagating(whenSplitQueueHasSpace);
}
@Override
public void cancel()
{
taskStateMachine.cancel();
}
@Override
public void abort()
{
taskStateMachine.abort();
clearSplits();
}
@Override
public PartitionedSplitsInfo getPartitionedSplitsInfo()
{
if (taskStateMachine.getState().isDone()) {
return PartitionedSplitsInfo.forZeroSplits();
}
synchronized (this) {
int count = 0;
long weight = 0;
for (PlanNodeId tableScanPlanNodeId : fragment.getTableScanSchedulingOrder()) {
Collection<Split> partitionedSplits = splits.get(tableScanPlanNodeId);
count += partitionedSplits.size();
weight = addExact(weight, SplitWeight.rawValueSum(partitionedSplits, Split::getSplitWeight));
}
return PartitionedSplitsInfo.forSplitCountAndWeightSum(count, weight);
}
}
@Override
public synchronized PartitionedSplitsInfo getQueuedPartitionedSplitsInfo()
{
if (taskStateMachine.getState().isDone()) {
return PartitionedSplitsInfo.forZeroSplits();
}
// Let's consider the first drivers encountered to be "running"
int remainingRunning = runningDrivers;
int queuedCount = 0;
long queuedWeight = 0;
for (PlanNodeId tableScanPlanNodeId : fragment.getTableScanSchedulingOrder()) {
for (Split split : splits.get(tableScanPlanNodeId)) {
if (remainingRunning > 0) {
remainingRunning--;
}
else {
queuedCount++;
queuedWeight = addExact(queuedWeight, split.getSplitWeight().getRawValue());
}
}
}
return PartitionedSplitsInfo.forSplitCountAndWeightSum(queuedCount, queuedWeight);
}
@Override
public synchronized int getUnacknowledgedPartitionedSplitCount()
{
return unacknowledgedSplits;
}
@Override
public PlanFragment getPlanFragment()
{
return fragment;
}
}
}