RuntimeReorderJoinSides.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import java.util.Optional;
import static com.facebook.presto.spi.plan.JoinDistributionType.PARTITIONED;
import static com.facebook.presto.spi.plan.JoinDistributionType.REPLICATED;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.spi.plan.JoinType.RIGHT;
import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.createRuntimeSwappedJoinNode;
import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
public class RuntimeReorderJoinSides
implements Rule<JoinNode>
{
private static final Logger log = Logger.get(RuntimeReorderJoinSides.class);
private static final Pattern<JoinNode> PATTERN = join();
private final Metadata metadata;
private final boolean nativeExecution;
public RuntimeReorderJoinSides(Metadata metadata, boolean nativeExecution)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.nativeExecution = nativeExecution;
}
@Override
public Pattern<JoinNode> getPattern()
{
return PATTERN;
}
@Override
public Result apply(JoinNode joinNode, Captures captures, Context context)
{
// Early exit if the leaves of the joinNode subtree include non tableScan nodes.
if (searchFrom(joinNode, context.getLookup())
.where(node -> node.getSources().isEmpty() && !(node instanceof TableScanNode))
.matches()) {
return Result.empty();
}
double leftOutputSizeInBytes = Double.NaN;
double rightOutputSizeInBytes = Double.NaN;
StatsProvider statsProvider = context.getStatsProvider();
if (searchFrom(joinNode, context.getLookup())
.where(node -> !(node instanceof TableScanNode) && !(node instanceof ExchangeNode))
.findAll().size() == 1) {
// Simple plan is characterized as Join directly on tableScanNodes only with exchangeNode in between.
// For simple plans, directly fetch the overall table sizes as the size of the join sides to have
// accurate input bytes statistics and meanwhile avoid non-negligible cost of collecting and processing
// per-column statistics.
leftOutputSizeInBytes = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes();
rightOutputSizeInBytes = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes();
}
if (Double.isNaN(leftOutputSizeInBytes) || Double.isNaN(rightOutputSizeInBytes)) {
// Per-column estimate left and right output size for complex plans or when size statistics is unavailable.
leftOutputSizeInBytes = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes(joinNode.getLeft());
rightOutputSizeInBytes = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes(joinNode.getRight());
}
if (Double.isNaN(leftOutputSizeInBytes) || Double.isNaN(rightOutputSizeInBytes)) {
return Result.empty();
}
if (rightOutputSizeInBytes <= leftOutputSizeInBytes) {
return Result.empty();
}
// Check if the swapped join is valid.
if (!isSwappedJoinValid(joinNode)) {
return Result.empty();
}
Optional<JoinNode> rewrittenNode = createRuntimeSwappedJoinNode(joinNode, metadata, context.getLookup(), context.getSession(), context.getIdAllocator(), nativeExecution);
if (rewrittenNode.isPresent()) {
log.debug(format("Probe size: %.2f is smaller than Build size: %.2f => invoke runtime join swapping on JoinNode ID: %s.", leftOutputSizeInBytes, rightOutputSizeInBytes, joinNode.getId()));
return Result.ofPlanNode(rewrittenNode.get());
}
return Result.empty();
}
private boolean isSwappedJoinValid(JoinNode join)
{
return !(join.getDistributionType().get() == REPLICATED && join.getType() == LEFT) &&
!(join.getDistributionType().get() == PARTITIONED && join.getCriteria().isEmpty() && join.getType() == RIGHT);
}
}