JoinExpandOrToUnionRule.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.rel.rules;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.immutables.value.Value;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Planner rule that matches a
* {@link org.apache.calcite.rel.core.Join}
* and expands OR clauses in join conditions.
*
* <p>This rule expands OR conditions in join clauses into
* multiple separate join conditions, allowing the optimizer
* to handle these conditions more efficiently.
*
* <p>The following is an example of inner join,
* and other examples for other kinds of joins
* can be found in the code below.
*
* <pre>{@code
* Project[*]
* ��������� Join[OR(t1.id=t2.id, t1.age=t2.age), inner]
* ��������� TableScan[t1]
* ��������� TableScan[t2]
* }</pre>
*
* <p>into
*
* <pre>{@code
* Project[*]
* ��������� UnionAll
* ��������� Join[t1.id=t2.id, inner]
* ��� ��������� TableScan[t1]
* ��� ��������� TableScan[t2]
* ��������� Join[t1.age=t2.age AND t1.id���t2.id, inner]
* ������������ TableScan[t1]
* ������������ TableScan[t2]
* }</pre>
*/
@Value.Enclosing
public class JoinExpandOrToUnionRule
extends RelRule<JoinExpandOrToUnionRule.Config>
implements TransformationRule {
/** Creates an JoinExpandOrToUnionRule. */
protected JoinExpandOrToUnionRule(Config config) {
super(config);
}
//~ Methods ----------------------------------------------------------------
@Override public boolean matches(RelOptRuleCall call) {
Join join = call.rel(0);
List<RexNode> orConds = RelOptUtil.disjunctions(join.getCondition());
return orConds.size() > 1;
}
@Override public void onMatch(RelOptRuleCall call) {
Join join = call.rel(0);
RelBuilder relBuilder = call.builder();
RelNode expanded;
switch (join.getJoinType()) {
case INNER:
expanded = expandInnerJoin(join, relBuilder);
break;
case ANTI:
expanded = expandAntiJoin(join, relBuilder);
break;
case LEFT:
expanded = expandLeftOrRightJoin(join, true, relBuilder);
break;
case RIGHT:
expanded = expandLeftOrRightJoin(join, false, relBuilder);
break;
case FULL:
expanded = expandFullJoin(join, relBuilder);
break;
default:
return;
}
if (expanded instanceof Join
&& ((Join) expanded).getCondition().equals(join.getCondition())) {
return;
}
call.transformTo(expanded);
}
private List<RexNode> splitCond(Join join) {
final RexBuilder builder = join.getCluster().getRexBuilder();
final List<RexNode> orConds = RelOptUtil.disjunctions(join.getCondition());
final int leftFieldCount = join.getLeft().getRowType().getFieldCount();
final List<RexNode> result = new ArrayList<>();
final List<RexNode> otherBuffer = new ArrayList<>();
for (RexNode cond : orConds) {
if (isValidCond(cond, leftFieldCount)) {
if (!otherBuffer.isEmpty()) {
result.add(RexUtil.composeDisjunction(builder, otherBuffer));
otherBuffer.clear();
}
result.add(cond);
} else {
otherBuffer.add(cond);
}
}
if (!otherBuffer.isEmpty()) {
result.add(RexUtil.composeDisjunction(builder, otherBuffer));
}
return result;
}
private boolean isValidCond(RexNode node, int leftFieldCount) {
boolean hasJoinKeyCond = false;
List<RexNode> conds = RelOptUtil.conjunctions(node);
for (RexNode cond : conds) {
// When components of the "call" are predicates that contain
// equality (when the above conditions are met), that are
// single-side (refer to only on of the collections joined),
// or which are constant, they will all trigger the expansion.
if (!doesNotReferToBothInputs(cond, leftFieldCount)) {
return false;
}
if (RexUtil.SubQueryFinder.find(cond) != null
|| RexUtil.containsCorrelation(cond)) {
// The "call" does not support sub-queries or correlation yet
return false;
}
if (cond instanceof RexCall) {
RexCall call = (RexCall) cond;
// Checks if the "call" is valid for use as a join key.
if (isEquiJoinCond(call, leftFieldCount)) {
hasJoinKeyCond = true;
}
}
}
return hasJoinKeyCond;
}
private boolean isEquiJoinCond(RexCall call, int leftFieldCount) {
if (call.getKind() != SqlKind.EQUALS
&& call.getKind() != SqlKind.IS_NOT_DISTINCT_FROM) {
return false;
}
RexNode left = call.getOperands().get(0);
RexNode right = call.getOperands().get(1);
if (left instanceof RexInputRef && right instanceof RexInputRef) {
int leftIndex = ((RexInputRef) left).getIndex();
int rightIndex = ((RexInputRef) right).getIndex();
return (leftIndex < leftFieldCount && rightIndex >= leftFieldCount)
|| (rightIndex < leftFieldCount && leftIndex >= leftFieldCount);
}
return false;
}
private boolean doesNotReferToBothInputs(RexNode rex, int leftFieldCount) {
RexInputRefCounter counter = new RexInputRefCounter(leftFieldCount);
rex.accept(counter);
return counter.doesNotReferToBothInputs();
}
/**
* Counts the number of InputRefs in a RexNode expression. */
private static class RexInputRefCounter extends RexVisitorImpl<Void> {
private int leftFieldCount;
public int leftInputRefCount = 0;
public int rightInputRefCount = 0;
RexInputRefCounter(int leftFieldCount) {
super(true);
this.leftFieldCount = leftFieldCount;
}
@Override public Void visitInputRef(RexInputRef inputRef) {
if (inputRef.getIndex() < leftFieldCount) {
leftFieldCount++;
} else {
rightInputRefCount++;
}
return null;
}
public boolean doesNotReferToBothInputs() {
return leftInputRefCount == 0 || rightInputRefCount == 0;
}
}
/**
* This method will make the following conversions.
*
* <pre>{@code
* Project[*]
* ��������� Join[OR(t1.id=t2.id, t1.age=t2.age), left]
* ��������� TableScan[t1]
* ��������� TableScan[t2]
*}</pre>
*
* <p>into
*
* <pre>{@code
* Project[*]
* ��������� UnionAll
* ��������� Join[t1.id=t2.id, inner]
* ��� ��������� TableScan[t1]
* ��� ��������� TableScan[t2]
* ��������� Join[t1.age=t2.age AND t1.id���t2.id, inner]
* ��� ��������� TableScan[t1]
* ��� ��������� TableScan[t2]
* ��������� Project[t1-side cols + NULLs]
* ��������� Join[t1.id=t2.id, anti]
* ��������� Join[t1.age=t2.age, anti]
* ��� ��������� TableScan[t1]
* ��� ��������� TableScan[t2]
* ��������� TableScan[t2]
* }</pre>
*/
private RelNode expandLeftOrRightJoin(Join join, boolean isLeftJoin,
RelBuilder relBuilder) {
List<RexNode> orConds = splitCond(join);
List<RelNode> joins = expandLeftOrRightJoinToRelNodes(join, orConds, isLeftJoin, relBuilder);
return relBuilder.pushAll(joins)
.union(true, joins.size())
.build();
}
private List<RelNode> expandLeftOrRightJoinToRelNodes(Join join, List<RexNode> orConds,
boolean isLeftJoin, RelBuilder relBuilder) {
List<RelNode> joins = new ArrayList<>();
joins.addAll(expandInnerJoinToRelNodes(join, orConds, relBuilder));
joins.add(expandAntiJoinToRelNode(join, orConds, isLeftJoin, true, relBuilder));
return joins;
}
/**
* This method will make the following conversions.
*
* <pre>{@code
* Project[*]
* ��������� Join[OR(t1.id=t2.id, t1.age=t2.age), full]
* ��������� TableScan[t1]
* ��������� TableScan[t2]
* }</pre>
*
* <p>into
*
* <pre>{@code
* Project[*]
* ��������� UnionAll
* ��������� Join[t1.id=t2.id, inner]
* ��� ��������� TableScan[t1]
* ��� ��������� TableScan[t2]
* ��������� Join[t1.age=t2.age AND t1.id���t2.id, inner]
* ��� ��������� TableScan[t1]
* ��� ��������� TableScan[t2]
* ��������� Project[t1-side cols + NULLs]
* ��� ��������� Join[t1.id=t2.id, anti]
* ��� ��������� Join[t1.age=t2.age, anti]
* ��� ��� ��������� TableScan[t1]
* ��� ��� ��������� TableScan[t2]
* ��� ��������� TableScan[t2]
* ��������� Project[NULLs + t2-side cols]
* ��������� Join[t2.id=t1.id, anti]
* ��������� Join[t2.age=t1.age, anti]
* ��� ��������� TableScan[t2]
* ��� ��������� TableScan[t1]
* ��������� TableScan[t1]
* }</pre>
*/
private RelNode expandFullJoin(Join join, RelBuilder relBuilder) {
List<RexNode> orConds = splitCond(join);
List<RelNode> joins = new ArrayList<>();
joins.addAll(expandInnerJoinToRelNodes(join, orConds, relBuilder));
joins.add(expandAntiJoinToRelNode(join, orConds, false, true, relBuilder));
joins.add(expandAntiJoinToRelNode(join, orConds, true, true, relBuilder));
relBuilder.pushAll(joins)
.union(true, joins.size());
final List<RexNode> projects = join.getRowType().getFieldList().stream()
.map(field -> {
RexNode rexNode = relBuilder.field(field.getIndex());
return field.getType().equals(rexNode.getType())
? rexNode
: relBuilder.getRexBuilder().makeCast(field.getType(), rexNode, true, false);
}).collect(Collectors.toList());
return relBuilder.project(projects)
.build();
}
/**
* This method will make the following conversions.
*
* <pre>{@code
* Project[*]
* ��������� Join[OR(t1.id=t2.id, t1.age=t2.age), inner]
* ��������� TableScan[t1]
* ��������� TableScan[t2]
* }</pre>
*
* <p>into
*
* <pre>{@code
* Project[*]
* ��������� UnionAll
* ��������� Join[t1.id=t2.id, inner]
* ��� ��������� TableScan[t1]
* ��� ��������� TableScan[t2]
* ��������� Join[t1.age=t2.age AND t1.id���t2.id, inner]
* ������������ TableScan[t1]
* ������������ TableScan[t2]
* }</pre>
*/
private RelNode expandInnerJoin(Join join, RelBuilder relBuilder) {
List<RexNode> orConds = splitCond(join);
List<RelNode> joins = expandInnerJoinToRelNodes(join, orConds, relBuilder);
return relBuilder.pushAll(joins)
.union(true, joins.size())
.build();
}
private List<RelNode> expandInnerJoinToRelNodes(Join join, List<RexNode> orConds,
RelBuilder relBuilder) {
List<RelNode> joins = new ArrayList<>();
for (int i = 0; i < orConds.size(); i++) {
RexNode orCond = orConds.get(i);
for (int j = 0; j < i; j++) {
orCond = relBuilder.and(orCond, relBuilder.not(orConds.get(j)));
}
relBuilder.push(join.getLeft())
.push(join.getRight())
.join(JoinRelType.INNER, orCond);
joins.add(relBuilder.build());
}
return joins;
}
/**
* This method will make the following conversions.
*
* <pre>{@code
* Project[*]
* ��������� Join[OR(id=id0, age=age0), anti]
* ��������� TableScan[tbl]
* ��������� TableScan[tbl]
* }</pre>
*
* <p>into
*
* <pre>{@code
* HashJoin[id=id0, anti]
* ��������� HashJoin[age=age0, anti]
* ��� ��������� TableScan[tbl]
* ��� ��������� TableScan[tbl]
* ��������� TableScan[tbl]
* }</pre>
*/
private RelNode expandAntiJoin(Join join, RelBuilder relBuilder) {
List<RexNode> orConds = splitCond(join);
return expandAntiJoinToRelNode(join, orConds, true, false, relBuilder);
}
private RelNode expandAntiJoinToRelNode(Join join, List<RexNode> orConds,
boolean isLeftAnti, boolean isAppendNulls, RelBuilder relBuilder) {
RelNode left = isLeftAnti ? join.getLeft() : join.getRight();
RelNode right = isLeftAnti ? join.getRight() : join.getLeft();
RelNode top = left;
for (int i = 0; i < orConds.size(); i++) {
RexNode orCond = orConds.get(i);
relBuilder.push(top)
.push(right)
.join(JoinRelType.ANTI,
isLeftAnti
? orCond
: JoinCommuteRule.swapJoinCond(orCond, join, relBuilder.getRexBuilder()));
top = relBuilder.build();
}
if (!isAppendNulls) {
return top;
}
relBuilder.push(top);
List<RexNode> fields = new ArrayList<>(relBuilder.fields());
List<RexNode> nulls = new ArrayList<>();
for (int i = 0; i < right.getRowType().getFieldCount(); i++) {
nulls.add(
relBuilder.getRexBuilder().makeNullLiteral(
right.getRowType().getFieldList().get(i).getType()));
}
List<RexNode> projects = isLeftAnti
? Stream.concat(fields.stream(), nulls.stream()).collect(Collectors.toList())
: Stream.concat(nulls.stream(), fields.stream()).collect(Collectors.toList());
return relBuilder.project(projects)
.build();
}
/** Rule configuration. */
@Value.Immutable
public interface Config extends RelRule.Config {
Config DEFAULT = ImmutableJoinExpandOrToUnionRule.Config.of()
.withOperandFor(Join.class);
@Override default JoinExpandOrToUnionRule toRule() {
return new JoinExpandOrToUnionRule(this);
}
/** Defines an operand tree for the given classes. */
default Config withOperandFor(Class<? extends Join> joinClass) {
return withOperandSupplier(b -> b.operand(joinClass).anyInputs())
.as(Config.class);
}
}
}