/src/duckdb/src/optimizer/unnest_rewriter.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | #include "duckdb/optimizer/unnest_rewriter.hpp" |
2 | | |
3 | | #include "duckdb/common/pair.hpp" |
4 | | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
5 | | #include "duckdb/planner/expression/bound_unnest_expression.hpp" |
6 | | #include "duckdb/planner/operator/logical_comparison_join.hpp" |
7 | | #include "duckdb/planner/operator/logical_delim_get.hpp" |
8 | | #include "duckdb/planner/operator/logical_projection.hpp" |
9 | | #include "duckdb/planner/operator/logical_unnest.hpp" |
10 | | #include "duckdb/planner/operator/logical_window.hpp" |
11 | | |
12 | | namespace duckdb { |
13 | | |
14 | 0 | void UnnestRewriterPlanUpdater::VisitOperator(LogicalOperator &op) { |
15 | 0 | VisitOperatorChildren(op); |
16 | 0 | VisitOperatorExpressions(op); |
17 | 0 | } |
18 | | |
19 | 0 | void UnnestRewriterPlanUpdater::VisitExpression(unique_ptr<Expression> *expression) { |
20 | 0 | auto &expr = *expression; |
21 | |
|
22 | 0 | if (expr->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { |
23 | 0 | auto &bound_column_ref = expr->Cast<BoundColumnRefExpression>(); |
24 | 0 | for (idx_t i = 0; i < replace_bindings.size(); i++) { |
25 | 0 | if (bound_column_ref.binding == replace_bindings[i].old_binding) { |
26 | 0 | bound_column_ref.binding = replace_bindings[i].new_binding; |
27 | 0 | break; |
28 | 0 | } |
29 | 0 | } |
30 | 0 | } |
31 | |
|
32 | 0 | VisitExpressionChildren(**expression); |
33 | 0 | } |
34 | | |
35 | 54.2k | unique_ptr<LogicalOperator> UnnestRewriter::Optimize(unique_ptr<LogicalOperator> op) { |
36 | | |
37 | 54.2k | UnnestRewriterPlanUpdater updater; |
38 | 54.2k | vector<reference<unique_ptr<LogicalOperator>>> candidates; |
39 | 54.2k | FindCandidates(op, candidates); |
40 | | |
41 | | // rewrite the plan and update the bindings |
42 | 54.2k | for (auto &candidate : candidates) { |
43 | | |
44 | | // rearrange the logical operators |
45 | 0 | if (RewriteCandidate(candidate)) { |
46 | 0 | updater.overwritten_tbl_idx = overwritten_tbl_idx; |
47 | | // update the bindings of the BOUND_UNNEST expression |
48 | 0 | UpdateBoundUnnestBindings(updater, candidate); |
49 | | // update the sequence of LOGICAL_PROJECTION(s) |
50 | 0 | UpdateRHSBindings(op, candidate, updater); |
51 | | // reset |
52 | 0 | delim_columns.clear(); |
53 | 0 | lhs_bindings.clear(); |
54 | 0 | } |
55 | 0 | } |
56 | | |
57 | 54.2k | return op; |
58 | 54.2k | } |
59 | | |
60 | | void UnnestRewriter::FindCandidates(unique_ptr<LogicalOperator> &op, |
61 | 587k | vector<reference<unique_ptr<LogicalOperator>>> &candidates) { |
62 | | // search children before adding, so that we add candidates bottom-up |
63 | 587k | for (auto &child : op->children) { |
64 | 533k | FindCandidates(child, candidates); |
65 | 533k | } |
66 | | |
67 | | // search for operator that has a LOGICAL_DELIM_JOIN as its child |
68 | 587k | if (op->children.size() != 1) { |
69 | 225k | return; |
70 | 225k | } |
71 | 362k | if (op->children[0]->type != LogicalOperatorType::LOGICAL_DELIM_JOIN) { |
72 | 362k | return; |
73 | 362k | } |
74 | | |
75 | | // found a delim join |
76 | 0 | auto &delim_join = op->children[0]->Cast<LogicalComparisonJoin>(); |
77 | | // only support INNER delim joins |
78 | 0 | if (delim_join.join_type != JoinType::INNER) { |
79 | 0 | return; |
80 | 0 | } |
81 | | // INNER delim join must have exactly one condition |
82 | 0 | if (delim_join.conditions.size() != 1) { |
83 | 0 | return; |
84 | 0 | } |
85 | | |
86 | | // LHS child is a window |
87 | 0 | idx_t delim_idx = delim_join.delim_flipped ? 1 : 0; |
88 | 0 | idx_t other_idx = 1 - delim_idx; |
89 | 0 | if (delim_join.children[delim_idx]->type != LogicalOperatorType::LOGICAL_WINDOW) { |
90 | 0 | return; |
91 | 0 | } |
92 | | |
93 | | // RHS child must be projection(s) followed by an UNNEST |
94 | 0 | auto curr_op = &delim_join.children[other_idx]; |
95 | 0 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
96 | 0 | if (curr_op->get()->children.size() != 1) { |
97 | 0 | break; |
98 | 0 | } |
99 | 0 | curr_op = &curr_op->get()->children[0]; |
100 | 0 | } |
101 | |
|
102 | 0 | if (curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST && |
103 | 0 | curr_op->get()->children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET) { |
104 | 0 | candidates.push_back(op); |
105 | 0 | } |
106 | 0 | } |
107 | | |
108 | 0 | bool UnnestRewriter::RewriteCandidate(unique_ptr<LogicalOperator> &candidate) { |
109 | |
|
110 | 0 | auto &topmost_op = *candidate; |
111 | 0 | if (topmost_op.type != LogicalOperatorType::LOGICAL_PROJECTION && |
112 | 0 | topmost_op.type != LogicalOperatorType::LOGICAL_WINDOW && |
113 | 0 | topmost_op.type != LogicalOperatorType::LOGICAL_FILTER && |
114 | 0 | topmost_op.type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY && |
115 | 0 | topmost_op.type != LogicalOperatorType::LOGICAL_UNNEST) { |
116 | 0 | return false; |
117 | 0 | } |
118 | | |
119 | | // get the LOGICAL_DELIM_JOIN, which is a child of the candidate |
120 | 0 | D_ASSERT(topmost_op.children.size() == 1); |
121 | 0 | auto &delim_join = topmost_op.children[0]->Cast<LogicalComparisonJoin>(); |
122 | 0 | D_ASSERT(delim_join.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); |
123 | 0 | GetDelimColumns(delim_join); |
124 | | |
125 | | // LHS of the LOGICAL_DELIM_JOIN is a LOGICAL_WINDOW that contains a LOGICAL_PROJECTION/LOGICAL_CROSS_JOIN |
126 | | // this lhs_proj later becomes the child of the UNNEST |
127 | |
|
128 | 0 | idx_t delim_idx = delim_join.delim_flipped ? 1 : 0; |
129 | 0 | idx_t other_idx = 1 - delim_idx; |
130 | 0 | auto &window = *delim_join.children[delim_idx]; |
131 | 0 | auto &lhs_op = window.children[0]; |
132 | 0 | GetLHSExpressions(*lhs_op); |
133 | | |
134 | | // find the LOGICAL_UNNEST |
135 | | // and get the path down to the LOGICAL_UNNEST |
136 | 0 | vector<unique_ptr<LogicalOperator> *> path_to_unnest; |
137 | 0 | auto curr_op = &delim_join.children[other_idx]; |
138 | 0 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
139 | 0 | path_to_unnest.push_back(curr_op); |
140 | 0 | curr_op = &curr_op->get()->children[0]; |
141 | 0 | } |
142 | | |
143 | | // store the table index of the child of the LOGICAL_UNNEST |
144 | | // then update the plan by making the lhs_proj the child of the LOGICAL_UNNEST |
145 | 0 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); |
146 | 0 | auto &unnest = curr_op->get()->Cast<LogicalUnnest>(); |
147 | 0 | D_ASSERT(unnest.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET); |
148 | 0 | overwritten_tbl_idx = unnest.children[0]->Cast<LogicalDelimGet>().table_index; |
149 | |
|
150 | 0 | D_ASSERT(!unnest.children.empty()); |
151 | 0 | auto &delim_get = unnest.children[0]->Cast<LogicalDelimGet>(); |
152 | 0 | D_ASSERT(delim_get.chunk_types.size() > 1); |
153 | 0 | distinct_unnest_count = delim_get.chunk_types.size(); |
154 | 0 | unnest.children[0] = std::move(lhs_op); |
155 | | |
156 | | // replace the LOGICAL_DELIM_JOIN with its RHS child operator |
157 | 0 | topmost_op.children[0] = std::move(*path_to_unnest.front()); |
158 | 0 | return true; |
159 | 0 | } |
160 | | |
161 | | void UnnestRewriter::UpdateRHSBindings(unique_ptr<LogicalOperator> &plan, unique_ptr<LogicalOperator> &candidate, |
162 | 0 | UnnestRewriterPlanUpdater &updater) { |
163 | |
|
164 | 0 | auto &topmost_op = *candidate; |
165 | 0 | idx_t shift = lhs_bindings.size(); |
166 | |
|
167 | 0 | vector<unique_ptr<LogicalOperator> *> path_to_unnest; |
168 | 0 | auto curr_op = &topmost_op.children[0]; |
169 | 0 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
170 | |
|
171 | 0 | path_to_unnest.push_back(curr_op); |
172 | 0 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); |
173 | 0 | auto &proj = curr_op->get()->Cast<LogicalProjection>(); |
174 | | |
175 | | // pop the unnest columns and the delim index |
176 | 0 | D_ASSERT(proj.expressions.size() > distinct_unnest_count); |
177 | 0 | for (idx_t i = 0; i < distinct_unnest_count; i++) { |
178 | 0 | proj.expressions.pop_back(); |
179 | 0 | } |
180 | | |
181 | | // store all shifted current bindings |
182 | 0 | idx_t tbl_idx = proj.table_index; |
183 | 0 | for (idx_t i = 0; i < proj.expressions.size(); i++) { |
184 | 0 | ReplaceBinding replace_binding(ColumnBinding(tbl_idx, i), ColumnBinding(tbl_idx, i + shift)); |
185 | 0 | updater.replace_bindings.push_back(replace_binding); |
186 | 0 | } |
187 | |
|
188 | 0 | curr_op = &curr_op->get()->children[0]; |
189 | 0 | } |
190 | | |
191 | | // update all bindings by shifting them |
192 | 0 | updater.VisitOperator(*plan); |
193 | 0 | updater.replace_bindings.clear(); |
194 | | |
195 | | // update all bindings coming from the LHS to RHS bindings |
196 | 0 | D_ASSERT(topmost_op.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION); |
197 | 0 | auto &top_proj = topmost_op.children[0]->Cast<LogicalProjection>(); |
198 | 0 | for (idx_t i = 0; i < lhs_bindings.size(); i++) { |
199 | 0 | ReplaceBinding replace_binding(lhs_bindings[i].binding, ColumnBinding(top_proj.table_index, i)); |
200 | 0 | updater.replace_bindings.push_back(replace_binding); |
201 | 0 | } |
202 | | |
203 | | // temporarily remove the BOUND_UNNESTs and the child of the LOGICAL_UNNEST from the plan |
204 | 0 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); |
205 | 0 | auto &unnest = curr_op->get()->Cast<LogicalUnnest>(); |
206 | 0 | vector<unique_ptr<Expression>> temp_bound_unnests; |
207 | 0 | for (auto &temp_bound_unnest : unnest.expressions) { |
208 | 0 | temp_bound_unnests.push_back(std::move(temp_bound_unnest)); |
209 | 0 | } |
210 | 0 | D_ASSERT(unnest.children.size() == 1); |
211 | 0 | auto temp_unnest_child = std::move(unnest.children[0]); |
212 | 0 | unnest.expressions.clear(); |
213 | 0 | unnest.children.clear(); |
214 | | // update the bindings of the plan |
215 | 0 | updater.VisitOperator(*plan); |
216 | 0 | updater.replace_bindings.clear(); |
217 | | // add the children again |
218 | 0 | for (auto &temp_bound_unnest : temp_bound_unnests) { |
219 | 0 | unnest.expressions.push_back(std::move(temp_bound_unnest)); |
220 | 0 | } |
221 | 0 | unnest.children.push_back(std::move(temp_unnest_child)); |
222 | | |
223 | | // add the LHS expressions to each LOGICAL_PROJECTION |
224 | 0 | for (idx_t i = path_to_unnest.size(); i > 0; i--) { |
225 | |
|
226 | 0 | D_ASSERT(path_to_unnest[i - 1]->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); |
227 | 0 | auto &proj = path_to_unnest[i - 1]->get()->Cast<LogicalProjection>(); |
228 | | |
229 | | // temporarily store the existing expressions |
230 | 0 | vector<unique_ptr<Expression>> existing_expressions; |
231 | 0 | for (idx_t expr_idx = 0; expr_idx < proj.expressions.size(); expr_idx++) { |
232 | 0 | existing_expressions.push_back(std::move(proj.expressions[expr_idx])); |
233 | 0 | } |
234 | |
|
235 | 0 | proj.expressions.clear(); |
236 | | |
237 | | // add the new expressions |
238 | 0 | for (idx_t expr_idx = 0; expr_idx < lhs_bindings.size(); expr_idx++) { |
239 | 0 | auto new_expr = make_uniq<BoundColumnRefExpression>( |
240 | 0 | lhs_bindings[expr_idx].alias, lhs_bindings[expr_idx].type, lhs_bindings[expr_idx].binding); |
241 | 0 | proj.expressions.push_back(std::move(new_expr)); |
242 | | |
243 | | // update the table index |
244 | 0 | lhs_bindings[expr_idx].binding.table_index = proj.table_index; |
245 | 0 | lhs_bindings[expr_idx].binding.column_index = expr_idx; |
246 | 0 | } |
247 | | |
248 | | // add the existing expressions again |
249 | 0 | for (idx_t expr_idx = 0; expr_idx < existing_expressions.size(); expr_idx++) { |
250 | 0 | proj.expressions.push_back(std::move(existing_expressions[expr_idx])); |
251 | 0 | } |
252 | 0 | } |
253 | 0 | } |
254 | | |
255 | | void UnnestRewriter::UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &updater, |
256 | 0 | unique_ptr<LogicalOperator> &candidate) { |
257 | |
|
258 | 0 | auto &topmost_op = *candidate; |
259 | | |
260 | | // traverse LOGICAL_PROJECTION(s) |
261 | 0 | auto curr_op = &topmost_op.children[0]; |
262 | 0 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
263 | 0 | curr_op = &curr_op->get()->children[0]; |
264 | 0 | } |
265 | | |
266 | | // found the LOGICAL_UNNEST |
267 | 0 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); |
268 | 0 | auto &unnest = curr_op->get()->Cast<LogicalUnnest>(); |
269 | |
|
270 | 0 | D_ASSERT(unnest.children.size() == 1); |
271 | 0 | auto unnest_cols = unnest.children[0]->GetColumnBindings(); |
272 | |
|
273 | 0 | for (idx_t i = 0; i < delim_columns.size(); i++) { |
274 | 0 | auto delim_binding = delim_columns[i]; |
275 | |
|
276 | 0 | auto unnest_it = unnest_cols.begin(); |
277 | 0 | while (unnest_it != unnest_cols.end()) { |
278 | 0 | auto unnest_binding = *unnest_it; |
279 | |
|
280 | 0 | if (delim_binding.table_index == unnest_binding.table_index) { |
281 | 0 | unnest_binding.table_index = overwritten_tbl_idx; |
282 | 0 | unnest_binding.column_index = i; |
283 | 0 | updater.replace_bindings.emplace_back(unnest_binding, delim_binding); |
284 | 0 | unnest_cols.erase(unnest_it); |
285 | 0 | break; |
286 | 0 | } |
287 | 0 | unnest_it++; |
288 | 0 | } |
289 | 0 | } |
290 | | |
291 | | // update bindings |
292 | 0 | for (auto &unnest_expr : unnest.expressions) { |
293 | 0 | updater.VisitExpression(&unnest_expr); |
294 | 0 | } |
295 | 0 | updater.replace_bindings.clear(); |
296 | 0 | } |
297 | | |
298 | 0 | void UnnestRewriter::GetDelimColumns(LogicalOperator &op) { |
299 | |
|
300 | 0 | D_ASSERT(op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); |
301 | 0 | auto &delim_join = op.Cast<LogicalComparisonJoin>(); |
302 | 0 | for (idx_t i = 0; i < delim_join.duplicate_eliminated_columns.size(); i++) { |
303 | 0 | auto &expr = *delim_join.duplicate_eliminated_columns[i]; |
304 | 0 | D_ASSERT(expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF); |
305 | 0 | auto &bound_colref_expr = expr.Cast<BoundColumnRefExpression>(); |
306 | 0 | delim_columns.push_back(bound_colref_expr.binding); |
307 | 0 | } |
308 | 0 | } |
309 | | |
310 | 0 | void UnnestRewriter::GetLHSExpressions(LogicalOperator &op) { |
311 | |
|
312 | 0 | op.ResolveOperatorTypes(); |
313 | 0 | auto col_bindings = op.GetColumnBindings(); |
314 | 0 | D_ASSERT(op.types.size() == col_bindings.size()); |
315 | |
|
316 | 0 | bool set_alias = false; |
317 | | // we can easily extract the alias for LOGICAL_PROJECTION(s) |
318 | 0 | if (op.type == LogicalOperatorType::LOGICAL_PROJECTION) { |
319 | 0 | auto &proj = op.Cast<LogicalProjection>(); |
320 | 0 | if (proj.expressions.size() == op.types.size()) { |
321 | 0 | set_alias = true; |
322 | 0 | } |
323 | 0 | } |
324 | |
|
325 | 0 | for (idx_t i = 0; i < op.types.size(); i++) { |
326 | 0 | lhs_bindings.emplace_back(col_bindings[i], op.types[i]); |
327 | 0 | if (set_alias) { |
328 | 0 | auto &proj = op.Cast<LogicalProjection>(); |
329 | 0 | lhs_bindings.back().alias = proj.expressions[i]->GetAlias(); |
330 | 0 | } |
331 | 0 | } |
332 | 0 | } |
333 | | |
334 | | } // namespace duckdb |