TableWriterOperator.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.drift.annotations.ThriftConstructor;
import com.facebook.drift.annotations.ThriftField;
import com.facebook.drift.annotations.ThriftStruct;
import com.facebook.presto.Session;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.RunLengthEncodedBlock;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskMetadataContext;
import com.facebook.presto.execution.scheduler.ExecutionWriterTarget;
import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.CreateHandle;
import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.InsertHandle;
import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.RefreshMaterializedViewHandle;
import com.facebook.presto.memory.context.LocalMemoryContext;
import com.facebook.presto.metadata.ConnectorMetadataUpdaterManager;
import com.facebook.presto.operator.OperationTimer.OperationTiming;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorPageSink;
import com.facebook.presto.spi.PageSinkContext;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.connector.ConnectorMetadataUpdater;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.split.PageSinkManager;
import com.facebook.presto.util.AutoCloseableCloser;
import com.facebook.presto.util.Mergeable;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.slice.Slice;
import io.airlift.units.Duration;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue;
import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture;
import static com.facebook.presto.SystemSessionProperties.isStatisticsCpuTimerEnabled;
import static com.facebook.presto.common.RuntimeMetricName.WRITTEN_FILES_COUNT;
import static com.facebook.presto.common.RuntimeUnit.NONE;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.operator.TableWriterUtils.STATS_START_CHANNEL;
import static com.facebook.presto.operator.TableWriterUtils.createStatisticsPage;
import static com.facebook.presto.spi.StandardErrorCode.CONSTRAINT_VIOLATION;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.util.concurrent.Futures.allAsList;
import static io.airlift.slice.Slices.wrappedBuffer;
import static io.airlift.units.Duration.succinctNanos;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
public class TableWriterOperator
implements Operator
{
public static final String OPERATOR_TYPE = "TableWriterOperator";
public static class TableWriterOperatorFactory
implements OperatorFactory
{
private final int operatorId;
private final PlanNodeId planNodeId;
private final PageSinkManager pageSinkManager;
private final ConnectorMetadataUpdaterManager metadataUpdaterManager;
private final TaskMetadataContext taskMetadataContext;
private final ExecutionWriterTarget target;
private final List<Integer> columnChannels;
private final List<String> notNullChannelColumnNames;
private final Session session;
private final OperatorFactory statisticsAggregationOperatorFactory;
private final List<Type> types;
private final PageSinkCommitStrategy pageSinkCommitStrategy;
private boolean closed;
private final JsonCodec<TableCommitContext> tableCommitContextCodec;
public TableWriterOperatorFactory(
int operatorId,
PlanNodeId planNodeId,
PageSinkManager pageSinkManager,
ConnectorMetadataUpdaterManager metadataUpdaterManager,
TaskMetadataContext taskMetadataContext,
ExecutionWriterTarget writerTarget,
List<Integer> columnChannels,
List<String> notNullChannelColumnNames,
Session session,
OperatorFactory statisticsAggregationOperatorFactory,
List<Type> types,
JsonCodec<TableCommitContext> tableCommitContextCodec,
PageSinkCommitStrategy pageSinkCommitStrategy)
{
this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
this.columnChannels = requireNonNull(columnChannels, "columnChannels is null");
this.notNullChannelColumnNames = requireNonNull(notNullChannelColumnNames, "notNullChannelColumnNames is null");
this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null");
this.metadataUpdaterManager = requireNonNull(metadataUpdaterManager, "metadataUpdaterManager is null");
this.taskMetadataContext = requireNonNull(taskMetadataContext, "taskMetadataContext is null");
checkArgument(
writerTarget instanceof CreateHandle || writerTarget instanceof InsertHandle || writerTarget instanceof RefreshMaterializedViewHandle,
"writerTarget must be CreateHandle or InsertHandle or RefreshMaterializedViewHandle");
this.target = requireNonNull(writerTarget, "writerTarget is null");
this.session = session;
this.statisticsAggregationOperatorFactory = requireNonNull(statisticsAggregationOperatorFactory, "statisticsAggregationOperatorFactory is null");
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.tableCommitContextCodec = requireNonNull(tableCommitContextCodec, "tableCommitContextCodec is null");
this.pageSinkCommitStrategy = requireNonNull(pageSinkCommitStrategy, "pageSinkCommitStrategy is null");
}
@Override
public Operator createOperator(DriverContext driverContext)
{
checkState(!closed, "Factory is already closed");
OperatorContext context = driverContext.addOperatorContext(operatorId, planNodeId, TableWriterOperator.class.getSimpleName());
Operator statisticsAggregationOperator = statisticsAggregationOperatorFactory.createOperator(driverContext);
boolean statisticsCpuTimerEnabled = !(statisticsAggregationOperator instanceof DevNullOperator) && isStatisticsCpuTimerEnabled(session);
return new TableWriterOperator(
context,
createPageSink(),
columnChannels,
notNullChannelColumnNames,
statisticsAggregationOperator,
types,
statisticsCpuTimerEnabled,
tableCommitContextCodec,
pageSinkCommitStrategy);
}
private ConnectorPageSink createPageSink()
{
ConnectorId connectorId = getConnectorId(target);
Optional<ConnectorMetadataUpdater> metadataUpdater = metadataUpdaterManager.getMetadataUpdater(connectorId);
if (metadataUpdater.isPresent()) {
taskMetadataContext.setConnectorId(connectorId);
taskMetadataContext.addMetadataUpdater(metadataUpdater.get());
}
PageSinkContext.Builder pageSinkContextBuilder = PageSinkContext.builder()
.setCommitRequired(pageSinkCommitStrategy.isCommitRequired());
metadataUpdater.ifPresent(pageSinkContextBuilder::setConnectorMetadataUpdater);
if (target instanceof CreateHandle) {
return pageSinkManager.createPageSink(session, ((CreateHandle) target).getHandle(), pageSinkContextBuilder.build());
}
if (target instanceof InsertHandle) {
return pageSinkManager.createPageSink(session, ((InsertHandle) target).getHandle(), pageSinkContextBuilder.build());
}
if (target instanceof RefreshMaterializedViewHandle) {
return pageSinkManager.createPageSink(session, ((RefreshMaterializedViewHandle) target).getHandle(), pageSinkContextBuilder.build());
}
throw new UnsupportedOperationException("Unhandled target type: " + target.getClass().getName());
}
private static ConnectorId getConnectorId(ExecutionWriterTarget handle)
{
if (handle instanceof CreateHandle) {
return ((CreateHandle) handle).getHandle().getConnectorId();
}
if (handle instanceof InsertHandle) {
return ((InsertHandle) handle).getHandle().getConnectorId();
}
if (handle instanceof RefreshMaterializedViewHandle) {
return ((RefreshMaterializedViewHandle) handle).getHandle().getConnectorId();
}
throw new UnsupportedOperationException("Unhandled target type: " + handle.getClass().getName());
}
@Override
public void noMoreOperators()
{
closed = true;
}
@Override
public OperatorFactory duplicate()
{
return new TableWriterOperatorFactory(
operatorId,
planNodeId,
pageSinkManager,
metadataUpdaterManager,
taskMetadataContext,
target,
columnChannels,
notNullChannelColumnNames,
session,
statisticsAggregationOperatorFactory,
types,
tableCommitContextCodec,
pageSinkCommitStrategy);
}
}
private enum State
{
RUNNING, FINISHING, FINISHED
}
private final OperatorContext operatorContext;
private final LocalMemoryContext pageSinkMemoryContext;
private final ConnectorPageSink pageSink;
private final List<Integer> columnChannels;
private final List<String> notNullChannelColumnNames;
private final AtomicLong pageSinkPeakMemoryUsage = new AtomicLong();
private final Operator statisticAggregationOperator;
private final List<Type> types;
private ListenableFuture<?> blocked = NOT_BLOCKED;
private CompletableFuture<Collection<Slice>> finishFuture;
private State state = State.RUNNING;
private long rowCount;
private boolean committed;
private boolean closed;
private long writtenBytes;
private final OperationTiming statisticsTiming = new OperationTiming();
private final boolean statisticsCpuTimerEnabled;
private final JsonCodec<TableCommitContext> tableCommitContextCodec;
private final PageSinkCommitStrategy pageSinkCommitStrategy;
private final Supplier<TableWriterInfo> tableWriterInfoSupplier;
public TableWriterOperator(
OperatorContext operatorContext,
ConnectorPageSink pageSink,
List<Integer> columnChannels,
List<String> notNullChannelColumnNames,
Operator statisticAggregationOperator,
List<Type> types,
boolean statisticsCpuTimerEnabled,
JsonCodec<TableCommitContext> tableCommitContextCodec,
PageSinkCommitStrategy pageSinkCommitStrategy)
{
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
this.pageSinkMemoryContext = operatorContext.localSystemMemoryContext();
this.pageSink = requireNonNull(pageSink, "pageSink is null");
this.columnChannels = requireNonNull(columnChannels, "columnChannels is null");
this.notNullChannelColumnNames = requireNonNull(notNullChannelColumnNames, "notNullChannelColumnNames is null");
checkArgument(columnChannels.size() == notNullChannelColumnNames.size(), "columnChannels and notNullColumnNames have different sizes");
this.statisticAggregationOperator = requireNonNull(statisticAggregationOperator, "statisticAggregationOperator is null");
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.statisticsCpuTimerEnabled = statisticsCpuTimerEnabled;
this.tableCommitContextCodec = requireNonNull(tableCommitContextCodec, "tableCommitContextCodec is null");
this.pageSinkCommitStrategy = requireNonNull(pageSinkCommitStrategy, "pageSinkCommitStrategy is null");
this.tableWriterInfoSupplier = createTableWriterInfoSupplier(pageSinkPeakMemoryUsage, statisticsTiming, pageSink);
this.operatorContext.setInfoSupplier(tableWriterInfoSupplier);
}
@Override
public OperatorContext getOperatorContext()
{
return operatorContext;
}
@Override
public void finish()
{
ListenableFuture<?> currentlyBlocked = blocked;
OperationTimer timer = new OperationTimer(statisticsCpuTimerEnabled);
statisticAggregationOperator.finish();
timer.end(statisticsTiming);
ListenableFuture<?> blockedOnAggregation = statisticAggregationOperator.isBlocked();
ListenableFuture<?> blockedOnFinish = NOT_BLOCKED;
if (state == State.RUNNING) {
state = State.FINISHING;
finishFuture = pageSink.finish();
blockedOnFinish = toListenableFuture(finishFuture);
updateWrittenBytes();
}
this.blocked = allAsList(currentlyBlocked, blockedOnAggregation, blockedOnFinish);
}
@Override
public boolean isFinished()
{
return state == State.FINISHED && blocked.isDone();
}
@Override
public ListenableFuture<?> isBlocked()
{
return blocked;
}
@Override
public boolean needsInput()
{
if (state != State.RUNNING || !blocked.isDone()) {
return false;
}
return statisticAggregationOperator.needsInput();
}
@Override
public void addInput(Page page)
{
requireNonNull(page, "page is null");
checkState(needsInput(), "Operator does not need input");
Block[] blocks = new Block[columnChannels.size()];
for (int outputChannel = 0; outputChannel < columnChannels.size(); outputChannel++) {
Block block = page.getBlock(columnChannels.get(outputChannel));
String columnName = notNullChannelColumnNames.get(outputChannel);
if (columnName != null) {
verifyBlockHasNoNulls(block, columnName);
}
blocks[outputChannel] = block;
}
OperationTimer timer = new OperationTimer(statisticsCpuTimerEnabled);
statisticAggregationOperator.addInput(page);
timer.end(statisticsTiming);
ListenableFuture<?> blockedOnAggregation = statisticAggregationOperator.isBlocked();
CompletableFuture<?> future = pageSink.appendPage(new Page(blocks));
updateMemoryUsage();
ListenableFuture<?> blockedOnWrite = toListenableFuture(future);
blocked = allAsList(blockedOnAggregation, blockedOnWrite);
rowCount += page.getPositionCount();
updateWrittenBytes();
}
private void verifyBlockHasNoNulls(Block block, String columnName)
{
if (!block.mayHaveNull()) {
return;
}
for (int position = 0; position < block.getPositionCount(); position++) {
if (block.isNull(position)) {
throw new PrestoException(CONSTRAINT_VIOLATION, "NULL value not allowed for NOT NULL column: " + columnName);
}
}
}
@Override
public Page getOutput()
{
if (!blocked.isDone()) {
return null;
}
if (!statisticAggregationOperator.isFinished()) {
OperationTimer timer = new OperationTimer(statisticsCpuTimerEnabled);
Page aggregationOutput = statisticAggregationOperator.getOutput();
timer.end(statisticsTiming);
if (aggregationOutput == null) {
return null;
}
return createStatisticsPage(types, aggregationOutput, createTableCommitContext(false));
}
if (state != State.FINISHING) {
return null;
}
Page fragmentsPage = createFragmentsPage();
int positionCount = fragmentsPage.getPositionCount();
Block[] outputBlocks = new Block[types.size()];
for (int channel = 0; channel < types.size(); channel++) {
if (channel < STATS_START_CHANNEL) {
outputBlocks[channel] = fragmentsPage.getBlock(channel);
}
else {
outputBlocks[channel] = RunLengthEncodedBlock.create(types.get(channel), null, positionCount);
}
}
updateWrittenFilesCount();
state = State.FINISHED;
return new Page(positionCount, outputBlocks);
}
// Fragments page layout:
//
// row fragments context
// X null X
// null X X
// null X X
// null X X
// ...
private Page createFragmentsPage()
{
Collection<Slice> fragments = getFutureValue(finishFuture);
int positionCount = fragments.size() + 1;
committed = true;
updateWrittenBytes();
// Output page will only be constructed once, and the table commit context channel will be constructed using RunLengthEncodedBlock.
// Thus individual BlockBuilder is used for each channel, instead of using PageBuilder.
BlockBuilder rowsBuilder = BIGINT.createBlockBuilder(null, positionCount);
BlockBuilder fragmentBuilder = VARBINARY.createBlockBuilder(null, positionCount);
// write row count
BIGINT.writeLong(rowsBuilder, rowCount);
fragmentBuilder.appendNull();
// write fragments
for (Slice fragment : fragments) {
rowsBuilder.appendNull();
VARBINARY.writeSlice(fragmentBuilder, fragment);
}
return new Page(positionCount, rowsBuilder.build(), fragmentBuilder.build(), RunLengthEncodedBlock.create(VARBINARY, createTableCommitContext(true), positionCount));
}
private Slice createTableCommitContext(boolean lastPage)
{
TaskId taskId = operatorContext.getDriverContext().getPipelineContext().getTaskId();
return wrappedBuffer(tableCommitContextCodec.toJsonBytes(
new TableCommitContext(
operatorContext.getDriverContext().getLifespan(),
taskId,
pageSinkCommitStrategy,
lastPage)));
}
@Override
public void close()
throws Exception
{
AutoCloseableCloser closer = AutoCloseableCloser.create();
if (!closed) {
closed = true;
if (!committed) {
closer.register(pageSink::abort);
}
}
closer.register(statisticAggregationOperator);
closer.register(() -> pageSinkMemoryContext.setBytes(0));
closer.close();
}
private void updateWrittenBytes()
{
long current = pageSink.getCompletedBytes();
operatorContext.recordPhysicalWrittenData(current - writtenBytes);
writtenBytes = current;
}
private void updateWrittenFilesCount()
{
operatorContext.getRuntimeStats().addMetricValue(WRITTEN_FILES_COUNT, NONE, pageSink.getWrittenFilesCount());
}
private void updateMemoryUsage()
{
long pageSinkMemoryUsage = pageSink.getSystemMemoryUsage();
pageSinkMemoryContext.setBytes(pageSinkMemoryUsage);
pageSinkPeakMemoryUsage.accumulateAndGet(pageSinkMemoryUsage, Math::max);
}
@VisibleForTesting
Operator getStatisticAggregationOperator()
{
return statisticAggregationOperator;
}
@VisibleForTesting
TableWriterInfo getInfo()
{
return tableWriterInfoSupplier.get();
}
private static Supplier<TableWriterInfo> createTableWriterInfoSupplier(AtomicLong pageSinkPeakMemoryUsage, OperationTiming statisticsTiming, ConnectorPageSink pageSink)
{
requireNonNull(pageSinkPeakMemoryUsage, "pageSinkPeakMemoryUsage is null");
requireNonNull(statisticsTiming, "statisticsTiming is null");
requireNonNull(pageSink, "pageSink is null");
return () -> new TableWriterInfo(
pageSinkPeakMemoryUsage.get(),
succinctNanos(statisticsTiming.getWallNanos()),
succinctNanos(statisticsTiming.getCpuNanos()),
succinctNanos(pageSink.getValidationCpuNanos()));
}
@ThriftStruct
public static class TableWriterInfo
implements Mergeable<TableWriterInfo>, OperatorInfo
{
private final long pageSinkPeakMemoryUsage;
private final Duration statisticsWallTime;
private final Duration statisticsCpuTime;
private final Duration validationCpuTime;
@JsonCreator
@ThriftConstructor
public TableWriterInfo(
@JsonProperty("pageSinkPeakMemoryUsage") long pageSinkPeakMemoryUsage,
@JsonProperty("statisticsWallTime") Duration statisticsWallTime,
@JsonProperty("statisticsCpuTime") Duration statisticsCpuTime,
@JsonProperty("validationCpuTime") Duration validationCpuTime)
{
this.pageSinkPeakMemoryUsage = pageSinkPeakMemoryUsage;
this.statisticsWallTime = requireNonNull(statisticsWallTime, "statisticsWallTime is null");
this.statisticsCpuTime = requireNonNull(statisticsCpuTime, "statisticsCpuTime is null");
this.validationCpuTime = requireNonNull(validationCpuTime, "validationCpuTime is null");
}
@JsonProperty
@ThriftField(1)
public long getPageSinkPeakMemoryUsage()
{
return pageSinkPeakMemoryUsage;
}
@JsonProperty
@ThriftField(2)
public Duration getStatisticsWallTime()
{
return statisticsWallTime;
}
@JsonProperty
@ThriftField(3)
public Duration getStatisticsCpuTime()
{
return statisticsCpuTime;
}
@JsonProperty
@ThriftField(4)
public Duration getValidationCpuTime()
{
return validationCpuTime;
}
@Override
public TableWriterInfo mergeWith(TableWriterInfo other)
{
return new TableWriterInfo(
Math.max(pageSinkPeakMemoryUsage, other.pageSinkPeakMemoryUsage),
succinctNanos(statisticsWallTime.roundTo(NANOSECONDS) + other.statisticsWallTime.roundTo(NANOSECONDS)),
succinctNanos(statisticsCpuTime.roundTo(NANOSECONDS) + other.statisticsCpuTime.roundTo(NANOSECONDS)),
succinctNanos(validationCpuTime.roundTo(NANOSECONDS) + other.validationCpuTime.roundTo(NANOSECONDS)));
}
@Override
public boolean isFinal()
{
return true;
}
@Override
public String toString()
{
return toStringHelper(this)
.add("pageSinkPeakMemoryUsage", pageSinkPeakMemoryUsage)
.add("statisticsWallTime", statisticsWallTime)
.add("statisticsCpuTime", statisticsCpuTime)
.add("validationCpuTime", validationCpuTime)
.toString();
}
}
}