DpHyp.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.rel.rules;
import org.apache.calcite.linq4j.function.Experimental;
import org.apache.calcite.plan.PlanTooComplexError;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.trace.CalciteTrace;
import com.google.common.collect.ImmutableList;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* The core process of dphyp enumeration algorithm.
*/
@Experimental
public class DpHyp {
private static final Logger LOGGER = CalciteTrace.getDpHypJoinReorderTracer();
protected final HyperGraph hyperGraph;
private final Map<Long, RelNode> dpTable;
// a map from subgraph to the node list of the best join tree. The node list records the node
// index and whether it is projected, which is used to convert the RexNodeAndFieldIndex in
// hyperedge to the RexInputRef in join condition
private final Map<Long, ImmutableList<HyperGraph.NodeState>> resultInputOrder;
protected final RelBuilder builder;
private final RelMetadataQuery mq;
private final int bloat;
public DpHyp(HyperGraph hyperGraph, RelBuilder builder, RelMetadataQuery relMetadataQuery,
int bloat) {
this.hyperGraph =
hyperGraph.copy(
hyperGraph.getTraitSet(),
hyperGraph.getInputs());
this.dpTable = new HashMap<>();
this.resultInputOrder = new HashMap<>();
this.builder = builder;
this.mq = relMetadataQuery;
this.bloat = bloat;
}
/**
* The entry function of the algorithm. We use a bitmap to represent a leaf node,
* which indicates the position of the corresponding leaf node in {@link HyperGraph}.
*
* <p>After the enumeration is completed, the best join order will be stored
* in the {@link DpHyp#dpTable}.
*/
public void startEnumerateJoin() {
int size = hyperGraph.getInputs().size();
for (int i = 0; i < size; i++) {
long singleNode = LongBitmap.newBitmap(i);
LOGGER.debug("Initialize the dp table. Node {{}} is:\n {}",
i,
RelOptUtil.toString(hyperGraph.getInput(i)));
dpTable.put(singleNode, hyperGraph.getInput(i));
resultInputOrder.put(
singleNode,
ImmutableList.of(new HyperGraph.NodeState(i, true)));
hyperGraph.initEdgeBitMap(singleNode);
}
try {
// start enumerating from the second to last
for (int i = size - 2; i >= 0; i--) {
long csg = LongBitmap.newBitmap(i);
long forbidden = csg - 1;
emitCsg(csg);
enumerateCsgRec(csg, forbidden);
}
} catch (PlanTooComplexError e) {
LOGGER.error("The dp table is too large, and the enumeration ends automatically.");
}
}
/**
* Given a connected subgraph (csg), enumerate all possible complements subgraph (cmp)
* that do not include anything from the exclusion subset.
*
* <p>Corresponding to EmitCsg in origin paper.
*/
private void emitCsg(long csg) {
long forbidden = csg | LongBitmap.getBvBitmap(csg);
long neighbors = hyperGraph.getNeighborBitmap(csg, forbidden);
LongBitmap.ReverseIterator reverseIterator = new LongBitmap.ReverseIterator(neighbors);
for (long cmp : reverseIterator) {
List<HyperEdge> edges = hyperGraph.connectCsgCmp(csg, cmp);
if (!edges.isEmpty()) {
emitCsgCmp(csg, cmp, edges);
}
// forbidden the nodes that smaller than current cmp when extend cmp, e.g.
// neighbors = {t1, t2}, t1 and t2 are connected.
// when extented t2, we will get (t1, t2)
// when extented t1, we will get (t1, t2) repeated
long newForbidden =
(cmp | LongBitmap.getBvBitmap(cmp)) & neighbors;
newForbidden = newForbidden | forbidden;
enumerateCmpRec(csg, cmp, newForbidden);
}
}
/**
* Given a connected subgraph (csg), expands it recursively by its neighbors.
* If the expanded csg is connected, try to enumerate its cmp (note that for complex hyperedge,
* we only select a single representative node to add to the neighbors, so csg and subNeighbor
* are not necessarily connected. However, it still needs to be expanded to prevent missing
* complex hyperedge). This method is called after the enumeration of csg is completed,
* that is, after {@link DpHyp#emitCsg(long csg)}.
*
* <p>Corresponding to EnumerateCsgRec in origin paper.
*/
private void enumerateCsgRec(long csg, long forbidden) {
long neighbors = hyperGraph.getNeighborBitmap(csg, forbidden);
LongBitmap.SubsetIterator subsetIterator = new LongBitmap.SubsetIterator(neighbors);
for (long subNeighbor : subsetIterator) {
hyperGraph.updateEdgesForUnion(csg, subNeighbor);
long newCsg = csg | subNeighbor;
if (dpTable.containsKey(newCsg)) {
emitCsg(newCsg);
}
}
long newForbidden = forbidden | neighbors;
subsetIterator.reset();
for (long subNeighbor : subsetIterator) {
long newCsg = csg | subNeighbor;
enumerateCsgRec(newCsg, newForbidden);
}
}
/**
* Given a connected subgraph (csg) and its complement subgraph (cmp), expands the cmp
* recursively by neighbors of cmp (cmp and subNeighbor are not necessarily connected,
* which is the same logic as in {@link DpHyp#enumerateCsgRec}).
*
* <p>Corresponding to EnumerateCmpRec in origin paper.
*/
private void enumerateCmpRec(long csg, long cmp, long forbidden) {
long neighbors = hyperGraph.getNeighborBitmap(cmp, forbidden);
LongBitmap.SubsetIterator subsetIterator = new LongBitmap.SubsetIterator(neighbors);
for (long subNeighbor : subsetIterator) {
long newCmp = cmp | subNeighbor;
hyperGraph.updateEdgesForUnion(cmp, subNeighbor);
if (dpTable.containsKey(newCmp)) {
List<HyperEdge> edges = hyperGraph.connectCsgCmp(csg, newCmp);
if (!edges.isEmpty()) {
emitCsgCmp(csg, newCmp, edges);
}
}
}
long newForbidden = forbidden | neighbors;
subsetIterator.reset();
for (long subNeighbor : subsetIterator) {
long newCmp = cmp | subNeighbor;
enumerateCmpRec(csg, newCmp, newForbidden);
}
}
/**
* Given a connected csg-cmp pair and the hyperedges that connect them, build the
* corresponding Join plan. If the new Join plan is better than the existing plan,
* update the {@link DpHyp#dpTable}.
*
* <p>Corresponding to EmitCsgCmp in origin paper.
*/
private void emitCsgCmp(long csg, long cmp, List<HyperEdge> edges) {
RelNode child1 = dpTable.get(csg);
RelNode child2 = dpTable.get(cmp);
ImmutableList<HyperGraph.NodeState> csgOrder = resultInputOrder.get(csg);
ImmutableList<HyperGraph.NodeState> cmpOrder = resultInputOrder.get(cmp);
assert child1 != null && child2 != null && csgOrder != null && cmpOrder != null;
assert Long.bitCount(csg) == csgOrder.size() && Long.bitCount(cmp) == cmpOrder.size();
JoinRelType joinType = hyperGraph.extractJoinType(edges);
if (joinType == null) {
return;
}
// verify whether the subgraph is legal by using the conflict rules in hyperedges
if (!hyperGraph.applicable(csg | cmp, edges)) {
return;
}
List<HyperGraph.NodeState> unionOrder = new ArrayList<>(csgOrder);
unionOrder.addAll(cmpOrder);
// build join condition from hyperedges. e.g.
// case.1
// csg: node0_projected [field0, field1], node1_projected [field0, field1],
//
// join
// / \
// node0 node1
//
// cmp: node2_projected [field0, field1]
// hyperedge1: node0.field0 = node2.field0
// hyperedge2: node1.field1 = node2.field1
// we will get join condition: ($0 = $4) and ($3 = $5)
//
// new_join(condition=[AND(($0 = $4), ($3 = $5))])
// / \
// join node2
// / \
// node0 node1
//
// case.2
// csg: node0_projected [field0, field1], node1_not_projected [field0, field1],
//
// join(joinType=semi/anti)
// / \
// node0 node1
//
// cmp: node2_projected [field0, field1]
// hyperedge1: node0.field0 = node2.field0
// hyperedge2: node0.field1 = node2.field1
// we will get join condition: ($0 = $2) and ($1 = $3)
//
// new_join(condition=[AND(($0 = $2), ($1 = $3))])
// / \
// join(joinType=semi/anti) node2
// / \
// node0 node1
RexNode joinCond1 = hyperGraph.extractJoinCond(unionOrder, csgOrder.size(), edges, joinType);
RelNode newPlan1 = builder
.push(child1)
.push(child2)
.join(joinType, joinCond1)
.build();
RelNode winPlan = newPlan1;
ImmutableList<HyperGraph.NodeState> winOrder = ImmutableList.copyOf(unionOrder);
assert verifyDpResultRowType(newPlan1, unionOrder);
if (ConflictDetectionHelper.isCommutative(joinType)) {
// swap left and right
unionOrder = new ArrayList<>(cmpOrder);
unionOrder.addAll(csgOrder);
RexNode joinCond2 = hyperGraph.extractJoinCond(unionOrder, cmpOrder.size(), edges, joinType);
RelNode newPlan2 = builder
.push(child2)
.push(child1)
.join(joinType, joinCond2)
.build();
winPlan = chooseBetterPlan(winPlan, newPlan2);
assert verifyDpResultRowType(newPlan2, unionOrder);
if (winPlan.equals(newPlan2)) {
winOrder = ImmutableList.copyOf(unionOrder);
}
}
LOGGER.debug("Found set {} and {}, connected by condition {}. [cost={}, rows={}]",
LongBitmap.printBitmap(csg),
LongBitmap.printBitmap(cmp),
RexUtil.composeConjunction(
builder.getRexBuilder(),
edges.stream()
.map(edge -> edge.getCondition()).collect(Collectors.toList())),
mq.getCumulativeCost(winPlan),
mq.getRowCount(winPlan));
RelNode oriPlan = dpTable.get(csg | cmp);
boolean dpTableUpdated = true;
if (oriPlan != null) {
winPlan = chooseBetterPlan(winPlan, oriPlan);
if (winPlan.equals(oriPlan)) {
winOrder = resultInputOrder.get(csg | cmp);
dpTableUpdated = false;
}
} else {
// when enumerating a new connected subgraph, check whether the dpTable size is too large
if (dpTable.size() > bloat) {
throw new PlanTooComplexError();
}
}
assert winOrder != null;
if (dpTableUpdated) {
LOGGER.debug("Dp table is updated. The better plan for subgraph {} now is:\n {}",
LongBitmap.printBitmap(csg | cmp),
RelOptUtil.toString(winPlan));
}
dpTable.put(csg | cmp, winPlan);
resultInputOrder.put(csg | cmp, winOrder);
}
public @Nullable RelNode getBestPlan() {
int size = hyperGraph.getInputs().size();
long wholeGraph = LongBitmap.newBitmapBetween(0, size);
RelNode orderedJoin = dpTable.get(wholeGraph);
if (orderedJoin == null) {
LOGGER.error("The optimal plan was not generated because the enumeration ended prematurely");
return null;
}
LOGGER.debug("Enumeration completed. The best plan is:\n {}", RelOptUtil.toString(orderedJoin));
ImmutableList<HyperGraph.NodeState> resultOrder = resultInputOrder.get(wholeGraph);
assert resultOrder != null && resultOrder.size() == size;
// ensure that the fields produced by the reordered join are in the same order as in the
// original plan.
List<RexNode> projects =
hyperGraph.restoreProjectionOrder(resultOrder,
orderedJoin.getRowType().getFieldList());
return builder
.push(orderedJoin)
.project(projects)
.build();
}
private RelNode chooseBetterPlan(RelNode plan1, RelNode plan2) {
RelOptCost cost1 = mq.getCumulativeCost(plan1);
RelOptCost cost2 = mq.getCumulativeCost(plan2);
if (cost1 != null && cost2 != null) {
return cost1.isLt(cost2) ? plan1 : plan2;
} else if (cost1 != null) {
return plan1;
} else {
return plan2;
}
}
/**
* Verify that the row type of plans generated by dphyp is equivalent to the origin plan.
*
* @param plan plan generated by dphyp
* @param resultOrder node status ordered list
* @return true if the plan row type equivalent to the hyperGraph row type
*/
protected boolean verifyDpResultRowType(RelNode plan, List<HyperGraph.NodeState> resultOrder) {
// only verify the whole graph
if (resultOrder.size() != hyperGraph.getInputs().size()) {
return true;
}
List<RexNode> projects =
hyperGraph.restoreProjectionOrder(resultOrder,
plan.getRowType().getFieldList());
RelNode resultNode = builder
.push(plan)
.project(projects)
.build();
return RelOptUtil.areRowTypesEqual(resultNode.getRowType(), hyperGraph.getRowType(), false);
}
}