PayloadJoinOptimizer.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slices;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.facebook.presto.SystemSessionProperties.isOptimizePayloadJoins;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.TypeUtils.isNumericType;
import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.PlannerUtils.clonePlanNode;
import static com.facebook.presto.sql.planner.PlannerUtils.coalesce;
import static com.facebook.presto.sql.planner.PlannerUtils.equalityPredicate;
import static com.facebook.presto.sql.planner.PlannerUtils.isScanFilterProject;
import static com.facebook.presto.sql.planner.PlannerUtils.restrictOutput;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.specialForm;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Sets.intersection;
import static java.util.Objects.requireNonNull;
/*
This optimization targets long chains of LOJ, where a large base table is extended with columns from medium-sized (not broadcastable) tables.
We rewrite a query of the form:
SELECT T.*, S1.A, S2.B...
FROM
T
LOJ S1 ON T.k1 = S1.k
LOJ S2 ON T.k2 = S2.k
LOJ S3 ON T.k3 = S3.k
...
LOJ Sn ON T.kn = Sn.k
into something like this:
SELECT T.*, S1.A, S2.B, ...
FROM
(
(SELECT DISTINCT k1, k2, .. kn,
k1 IS NULL as k1_null,
k2 IS NULL as k2_null,
k3 IS NULL as k3_null,
���
kn IS NULL as kn_null
FROM T1 ) AS T_keys
LOJ S1 ON T_keys.k1 = S1.k
LOJ S2 ON T_keys.k2 = S2.k
LOJ S3 ON T_keys.k3 = S3.k
...
LOJ Sn ON T_keys.kn = Sn.k)
ROJ T
ON
T.k1 IS NULL = k1_null AND
T.k2 IS NULL = k2_null AND
T.k3 IS NULL = k3_null AND
...
T.kn IS NULL = kn_null AND
COALESCE(T.k1, ������) = COALESCE(T_keys.k1, ������) AND
COALESCE(T.k2, ������) = COALESCE(T_keys.k2, ������) AND
COALESCE(T.k3, ������) = COALESCE(T_keys.k3, ������) AND
...
COALESCE(T.kn, ������) = COALESCE(T_keys.kn, ������)
*/
public class PayloadJoinOptimizer
implements PlanOptimizer
{
private final Metadata metadata;
private boolean isEnabledForTesting;
public PayloadJoinOptimizer(Metadata metadata)
{
requireNonNull(metadata, "metadata is null");
this.metadata = metadata;
}
@Override
public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager();
if (isEnabled(session)) {
Rewriter rewriter = new PayloadJoinOptimizer.Rewriter(session, this.metadata, types, functionAndTypeManager, idAllocator, variableAllocator);
PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, new JoinContext());
return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
}
return PlanOptimizerResult.optimizerResult(plan, false);
}
@Override
public void setEnabledForTesting(boolean isSet)
{
isEnabledForTesting = isSet;
}
@Override
public boolean isEnabled(Session session)
{
return isEnabledForTesting || isOptimizePayloadJoins(session);
}
private static class Rewriter
extends SimplePlanRewriter<JoinContext>
{
private final Session session;
Metadata metadata;
private final TypeProvider types;
private final FunctionAndTypeManager functionAndTypeManager;
private final PlanNodeIdAllocator planNodeIdAllocator;
private final VariableAllocator variableAllocator;
private boolean planChanged;
private Rewriter(Session session, Metadata metadata, TypeProvider types, FunctionAndTypeManager functionAndTypeManager, PlanNodeIdAllocator planNodeIdAllocator, VariableAllocator variableAllocator)
{
this.session = requireNonNull(session, "session is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.types = requireNonNull(types, "types is null");
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
this.planNodeIdAllocator = requireNonNull(planNodeIdAllocator, "planNodeIdAllocator is null");
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
}
public boolean isPlanChanged()
{
return planChanged;
}
@Override
public PlanNode visitPlan(PlanNode planNode, RewriteContext<JoinContext> context)
{
// do a default rewrite with a new context for each child to avoid lateral propagation of context information
List<PlanNode> newChildren = planNode.getSources().stream().map(childNode -> context.rewrite(childNode, new JoinContext())).collect(Collectors.toList());
return replaceChildren(planNode, newChildren);
}
@Override
public PlanNode visitJoin(JoinNode joinNode, RewriteContext<JoinContext> context)
{
final JoinContext joinContext = context.get();
Set<VariableReferenceExpression> inputJoinKeys = joinContext.getJoinKeys();
PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();
boolean isTopJoin = joinContext.getJoinKeys().size() == 0;
ImmutableSet<VariableReferenceExpression> leftColumns = leftNode.getOutputVariables().stream().collect(toImmutableSet());
// abort rewrite if some of the collected join keys are in the RHS of the current join
ImmutableSet<VariableReferenceExpression> rightJoinKeys = inputJoinKeys.stream().filter(key -> rightNode.getOutputVariables().contains(key)).collect(toImmutableSet());
Set<VariableReferenceExpression> joinKeys = extractJoinKeys(joinNode.getFilter(), joinNode.getCriteria());
ImmutableSet<VariableReferenceExpression> leftJoinKeys = intersection(joinKeys, leftColumns).immutableCopy();
if (!rightJoinKeys.isEmpty() || !needsRewrite(joinNode.getType(), leftColumns, leftJoinKeys)) {
List<PlanNode> newChildren = joinNode.getSources().stream()
.map(child -> defaultRewriteJoinChild(child, context, joinNode.isCrossJoin()))
.collect(toImmutableList());
return replaceChildren(joinNode, newChildren);
}
joinContext.addKeys(leftJoinKeys);
joinContext.incrementNumJoins();
PlanNode newLeftNode = context.rewrite(leftNode, joinContext);
if (leftNode.equals(newLeftNode)) {
newLeftNode = context.rewrite(leftNode, new JoinContext());
return replaceChildren(joinNode, ImmutableList.of(newLeftNode, rightNode));
}
List<VariableReferenceExpression> leftCols = newLeftNode.getOutputVariables();
List<VariableReferenceExpression> rightCols = rightNode.getOutputVariables();
List<VariableReferenceExpression> allCols = Stream.concat(leftCols.stream(), rightCols.stream()).collect(toImmutableList());
JoinNode newJoinNode = new JoinNode(
joinNode.getSourceLocation(),
planNodeIdAllocator.getNextId(),
joinNode.getType(),
newLeftNode,
rightNode,
joinNode.getCriteria(),
allCols,
joinNode.getFilter(),
joinNode.getLeftHashVariable(),
joinNode.getRightHashVariable(),
joinNode.getDistributionType(),
joinNode.getDynamicFilters());
if (isTopJoin && context.get().needsPayloadRejoin()) {
PlanNode payloadJoin = transformJoin(newJoinNode, joinContext);
// reset payload node as it has been reattached to the plan node
context.get().setPayloadNode(null);
// do a final check that the rewrite didn't lose any columns (can happen if there are intermediate projections on non-join keys that get hidden because of the DISTINCT keys computation)
List<VariableReferenceExpression> outputVariables = joinNode.getOutputVariables();
if (!payloadJoin.getOutputVariables().containsAll(outputVariables)) {
return joinNode;
}
return restrictOutput(payloadJoin, planNodeIdAllocator, outputVariables);
}
planChanged = true;
return newJoinNode;
}
private PlanNode defaultRewriteJoinChild(PlanNode child, RewriteContext<JoinContext> context, boolean isCrossJoin)
{
PlanNode newChild = context.rewrite(child, new JoinContext());
if (isCrossJoin && child.getOutputVariables() != newChild.getOutputVariables()) {
return restrictOutput(newChild, planNodeIdAllocator, child.getOutputVariables());
}
return newChild;
}
private boolean needsRewrite(JoinType joinType, ImmutableSet<VariableReferenceExpression> leftColumns, Set<VariableReferenceExpression> joinKeys)
{
return joinType == LEFT && supportedJoinKeyTypes(joinKeys) && leftColumns.stream().anyMatch(var -> !joinKeys.contains(var));
}
@Override
public PlanNode visitProject(ProjectNode projectNode, RewriteContext<JoinContext> context)
{
if (isScanFilterProject(projectNode)) {
return rewriteScanFilterProject(projectNode, context);
}
PlanNode child = projectNode.getSource();
Set<VariableReferenceExpression> inputJoinKeys = context.get().getJoinKeys();
if (!child.getOutputVariables().containsAll(inputJoinKeys)) {
Map<VariableReferenceExpression, RowExpression> pushableExpressions = new HashMap<>();
projectNode.getAssignments().forEach((var, expr) -> {
if (inputJoinKeys.contains(var) && !var.equals(expr)) {
// join key computed in this projection: need to push down
pushableExpressions.put(var, expr);
}
});
context.get().addProjectionsToPush(pushableExpressions);
}
PlanNode newChild = context.rewrite(child, context.get());
if (child.equals(newChild)) {
return projectNode;
}
// remove assignments that were pushed down
Set<VariableReferenceExpression> joinKeys = context.get().getJoinKeys();
Assignments newAssignments = projectNode.getAssignments();
if (context.get().needsPayloadRejoin() && !child.getOutputVariables().containsAll(joinKeys)) {
Assignments.Builder assignments = Assignments.builder();
projectNode.getAssignments().forEach((var, expr) -> {
if (joinKeys.contains(var) && !var.equals(expr)) {
// join key computed in this projection: need to push down
assignments.put(var, var);
}
else {
assignments.put(var, expr);
}
});
newAssignments = assignments.build();
}
Set<VariableReferenceExpression> newChildOutputVarSet = newChild.getOutputVariables().stream().collect(toImmutableSet());
Assignments newProjectAssighments = removeHiddenColumns(newAssignments, newChildOutputVarSet, context.get().getJoinKeys());
ProjectNode newProjectNode = new ProjectNode(projectNode.getId(), newChild, newProjectAssighments);
// cancel rewrite when some columns needed for the project were hidden by the rewrite
return validateProjectAssignments(newProjectNode) ? newProjectNode : projectNode;
}
@Override
public PlanNode visitFilter(FilterNode filterNode, RewriteContext<JoinContext> context)
{
if (isScanFilterProject(filterNode)) {
return rewriteScanFilterProject(filterNode, context);
}
return context.defaultRewrite(filterNode, new JoinContext());
}
@Override
public PlanNode visitTableScan(TableScanNode scanNode, RewriteContext<JoinContext> context)
{
return rewriteScanFilterProject(scanNode, context);
}
private PlanNode rewriteScanFilterProject(PlanNode planNode, RewriteContext<JoinContext> context)
{
Set<VariableReferenceExpression> joinKeys = context.get().getJoinKeys();
if (joinKeys.size() == 0 || context.get().getNumJoins() < 2) {
return planNode;
}
List<VariableReferenceExpression> outputCols = planNode.getOutputVariables();
if (!ImmutableSet.copyOf(planNode.getOutputVariables()).containsAll(joinKeys)) {
// not all join keys are in the plan node: check if there are any pushable projections
Map<VariableReferenceExpression, RowExpression> projectionsToPush = context.get().getProjectionsToPush();
if (!outputCols.containsAll(VariablesExtractor.extractUnique(projectionsToPush.values()))) {
// abort rewrite
return planNode;
}
PlanNode newProjectNode = addProjections(planNode, planNodeIdAllocator, context.get().getProjectionsToPush());
return constructDistinctKeysPlan(newProjectNode, context, joinKeys);
}
return constructDistinctKeysPlan(planNode, context, joinKeys);
}
private AggregationNode constructDistinctKeysPlan(PlanNode planNode, RewriteContext<JoinContext> context, Set<VariableReferenceExpression> joinKeys)
{
List<VariableReferenceExpression> groupingKeys = joinKeys.stream().collect(toImmutableList());
AggregationNode agg = new AggregationNode(
planNode.getSourceLocation(),
planNodeIdAllocator.getNextId(),
planNode,
ImmutableMap.of(),
singleGroupingSet(groupingKeys),
ImmutableList.of(),
SINGLE,
Optional.empty(),
Optional.empty(),
Optional.empty());
Map<VariableReferenceExpression, VariableReferenceExpression> varMap = new HashMap<>();
for (VariableReferenceExpression var : joinKeys) {
VariableReferenceExpression newVar = variableAllocator.newVariable(var.getName(), var.getType());
varMap.put(var, newVar);
}
context.get().setJoinKeyMap(new HashMap<>(varMap));
PlanNode planNodeCopy = clonePlanNode(planNode, session, metadata, planNodeIdAllocator, planNode.getOutputVariables(), varMap);
context.get().setPayloadNode(planNodeCopy);
return agg;
}
private PlanNode transformJoin(JoinNode keysNode, JoinContext context)
{
PlanNode payloadPlanNode = context.getPayloadNode();
Set<VariableReferenceExpression> joinKeys = context.getJoinKeys();
Map<VariableReferenceExpression, VariableReferenceExpression> joinKeyMap = context.getJoinKeyMap();
checkState(null != payloadPlanNode, "Payload plannode not initialized");
checkState(null != joinKeyMap, "joinkey map not initialized");
FunctionResolution functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
// build new assignments of the form "jk IS NULL as jk_NULL"
Assignments.Builder assignments = Assignments.builder();
ImmutableList.Builder<RowExpression> coalesceComparisonBuilder = ImmutableList.builder();
ImmutableList.Builder<RowExpression> nullComparisonBuilder = ImmutableList.builder();
List<VariableReferenceExpression> joinOutputCols = keysNode.getOutputVariables();
for (VariableReferenceExpression var : joinOutputCols) {
assignments.put(var, var);
}
for (VariableReferenceExpression var : joinKeys) {
VariableReferenceExpression newVar = joinKeyMap.get(var);
VariableReferenceExpression isNullVar = variableAllocator.newVariable(var.getName() + "_NULL", BOOLEAN);
assignments.put(isNullVar, specialForm(IS_NULL, BOOLEAN, ImmutableList.of(var)));
// construct predicate of the form "coalesce(newVar, 0) = coalesce(var, 0)"
RowExpression coalesceComp = equalityPredicate(functionResolution, coalesceToZero(newVar), coalesceToZero(var));
RowExpression nullComp = equalityPredicate(functionResolution, specialForm(IS_NULL, BOOLEAN, ImmutableList.of(newVar)), isNullVar);
nullComparisonBuilder.add(nullComp);
coalesceComparisonBuilder.add(coalesceComp);
}
ProjectNode projectNode = new ProjectNode(planNodeIdAllocator.getNextId(), keysNode, assignments.build());
List<VariableReferenceExpression> resultOutputCols = Stream.concat(payloadPlanNode.getOutputVariables().stream(), projectNode.getOutputVariables().stream()).collect(toImmutableList());
List<RowExpression> joinCriteria = Stream.concat(nullComparisonBuilder.build().stream(), coalesceComparisonBuilder.build().stream()).collect(toImmutableList());
return new JoinNode(
keysNode.getSourceLocation(),
planNodeIdAllocator.getNextId(),
JoinType.LEFT,
payloadPlanNode,
projectNode,
ImmutableList.of(),
resultOutputCols,
Optional.of(LogicalRowExpressions.and(joinCriteria)),
keysNode.getLeftHashVariable(),
keysNode.getRightHashVariable(),
keysNode.getDistributionType(),
keysNode.getDynamicFilters());
}
private Assignments removeHiddenColumns(Assignments newAssignments, Set<VariableReferenceExpression> newChildOutputVarSet, Set<VariableReferenceExpression> joinKeys)
{
Map<VariableReferenceExpression, RowExpression> newAssignmentsMap =
newAssignments.entrySet().stream().filter(assignment ->
newChildOutputVarSet.containsAll(VariablesExtractor.extractUnique(assignment.getValue()))).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Set<VariableReferenceExpression> outputKeys = newAssignmentsMap.keySet();
Map<VariableReferenceExpression, RowExpression> joinKeyMap = joinKeys.stream().filter(key -> !outputKeys.contains(key) && newChildOutputVarSet.contains(key)).collect(Collectors.toMap(Function.identity(), Function.identity()));
newAssignmentsMap.putAll(joinKeyMap);
return new Assignments(newAssignmentsMap);
}
private boolean validateProjectAssignments(ProjectNode projectNode)
{
Assignments assignments = projectNode.getAssignments();
PlanNode input = projectNode.getSource();
ImmutableSet<VariableReferenceExpression> inputColsSet = input.getOutputVariables().stream().collect(toImmutableSet());
for (Map.Entry<VariableReferenceExpression, RowExpression> assignment : assignments.entrySet()) {
RowExpression expr = assignment.getValue();
if (!inputColsSet.containsAll(VariablesExtractor.extractUnique(expr))) {
return false;
}
}
return true;
}
private RowExpression coalesceToZero(RowExpression var)
{
RowExpression zero = zeroForType(var.getType());
return coalesce(ImmutableList.of(var, zero));
}
private Set<VariableReferenceExpression> extractJoinKeys(Optional<RowExpression> filter, List<EquiJoinClause> criteria)
{
ImmutableSet.Builder<VariableReferenceExpression> builder = ImmutableSet.builder();
criteria.forEach((v) -> {
builder.add(v.getLeft());
builder.add(v.getRight());
});
if (filter.isPresent()) {
builder.addAll(VariablesExtractor.extractAll(filter.get()));
}
return builder.build();
}
private boolean supportedJoinKeyTypes(Set<VariableReferenceExpression> joinKeys)
{
return joinKeys.stream().allMatch(key -> key.getType() instanceof VarcharType || isNumericType(key.getType()));
}
}
private static RowExpression zeroForType(Type type)
{
checkArgument(isNumericType(type) || type instanceof VarcharType, "join key should be of numeric or varchar type");
if (isNumericType(type)) {
return constant(0L, BIGINT);
}
return constant(Slices.utf8Slice(""), VarcharType.VARCHAR);
}
private static class JoinContext
{
private Set<VariableReferenceExpression> joinKeys = new HashSet<>();
private Map<VariableReferenceExpression, VariableReferenceExpression> joinKeyMap;
private Map<VariableReferenceExpression, RowExpression> projectionsToPush = new HashMap<>();
int numJoins;
PlanNode payloadNode;
public JoinContext() {}
public Set<VariableReferenceExpression> getJoinKeys()
{
return joinKeys;
}
public void addKeys(ImmutableSet<VariableReferenceExpression> keys)
{
joinKeys.addAll(keys);
}
public Map<VariableReferenceExpression, RowExpression> getProjectionsToPush()
{
return projectionsToPush;
}
public void addProjectionsToPush(Map<VariableReferenceExpression, RowExpression> map)
{
projectionsToPush.putAll(map);
}
public Map<VariableReferenceExpression, VariableReferenceExpression> getJoinKeyMap()
{
return joinKeyMap;
}
public void setJoinKeyMap(Map<VariableReferenceExpression, VariableReferenceExpression> map)
{
joinKeyMap = map;
}
public PlanNode getPayloadNode()
{
return payloadNode;
}
public void setPayloadNode(PlanNode payloadNode)
{
this.payloadNode = payloadNode;
}
public void reset()
{
joinKeys = new HashSet<>();
projectionsToPush = new HashMap<>();
joinKeyMap = null;
numJoins = 0;
payloadNode = null;
}
public int getNumJoins()
{
return numJoins;
}
public void incrementNumJoins()
{
numJoins++;
}
public boolean needsPayloadRejoin()
{
return payloadNode != null;
}
}
}