/src/duckdb/src/optimizer/cse_optimizer.cpp
Line | Count | Source |
1 | | #include "duckdb/optimizer/cse_optimizer.hpp" |
2 | | |
3 | | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
4 | | #include "duckdb/planner/expression_iterator.hpp" |
5 | | #include "duckdb/planner/operator/logical_projection.hpp" |
6 | | #include "duckdb/planner/column_binding_map.hpp" |
7 | | #include "duckdb/planner/binder.hpp" |
8 | | |
9 | | namespace duckdb { |
10 | | |
11 | | //! The CSENode contains information about a common subexpression; how many times it occurs, and the column index in the |
12 | | //! underlying projection |
13 | | struct CSENode { |
14 | | idx_t count; |
15 | | ProjectionIndex column_index; |
16 | | |
17 | 152k | CSENode() : count(1), column_index() { |
18 | 152k | } |
19 | | }; |
20 | | |
21 | | //! The CSEReplacementState |
22 | | struct CSEReplacementState { |
23 | | //! The projection index of the new projection |
24 | | TableIndex projection_index; |
25 | | //! Map of expression -> CSENode |
26 | | expression_map_t<CSENode> expression_count; |
27 | | //! Map of column bindings to column indexes in the projection expression list |
28 | | column_binding_map_t<ProjectionIndex> column_map; |
29 | | //! The set of expressions of the resulting projection |
30 | | vector<unique_ptr<Expression>> expressions; |
31 | | //! Cached expressions that are kept around so the expression_map always contains valid expressions |
32 | | vector<unique_ptr<Expression>> cached_expressions; |
33 | | //! Short circuit argument tracking |
34 | | bool short_circuited = false; |
35 | | }; |
36 | | |
37 | 506k | void CommonSubExpressionOptimizer::VisitOperator(LogicalOperator &op) { |
38 | 506k | switch (op.type) { |
39 | 185k | case LogicalOperatorType::LOGICAL_PROJECTION: |
40 | 242k | case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: |
41 | 242k | ExtractCommonSubExpresions(op); |
42 | 242k | break; |
43 | 264k | default: |
44 | 264k | break; |
45 | 506k | } |
46 | 506k | LogicalOperatorVisitor::VisitOperator(op); |
47 | 506k | } |
48 | | |
49 | 1.26M | void CommonSubExpressionOptimizer::CountExpressions(Expression &expr, CSEReplacementState &state) { |
50 | | // we only consider expressions with children for CSE elimination |
51 | 1.26M | switch (expr.GetExpressionClass()) { |
52 | 603k | case ExpressionClass::BOUND_COLUMN_REF: |
53 | 941k | case ExpressionClass::BOUND_CONSTANT: |
54 | 949k | case ExpressionClass::BOUND_PARAMETER: |
55 | 949k | return; |
56 | 318k | default: |
57 | 318k | break; |
58 | 1.26M | } |
59 | 318k | if (expr.GetExpressionClass() != ExpressionClass::BOUND_AGGREGATE && !expr.IsVolatile()) { |
60 | | // we can't move aggregates to a projection, so we only consider the children of the aggregate |
61 | 78.8k | auto node = state.expression_count.find(expr); |
62 | 78.8k | if (node == state.expression_count.end()) { |
63 | | // first time we encounter this expression, insert this node with [count = 1] |
64 | | // but only if it is not an interior argument of a short circuit sensitive expression. |
65 | 76.1k | if (!state.short_circuited) { |
66 | 76.0k | state.expression_count[expr] = CSENode(); |
67 | 76.0k | } |
68 | 76.1k | } else { |
69 | | // we encountered this expression before, increment the occurrence count |
70 | 2.71k | node->second.count++; |
71 | 2.71k | } |
72 | 78.8k | } |
73 | | |
74 | | // If we have a function that uses short circuiting, then we can only extract CSEs from the leftmost |
75 | | // side of the argument tree (child_no == 0) |
76 | 318k | switch (expr.GetExpressionClass()) { |
77 | 401 | case ExpressionClass::BOUND_CONJUNCTION: |
78 | 38.4k | case ExpressionClass::BOUND_CASE: { |
79 | | // Save the short circuit reference |
80 | 38.4k | const auto save_short_circuit = state.short_circuited; |
81 | 121k | ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { |
82 | 121k | CountExpressions(child, state); |
83 | 121k | state.short_circuited = true; |
84 | 121k | }); |
85 | 38.4k | state.short_circuited = save_short_circuit; |
86 | 38.4k | break; |
87 | 401 | } |
88 | 280k | default: |
89 | | // recursively count the children |
90 | 334k | ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { CountExpressions(child, state); }); |
91 | 280k | break; |
92 | 318k | } |
93 | 318k | } |
94 | | |
95 | 14.5k | void CommonSubExpressionOptimizer::PerformCSEReplacement(unique_ptr<Expression> &expr_ptr, CSEReplacementState &state) { |
96 | 14.5k | Expression &expr = *expr_ptr; |
97 | 14.5k | if (expr.GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { |
98 | 269 | auto &bound_column_ref = expr.Cast<BoundColumnRefExpression>(); |
99 | | // bound column ref, check if this one has already been recorded in the expression list |
100 | 269 | auto column_entry = state.column_map.find(bound_column_ref.binding); |
101 | 269 | if (column_entry == state.column_map.end()) { |
102 | | // not there yet: push the expression |
103 | 45 | auto new_col_ref = make_uniq<BoundColumnRefExpression>( |
104 | 45 | bound_column_ref.GetAlias(), bound_column_ref.return_type, bound_column_ref.binding); |
105 | 45 | auto new_column_index = ColumnBinding::PushExpression(state.expressions, std::move(new_col_ref)); |
106 | 45 | state.column_map[bound_column_ref.binding] = new_column_index; |
107 | 45 | bound_column_ref.binding = ColumnBinding(state.projection_index, new_column_index); |
108 | 224 | } else { |
109 | | // else: just update the column binding! |
110 | 224 | bound_column_ref.binding = ColumnBinding(state.projection_index, column_entry->second); |
111 | 224 | } |
112 | 269 | return; |
113 | 269 | } |
114 | | // check if this child is eligible for CSE elimination |
115 | 14.2k | if (state.expression_count.find(expr) != state.expression_count.end()) { |
116 | 8.09k | auto &node = state.expression_count[expr]; |
117 | 8.09k | if (node.count > 1) { |
118 | | // this expression occurs more than once! push it into the projection |
119 | | // check if it has already been pushed into the projection |
120 | 1.74k | auto alias = expr.GetAlias(); |
121 | 1.74k | auto type = expr.return_type; |
122 | 1.74k | if (!node.column_index.IsValid()) { |
123 | | // has not been pushed yet: push it |
124 | 732 | node.column_index = ColumnBinding::PushExpression(state.expressions, std::move(expr_ptr)); |
125 | 1.00k | } else { |
126 | 1.00k | state.cached_expressions.push_back(std::move(expr_ptr)); |
127 | 1.00k | } |
128 | | // replace the original expression with a bound column ref |
129 | 1.74k | expr_ptr = make_uniq<BoundColumnRefExpression>(alias, type, |
130 | 1.74k | ColumnBinding(state.projection_index, node.column_index)); |
131 | 1.74k | return; |
132 | 1.74k | } |
133 | 8.09k | } |
134 | | // this expression only occurs once, we can't perform CSE elimination |
135 | | // look into the children to see if we can replace them |
136 | 12.5k | ExpressionIterator::EnumerateChildren(expr, |
137 | 12.5k | [&](unique_ptr<Expression> &child) { PerformCSEReplacement(child, state); }); |
138 | 12.5k | } |
139 | | |
140 | 242k | void CommonSubExpressionOptimizer::ExtractCommonSubExpresions(LogicalOperator &op) { |
141 | 242k | D_ASSERT(op.children.size() == 1); |
142 | | |
143 | | // first we count for each expression with children how many types it occurs |
144 | 242k | CSEReplacementState state; |
145 | 242k | LogicalOperatorVisitor::EnumerateExpressions( |
146 | 811k | op, [&](unique_ptr<Expression> *child) { CountExpressions(**child, state); }); |
147 | | // check if there are any expressions to extract |
148 | 242k | bool perform_replacement = false; |
149 | 242k | for (auto &expr : state.expression_count) { |
150 | 69.7k | if (expr.second.count > 1) { |
151 | 444 | perform_replacement = true; |
152 | 444 | break; |
153 | 444 | } |
154 | 69.7k | } |
155 | 242k | if (!perform_replacement) { |
156 | | // no CSEs to extract |
157 | 241k | return; |
158 | 241k | } |
159 | 444 | state.projection_index = binder.GenerateTableIndex(); |
160 | | // we found common subexpressions to extract |
161 | | // now we iterate over all the expressions and perform the actual CSE elimination |
162 | | |
163 | 444 | LogicalOperatorVisitor::EnumerateExpressions( |
164 | 2.86k | op, [&](unique_ptr<Expression> *child) { PerformCSEReplacement(*child, state); }); |
165 | 444 | D_ASSERT(state.expressions.size() > 0); |
166 | | // create a projection node as the child of this node |
167 | 444 | auto projection = make_uniq<LogicalProjection>(state.projection_index, std::move(state.expressions)); |
168 | 444 | if (op.children[0]->has_estimated_cardinality) { |
169 | 444 | projection->SetEstimatedCardinality(op.children[0]->estimated_cardinality); |
170 | 444 | } |
171 | 444 | projection->children.push_back(std::move(op.children[0])); |
172 | 444 | op.children[0] = std::move(projection); |
173 | 444 | } |
174 | | |
175 | | } // namespace duckdb |