HyperGraph.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.rel.rules;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.linq4j.function.Experimental;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.AbstractRelNode;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexNodeAndFieldIndex;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import com.google.common.collect.Lists;
import org.checkerframework.checker.nullness.qual.Nullable;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.google.common.base.Preconditions.checkArgument;
/**
* HyperGraph represents a join graph.
*/
@Experimental
public class HyperGraph extends AbstractRelNode {
private final List<RelNode> inputs;
// unprojected input (from the right child of the semi/anti join) bitmap. Used to convert
// RexInputRef to RexNodeAndFieldIndex when building hypergraph
private final long notProjectInputs;
@SuppressWarnings("HidingField")
private final RelDataType rowType;
private final List<HyperEdge> edges;
// record the indices of complex hyper edges in the 'edges'
private final ImmutableBitSet complexEdgesBitmap;
/**
* For the HashMap fields, key is the bitmap for inputs,
* value is the hyper edge bitmap in edges.
*/
// record which hyper edges have been used by the enumerated csg-cmp pairs
private final HashMap<Long, BitSet> ccpUsedEdgesMap;
private final HashMap<Long, BitSet> simpleEdgesMap;
private final HashMap<Long, BitSet> complexEdgesMap;
// node bitmap overlaps edge's leftNodeBits or rightNodeBits, but does not completely cover
private final HashMap<Long, BitSet> overlapEdgesMap;
protected HyperGraph(RelOptCluster cluster,
RelTraitSet traitSet,
List<RelNode> inputs,
long notProjectInputs,
List<HyperEdge> edges,
RelDataType rowType) {
super(cluster, traitSet);
this.inputs = Lists.newArrayList(inputs);
this.notProjectInputs = notProjectInputs;
this.edges = Lists.newArrayList(edges);
this.rowType = rowType;
ImmutableBitSet.Builder bitSetBuilder = ImmutableBitSet.builder();
for (int i = 0; i < edges.size(); i++) {
if (!edges.get(i).isSimple()) {
bitSetBuilder.set(i);
}
}
this.complexEdgesBitmap = bitSetBuilder.build();
this.ccpUsedEdgesMap = new HashMap<>();
this.simpleEdgesMap = new HashMap<>();
this.complexEdgesMap = new HashMap<>();
this.overlapEdgesMap = new HashMap<>();
}
protected HyperGraph(RelOptCluster cluster,
RelTraitSet traitSet,
List<RelNode> inputs,
long notProjectInputs,
List<HyperEdge> edges,
RelDataType rowType,
ImmutableBitSet complexEdgesBitmap,
HashMap<Long, BitSet> ccpUsedEdgesMap,
HashMap<Long, BitSet> simpleEdgesMap,
HashMap<Long, BitSet> complexEdgesMap,
HashMap<Long, BitSet> overlapEdgesMap) {
super(cluster, traitSet);
this.inputs = Lists.newArrayList(inputs);
this.notProjectInputs = notProjectInputs;
this.edges = Lists.newArrayList(edges);
this.rowType = rowType;
this.complexEdgesBitmap = complexEdgesBitmap;
this.ccpUsedEdgesMap = new HashMap<>(ccpUsedEdgesMap);
this.simpleEdgesMap = new HashMap<>(simpleEdgesMap);
this.complexEdgesMap = new HashMap<>(complexEdgesMap);
this.overlapEdgesMap = new HashMap<>(overlapEdgesMap);
}
@Override public HyperGraph copy(RelTraitSet traitSet, List<RelNode> inputs) {
return new HyperGraph(
getCluster(),
traitSet,
inputs,
notProjectInputs,
edges,
rowType,
complexEdgesBitmap,
ccpUsedEdgesMap,
simpleEdgesMap,
complexEdgesMap,
overlapEdgesMap);
}
@Override public RelWriter explainTerms(RelWriter pw) {
super.explainTerms(pw);
for (Ord<RelNode> ord : Ord.zip(inputs)) {
pw.input("input#" + ord.i, ord.e);
}
List<String> hyperEdges = edges.stream()
.map(hyperEdge -> hyperEdge.toString())
.collect(Collectors.toList());
pw.item("edges", String.join(",", hyperEdges));
return pw;
}
@Override public List<RelNode> getInputs() {
return inputs;
}
@Override public void replaceInput(int ordinalInParent, RelNode p) {
inputs.set(ordinalInParent, p);
recomputeDigest();
}
@Override public RelDataType deriveRowType() {
return rowType;
}
@Override public RelNode accept(RexShuttle shuttle) {
List<HyperEdge> shuttleEdges = new ArrayList<>();
for (HyperEdge edge : edges) {
HyperEdge shuttleEdge =
edge.accept(shuttle);
shuttleEdges.add(shuttleEdge);
}
return new HyperGraph(
getCluster(),
traitSet,
inputs,
notProjectInputs,
shuttleEdges,
rowType,
complexEdgesBitmap,
ccpUsedEdgesMap,
simpleEdgesMap,
complexEdgesMap,
overlapEdgesMap);
}
//~ hyper graph method ----------------------------------------------------------
public List<HyperEdge> getEdges() {
return edges;
}
public long getNotProjectInputs() {
return notProjectInputs;
}
public long getNeighborBitmap(long csg, long forbidden) {
long neighbors = 0L;
List<HyperEdge> simpleEdges = simpleEdgesMap.getOrDefault(csg, new BitSet()).stream()
.mapToObj(edges::get)
.collect(Collectors.toList());
for (HyperEdge edge : simpleEdges) {
neighbors |= edge.getEndpoint();
}
forbidden = forbidden | csg;
neighbors = neighbors & ~forbidden;
forbidden = forbidden | neighbors;
List<HyperEdge> complexEdges = complexEdgesMap.getOrDefault(csg, new BitSet()).stream()
.mapToObj(edges::get)
.collect(Collectors.toList());
for (HyperEdge edge : complexEdges) {
long leftBitmap = edge.getLeftEndpoint();
long rightBitmap = edge.getRightEndpoint();
if (LongBitmap.isSubSet(leftBitmap, csg) && !LongBitmap.isOverlap(rightBitmap, forbidden)) {
neighbors |= Long.lowestOneBit(rightBitmap);
} else if (LongBitmap.isSubSet(rightBitmap, csg)
&& !LongBitmap.isOverlap(leftBitmap, forbidden)) {
neighbors |= Long.lowestOneBit(leftBitmap);
}
}
return neighbors;
}
/**
* If csg and cmp are connected, return the edges that connect them.
*/
public List<HyperEdge> connectCsgCmp(long csg, long cmp) {
checkArgument(simpleEdgesMap.containsKey(csg));
checkArgument(simpleEdgesMap.containsKey(cmp));
List<HyperEdge> connectedEdges = new ArrayList<>();
BitSet connectedEdgesBitmap = new BitSet();
connectedEdgesBitmap.or(simpleEdgesMap.getOrDefault(csg, new BitSet()));
connectedEdgesBitmap.or(complexEdgesMap.getOrDefault(csg, new BitSet()));
BitSet cmpEdgesBitmap = new BitSet();
cmpEdgesBitmap.or(simpleEdgesMap.getOrDefault(cmp, new BitSet()));
cmpEdgesBitmap.or(complexEdgesMap.getOrDefault(cmp, new BitSet()));
connectedEdgesBitmap.and(cmpEdgesBitmap);
// only consider the records related to csg and cmp in the simpleEdgesMap/complexEdgesMap,
// may omit some complex hyper edges. e.g.
// csg = {t1, t3}, cmp = {t2}, will omit the edge (t1, t2)������(t3)
BitSet mayMissedEdges = new BitSet();
mayMissedEdges.or(complexEdgesBitmap.toBitSet());
mayMissedEdges.andNot(ccpUsedEdgesMap.getOrDefault(csg, new BitSet()));
mayMissedEdges.andNot(ccpUsedEdgesMap.getOrDefault(cmp, new BitSet()));
mayMissedEdges.andNot(connectedEdgesBitmap);
mayMissedEdges.stream()
.forEach(index -> {
HyperEdge edge = edges.get(index);
if (LongBitmap.isSubSet(edge.getEndpoint(), csg | cmp)) {
connectedEdgesBitmap.set(index);
}
});
// record hyper edges are used by current csg ��� cmp
BitSet curUsedEdges = new BitSet();
curUsedEdges.or(connectedEdgesBitmap);
curUsedEdges.or(ccpUsedEdgesMap.getOrDefault(csg, new BitSet()));
curUsedEdges.or(ccpUsedEdgesMap.getOrDefault(cmp, new BitSet()));
if (ccpUsedEdgesMap.containsKey(csg | cmp)) {
checkArgument(
curUsedEdges.equals(ccpUsedEdgesMap.get(csg | cmp)));
}
ccpUsedEdgesMap.put(csg | cmp, curUsedEdges);
connectedEdgesBitmap.stream()
.forEach(index -> connectedEdges.add(edges.get(index)));
return connectedEdges;
}
public void initEdgeBitMap(long subset) {
BitSet simpleBitSet = new BitSet();
BitSet complexBitSet = new BitSet();
BitSet overlapBitSet = new BitSet();
for (int i = 0; i < edges.size(); i++) {
HyperEdge edge = edges.get(i);
if (isAccurateEdge(edge, subset)) {
if (edge.isSimple()) {
simpleBitSet.set(i);
} else {
complexBitSet.set(i);
}
} else if (isOverlapEdge(edge, subset)) {
overlapBitSet.set(i);
}
}
simpleEdgesMap.put(subset, simpleBitSet);
complexEdgesMap.put(subset, complexBitSet);
overlapEdgesMap.put(subset, overlapBitSet);
}
public void updateEdgesForUnion(long subset1, long subset2) {
if (!simpleEdgesMap.containsKey(subset1)) {
initEdgeBitMap(subset1);
}
if (!simpleEdgesMap.containsKey(subset2)) {
initEdgeBitMap(subset2);
}
long unionSet = subset1 | subset2;
if (simpleEdgesMap.containsKey(unionSet)) {
return;
}
BitSet unionSimpleBitSet = new BitSet();
unionSimpleBitSet.or(simpleEdgesMap.getOrDefault(subset1, new BitSet()));
unionSimpleBitSet.or(simpleEdgesMap.getOrDefault(subset2, new BitSet()));
BitSet unionComplexBitSet = new BitSet();
unionComplexBitSet.or(complexEdgesMap.getOrDefault(subset1, new BitSet()));
unionComplexBitSet.or(complexEdgesMap.getOrDefault(subset2, new BitSet()));
BitSet unionOverlapBitSet = new BitSet();
unionOverlapBitSet.or(overlapEdgesMap.getOrDefault(subset1, new BitSet()));
unionOverlapBitSet.or(overlapEdgesMap.getOrDefault(subset2, new BitSet()));
// the overlaps edge that belongs to subset1/subset2
// may be complex edge for subset1 union subset2
for (int index : unionOverlapBitSet.stream().toArray()) {
HyperEdge edge = edges.get(index);
if (isAccurateEdge(edge, unionSet)) {
unionComplexBitSet.set(index);
unionOverlapBitSet.set(index, false);
}
}
// remove cycle in subset1 union subset2
for (int index : unionSimpleBitSet.stream().toArray()) {
HyperEdge edge = edges.get(index);
if (!isAccurateEdge(edge, unionSet)) {
unionSimpleBitSet.set(index, false);
}
}
for (int index : unionComplexBitSet.stream().toArray()) {
HyperEdge edge = edges.get(index);
if (!isAccurateEdge(edge, unionSet)) {
unionComplexBitSet.set(index, false);
}
}
simpleEdgesMap.put(unionSet, unionSimpleBitSet);
complexEdgesMap.put(unionSet, unionComplexBitSet);
overlapEdgesMap.put(unionSet, unionOverlapBitSet);
}
private static boolean isAccurateEdge(HyperEdge edge, long subset) {
boolean isLeftEnd = LongBitmap.isSubSet(edge.getLeftEndpoint(), subset)
&& !LongBitmap.isOverlap(edge.getRightEndpoint(), subset);
boolean isRightEnd = LongBitmap.isSubSet(edge.getRightEndpoint(), subset)
&& !LongBitmap.isOverlap(edge.getLeftEndpoint(), subset);
return isLeftEnd || isRightEnd;
}
private static boolean isOverlapEdge(HyperEdge edge, long subset) {
boolean isLeftEnd = LongBitmap.isOverlap(edge.getLeftEndpoint(), subset)
&& !LongBitmap.isOverlap(edge.getRightEndpoint(), subset);
boolean isRightEnd = LongBitmap.isOverlap(edge.getRightEndpoint(), subset)
&& !LongBitmap.isOverlap(edge.getLeftEndpoint(), subset);
return isLeftEnd || isRightEnd;
}
public @Nullable JoinRelType extractJoinType(List<HyperEdge> edges) {
JoinRelType joinType = edges.get(0).getJoinType();
for (int i = 1; i < edges.size(); i++) {
if (edges.get(i).getJoinType() != joinType) {
return null;
}
}
return joinType;
}
/**
* Check whether the csg-cmp pair is applicable through the conflict rule.
*
* @param csgcmp csg-cmp pair
* @param edges hyper edges with conflict rules
* @return true if the csg-cmp pair is applicable
*/
public boolean applicable(long csgcmp, List<HyperEdge> edges) {
for (HyperEdge edge : edges) {
if (!LongBitmap.isSubSet(edge.getEndpoint(), csgcmp)) {
return false;
}
for (ConflictRule conflictRule : edge.getConflictRules()) {
if (LongBitmap.isOverlap(csgcmp, conflictRule.from)
&& !LongBitmap.isSubSet(conflictRule.to, csgcmp)) {
// for conflict rule T1 ��� T2, if T1 ��� csgcmp != empty set, then T2 must be
// included in csgcmp
return false;
}
}
}
return true;
}
/**
* Build an RexNode expression for the predicate corresponding to a set of hyperedges.
*
* @param inputOrder node status ordered list for current csg-cmp
* @param leftCount number of tables in left child
* @param edges hyper edges
* @param joinType join type
* @return join condition
*/
public RexNode extractJoinCond(
List<NodeState> inputOrder,
int leftCount,
List<HyperEdge> edges,
JoinRelType joinType) {
// a map from (node index, field index) to input ref
Map<Pair<Integer, Integer>, Integer> nodeAndFieldIndex2InputRefMap = new HashMap<>();
int inputRef = 0;
for (int i = 0; i < inputOrder.size(); i++) {
NodeState nodeState = inputOrder.get(i);
if (nodeState.projected) {
int fieldCount = inputs.get(nodeState.nodeIndex).getRowType().getFieldCount();
for (int fieldIndex = 0; fieldIndex < fieldCount; fieldIndex++) {
nodeAndFieldIndex2InputRefMap.put(
Pair.of(nodeState.nodeIndex, fieldIndex), inputRef++);
}
}
}
RexShuttle nodeAndFieldIndex2InputRefShuttle = new RexShuttle() {
@Override public RexNode visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) {
Integer inputRef =
nodeAndFieldIndex2InputRefMap.get(
Pair.of(nodeAndFieldIndex.getNodeIndex(), nodeAndFieldIndex.getFieldIndex()));
assert inputRef != null;
return new RexInputRef(inputRef, nodeAndFieldIndex.getType());
}
};
List<RexNode> joinConds = new ArrayList<>();
for (HyperEdge edge : edges) {
RexNode inputRefCond = edge.getCondition().accept(nodeAndFieldIndex2InputRefShuttle);
joinConds.add(inputRefCond);
}
// update the node status of subgraph(csg-cmp pair) according to the join type.
if (!joinType.projectsRight()) {
for (int i = leftCount; i < inputOrder.size(); i++) {
int nodeIndex = inputOrder.get(i).nodeIndex;
inputOrder.set(i, new NodeState(nodeIndex, false));
}
}
return RexUtil.composeConjunction(getCluster().getRexBuilder(), joinConds);
}
/**
* Ensure that the fields produced by the reordered join are in the same order as in the original
* plan.
*
* @param resultOrder node status ordered list for the final result
* @param rowTypeList rowType of the final result
* @return list of RexInputRef
*/
public List<RexNode> restoreProjectionOrder(
List<NodeState> resultOrder,
List<RelDataTypeField> rowTypeList) {
// a map from node index to the number of fields projected before this node
Map<Integer, Integer> nodeIndexToFieldCountBefore = new HashMap<>();
int projectedFieldCount = 0;
for (NodeState nodeState : resultOrder) {
nodeIndexToFieldCountBefore.put(nodeState.nodeIndex, projectedFieldCount);
if (nodeState.projected) {
projectedFieldCount += inputs.get(nodeState.nodeIndex).getRowType().getFieldCount();
}
}
// origin inputs order is [n0, n1, n2]
// n0_projected [field0, field1]
// n1_projected [field0, field1]
// n2_projected [field0, field1]
// rowType is [n0.field0, n0.field1, n1.field0, n1.field1, n2.field0, n2.field1]
// assume that the original plan tree is:
// inner join
// / \
// inner join n2
// / \
// n0 n1
//
// if the node order of the final result is [n2, n1, n0], rowType is
// [n2.field0, n2.field1, n1.field0, n1.field1, n0.field0, n0.field1]
// assume that the reordered plan tree is:
// inner join
// / \
// inner join n0
// / \
// n2 n1
//
// we need a projection like [$4, $5, $2, $3, $0, $1]
List<RexNode> projects = new ArrayList<>();
for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) {
if (LongBitmap.isOverlap(notProjectInputs, LongBitmap.newBitmap(inputIndex))) {
continue;
}
int fieldCount = inputs.get(inputIndex).getRowType().getFieldCount();
for (int fieldIndex = 0; fieldIndex < fieldCount; fieldIndex++) {
Integer fieldOffset = nodeIndexToFieldCountBefore.get(inputIndex);
if (fieldOffset == null) {
throw new AssertionError(
"The result order looses the " + inputIndex + "-th input");
}
int inputRef = fieldIndex + fieldOffset;
projects.add(
new RexInputRef(inputRef, rowTypeList.get(inputRef).getType()));
}
}
return projects;
}
/**
* Record the projection state of vertices in the hypergraph during enumerating. It is used to
* calculate the index of RexInputRef when building join conditions from hyperedges.
*/
public static class NodeState {
final int nodeIndex;
final boolean projected;
NodeState(int nodeIndex, boolean projected) {
this.nodeIndex = nodeIndex;
this.projected = projected;
}
}
}