PushRemoteExchangeThroughGroupId.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.connector.system.GlobalSystemConnector;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.Partitioning;
import com.facebook.presto.spi.plan.PartitioningHandle;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.google.common.collect.ImmutableList;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount;
import static com.facebook.presto.SystemSessionProperties.getPartitioningProviderCatalog;
import static com.facebook.presto.SystemSessionProperties.shouldPushRemoteExchangeThroughGroupId;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static com.facebook.presto.sql.planner.plan.Patterns.exchange;
import static com.facebook.presto.sql.planner.plan.Patterns.groupId;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.util.Objects.requireNonNull;

/**
 * Pushes RemoteExchange node down through GroupId node when GroupId node contains non-empty
 * set of common grouping columns.
 *
 * As an example this rule will change following plan
 * Aggregation [final]
 *   - RemoteExchange [repartition]
 *     - Aggregation [partial]
 *       - GroupId
 *         - TableScan
 * To
 * Aggregation
 *   - GroupId
 *     - RemoteExchange [repartition]
 *       - TableScan
 *
 * We can leverage this optimization rule to rewrite plan to be more efficient
 * if following conditions are true:
 *
 * 1. There are large number of grouping sets in query.
 * 2. Partial aggregation reduction ratio is not great.
 * 3. There is least one common grouping key among grouping sets.
 *
 * Note: This rule is disabled by default. Session property
 * PUSH_REMOTE_EXCHANGE_THROUGH_GROUP_ID can be used to enable it.
 */
public final class PushRemoteExchangeThroughGroupId
        implements Rule<ExchangeNode>
{
    private final Metadata metadata;
    private static final Capture<GroupIdNode> GROUP_ID = newCapture();
    private static final Pattern<ExchangeNode> PATTERN = exchange()
            .matching(exchange -> exchange.getScope().isRemote())
            .matching(exchange -> exchange.getType() == REPARTITION)
            .with(source().matching(
                    groupId()
                            .capturedAs(GROUP_ID)
                            .matching(groupId -> !groupId.getCommonGroupingColumns().isEmpty())));

    public PushRemoteExchangeThroughGroupId(Metadata metadata)
    {
        this.metadata = requireNonNull(metadata, "metadata is null");
    }

    @Override
    public Pattern<ExchangeNode> getPattern()
    {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return shouldPushRemoteExchangeThroughGroupId(session);
    }

    @Override
    public Result apply(ExchangeNode node, Captures captures, Context context)
    {
        GroupIdNode groupIdNode = captures.get(GROUP_ID);

        List<VariableReferenceExpression> inputs = getOnlyElement(node.getInputs());
        inputs = removeVariable(inputs, groupIdNode.getGroupIdVariable());
        inputs = replaceAlias(inputs, groupIdNode.getGroupingColumns());

        PartitioningScheme partitioningScheme = node.getPartitioningScheme();
        List<VariableReferenceExpression> outputLayout = partitioningScheme.getOutputLayout();
        outputLayout = removeVariable(outputLayout, groupIdNode.getGroupIdVariable());
        outputLayout = replaceAlias(outputLayout, groupIdNode.getGroupingColumns());

        Set<VariableReferenceExpression> commonGroupingColumns = groupIdNode.getCommonGroupingColumns();
        List<VariableReferenceExpression> partitionColumns = replaceAlias(commonGroupingColumns, groupIdNode.getGroupingColumns());

        // Check that new |partitionColumns| must be subset of original partition columns.
        Map<VariableReferenceExpression, VariableReferenceExpression> groupingColumns = groupIdNode.getGroupingColumns();
        List<VariableReferenceExpression> originalPartitionColumns =
                partitioningScheme.getPartitioning().getVariableReferences()
                        .stream()
                        .map(expr -> groupingColumns.getOrDefault(expr, expr))
                        .collect(toImmutableList());
        if (!originalPartitionColumns.containsAll(partitionColumns)) {
            return Result.empty();
        }

        // Create new PartitioningHandle.
        PartitioningHandle partitioningHandle;
        if (GlobalSystemConnector.NAME.equals(getPartitioningProviderCatalog(context.getSession()))) {
            partitioningHandle = partitioningScheme.getPartitioning().getHandle();
        }
        else {
            partitioningHandle = createPartitioningHandle(context.getSession(), partitionColumns);
        }

        return Result.ofPlanNode(new GroupIdNode(
                node.getSourceLocation(),
                groupIdNode.getId(),
                new ExchangeNode(
                        node.getSourceLocation(),
                        node.getId(),
                        node.getType(),
                        node.getScope(),
                        new PartitioningScheme(
                                Partitioning.create(partitioningHandle, partitionColumns),
                                outputLayout,
                                partitioningScheme.getHashColumn(),
                                partitioningScheme.isReplicateNullsAndAny(),
                                partitioningScheme.isScaleWriters(),
                                partitioningScheme.getEncoding(),
                                partitioningScheme.getBucketToPartition()),
                        ImmutableList.of(groupIdNode.getSource()),
                        ImmutableList.of(inputs),
                        node.isEnsureSourceOrdering(),
                        node.getOrderingScheme()),
                groupIdNode.getGroupingSets(),
                groupIdNode.getGroupingColumns(),
                groupIdNode.getAggregationArguments(),
                groupIdNode.getGroupIdVariable()));
    }

    private static List<VariableReferenceExpression> removeVariable(List<VariableReferenceExpression> variables, VariableReferenceExpression variableToRemove)
    {
        return variables.stream()
                .filter(variable -> !variableToRemove.equals(variable))
                .collect(toImmutableList());
    }

    private static List<VariableReferenceExpression> replaceAlias(Collection<VariableReferenceExpression> variables, Map<VariableReferenceExpression, VariableReferenceExpression> mapping)
    {
        return variables.stream()
                .map(variable -> mapping.containsKey(variable) ? mapping.get(variable) : variable)
                .collect(toImmutableList());
    }

    private PartitioningHandle createPartitioningHandle(Session session, Collection<VariableReferenceExpression> partitioningColumns)
    {
        List<Type> partitioningTypes = partitioningColumns.stream()
                .map(VariableReferenceExpression::getType)
                .collect(toImmutableList());
        return metadata.getPartitioningHandleForExchange(
                session,
                getPartitioningProviderCatalog(session),
                getHashPartitionCount(session),
                partitioningTypes);
    }
}