/src/duckdb/src/optimizer/remove_unused_columns.cpp
Line | Count | Source |
1 | | #include "duckdb/optimizer/remove_unused_columns.hpp" |
2 | | |
3 | | #include "duckdb/function/aggregate/distributive_functions.hpp" |
4 | | #include "duckdb/function/function_binder.hpp" |
5 | | #include "duckdb/parser/parsed_data/vacuum_info.hpp" |
6 | | #include "duckdb/planner/binder.hpp" |
7 | | #include "duckdb/planner/column_binding_map.hpp" |
8 | | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
9 | | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
10 | | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
11 | | #include "duckdb/planner/expression/bound_function_expression.hpp" |
12 | | #include "duckdb/planner/expression_iterator.hpp" |
13 | | #include "duckdb/planner/operator/logical_aggregate.hpp" |
14 | | #include "duckdb/planner/operator/logical_comparison_join.hpp" |
15 | | #include "duckdb/planner/operator/logical_distinct.hpp" |
16 | | #include "duckdb/planner/operator/logical_filter.hpp" |
17 | | #include "duckdb/planner/operator/logical_get.hpp" |
18 | | #include "duckdb/planner/operator/logical_order.hpp" |
19 | | #include "duckdb/planner/operator/logical_projection.hpp" |
20 | | #include "duckdb/planner/operator/logical_set_operation.hpp" |
21 | | #include "duckdb/planner/operator/logical_simple.hpp" |
22 | | #include "duckdb/function/scalar/struct_utils.hpp" |
23 | | |
24 | | namespace duckdb { |
25 | | |
26 | 203 | void BaseColumnPruner::ReplaceBinding(ColumnBinding current_binding, ColumnBinding new_binding) { |
27 | 203 | auto colrefs = column_references.find(current_binding); |
28 | 203 | if (colrefs != column_references.end()) { |
29 | 693 | for (auto &colref_p : colrefs->second.bindings) { |
30 | 693 | auto &colref = colref_p.get(); |
31 | 693 | D_ASSERT(colref.binding == current_binding); |
32 | 693 | colref.binding = new_binding; |
33 | 693 | } |
34 | 203 | } |
35 | 203 | } |
36 | | |
37 | | template <class T> |
38 | 37.1k | void RemoveUnusedColumns::ClearUnusedExpressions(vector<T> &list, idx_t table_idx, bool replace) { |
39 | 37.1k | idx_t offset = 0; |
40 | 190k | for (idx_t col_idx = 0; col_idx < list.size(); col_idx++) { |
41 | 153k | auto current_binding = ColumnBinding(table_idx, col_idx + offset); |
42 | 153k | auto entry = column_references.find(current_binding); |
43 | 153k | if (entry == column_references.end()) { |
44 | | // this entry is not referred to, erase it from the set of expressions |
45 | 3.74k | list.erase_at(col_idx); |
46 | 3.74k | offset++; |
47 | 3.74k | col_idx--; |
48 | 149k | } else if (offset > 0 && replace) { |
49 | | // column is used but the ColumnBinding has changed because of removed columns |
50 | 203 | ReplaceBinding(current_binding, ColumnBinding(table_idx, col_idx)); |
51 | 203 | } |
52 | 153k | } |
53 | 37.1k | } void duckdb::RemoveUnusedColumns::ClearUnusedExpressions<duckdb::unique_ptr<duckdb::Expression, std::__1::default_delete<duckdb::Expression>, true> >(duckdb::vector<duckdb::unique_ptr<duckdb::Expression, std::__1::default_delete<duckdb::Expression>, true>, true, std::__1::allocator<duckdb::unique_ptr<duckdb::Expression, std::__1::default_delete<duckdb::Expression>, true> > >&, unsigned long, bool) Line | Count | Source | 38 | 23.9k | void RemoveUnusedColumns::ClearUnusedExpressions(vector<T> &list, idx_t table_idx, bool replace) { | 39 | 23.9k | idx_t offset = 0; | 40 | 99.2k | for (idx_t col_idx = 0; col_idx < list.size(); col_idx++) { | 41 | 75.2k | auto current_binding = ColumnBinding(table_idx, col_idx + offset); | 42 | 75.2k | auto entry = column_references.find(current_binding); | 43 | 75.2k | if (entry == column_references.end()) { | 44 | | // this entry is not referred to, erase it from the set of expressions | 45 | 3.38k | list.erase_at(col_idx); | 46 | 3.38k | offset++; | 47 | 3.38k | col_idx--; | 48 | 71.8k | } else if (offset > 0 && replace) { | 49 | | // column is used but the ColumnBinding has changed because of removed columns | 50 | 203 | ReplaceBinding(current_binding, ColumnBinding(table_idx, col_idx)); | 51 | 203 | } | 52 | 75.2k | } | 53 | 23.9k | } |
void duckdb::RemoveUnusedColumns::ClearUnusedExpressions<unsigned long>(duckdb::vector<unsigned long, true, std::__1::allocator<unsigned long> >&, unsigned long, bool) Line | Count | Source | 38 | 13.2k | void RemoveUnusedColumns::ClearUnusedExpressions(vector<T> &list, idx_t table_idx, bool replace) { | 39 | 13.2k | idx_t offset = 0; | 40 | 91.5k | for (idx_t col_idx = 0; col_idx < list.size(); col_idx++) { | 41 | 78.3k | auto current_binding = ColumnBinding(table_idx, col_idx + offset); | 42 | 78.3k | auto entry = column_references.find(current_binding); | 43 | 78.3k | if (entry == column_references.end()) { | 44 | | // this entry is not referred to, erase it from the set of expressions | 45 | 364 | list.erase_at(col_idx); | 46 | 364 | offset++; | 47 | 364 | col_idx--; | 48 | 78.0k | } else if (offset > 0 && replace) { | 49 | | // column is used but the ColumnBinding has changed because of removed columns | 50 | 0 | ReplaceBinding(current_binding, ColumnBinding(table_idx, col_idx)); | 51 | 0 | } | 52 | 78.3k | } | 53 | 13.2k | } |
|
54 | | |
55 | 286k | void RemoveUnusedColumns::VisitOperator(LogicalOperator &op) { |
56 | 286k | switch (op.type) { |
57 | 9.77k | case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { |
58 | | // aggregate |
59 | 9.77k | auto &aggr = op.Cast<LogicalAggregate>(); |
60 | | // if there is more than one grouping set, the group by most likely has a rollup or cube |
61 | | // If there is an equality join underneath the aggregate, this can change the groups to avoid unused columns |
62 | | // This causes the duplicate eliminator to ignore functionality provided by grouping sets |
63 | 9.77k | bool new_root = false; |
64 | 9.77k | if (aggr.grouping_sets.size() > 1) { |
65 | 0 | new_root = true; |
66 | 0 | } |
67 | 9.77k | if (!everything_referenced && !new_root) { |
68 | | // FIXME: groups that are not referenced need to stay -> but they don't need to be scanned and output! |
69 | 8.13k | ClearUnusedExpressions(aggr.expressions, aggr.aggregate_index); |
70 | 8.13k | if (aggr.expressions.empty() && aggr.groups.empty()) { |
71 | | // removed all expressions from the aggregate: push a COUNT(*) |
72 | 0 | auto count_star_fun = CountStarFun::GetFunction(); |
73 | 0 | FunctionBinder function_binder(context); |
74 | 0 | aggr.expressions.push_back( |
75 | 0 | function_binder.BindAggregateFunction(count_star_fun, {}, nullptr, AggregateType::NON_DISTINCT)); |
76 | 0 | } |
77 | 8.13k | } |
78 | | |
79 | | // then recurse into the children of the aggregate |
80 | 9.77k | RemoveUnusedColumns remove(binder, context, new_root); |
81 | 9.77k | remove.VisitOperatorExpressions(op); |
82 | 9.77k | remove.VisitOperator(*op.children[0]); |
83 | 9.77k | return; |
84 | 0 | } |
85 | 0 | case LogicalOperatorType::LOGICAL_ASOF_JOIN: |
86 | 0 | case LogicalOperatorType::LOGICAL_DELIM_JOIN: |
87 | 44 | case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { |
88 | 44 | if (everything_referenced) { |
89 | 0 | break; |
90 | 0 | } |
91 | 44 | auto &comp_join = op.Cast<LogicalComparisonJoin>(); |
92 | | |
93 | 44 | if (comp_join.join_type != JoinType::INNER) { |
94 | 44 | break; |
95 | 44 | } |
96 | | // for inner joins with equality predicates in the form of (X=Y) |
97 | | // we can replace any references to the RHS (Y) to references to the LHS (X) |
98 | | // this reduces the amount of columns we need to extract from the join hash table |
99 | | // (except in the case of floating point numbers which have +0 and -0, equal but different). |
100 | 0 | for (auto &cond : comp_join.conditions) { |
101 | 0 | if (cond.comparison != ExpressionType::COMPARE_EQUAL) { |
102 | 0 | continue; |
103 | 0 | } |
104 | 0 | if (cond.left->GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF) { |
105 | 0 | continue; |
106 | 0 | } |
107 | 0 | if (cond.right->GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF) { |
108 | 0 | continue; |
109 | 0 | } |
110 | 0 | if (cond.left->Cast<BoundColumnRefExpression>().return_type.IsFloating()) { |
111 | 0 | continue; |
112 | 0 | } |
113 | 0 | if (cond.right->Cast<BoundColumnRefExpression>().return_type.IsFloating()) { |
114 | 0 | continue; |
115 | 0 | } |
116 | | // comparison join between two bound column refs |
117 | | // we can replace any reference to the RHS (build-side) with a reference to the LHS (probe-side) |
118 | 0 | auto &lhs_col = cond.left->Cast<BoundColumnRefExpression>(); |
119 | 0 | auto &rhs_col = cond.right->Cast<BoundColumnRefExpression>(); |
120 | | // if there are any columns that refer to the RHS, |
121 | 0 | auto colrefs = column_references.find(rhs_col.binding); |
122 | 0 | if (colrefs == column_references.end()) { |
123 | 0 | continue; |
124 | 0 | } |
125 | 0 | for (auto &entry : colrefs->second.bindings) { |
126 | 0 | auto &colref = entry.get(); |
127 | 0 | colref.binding = lhs_col.binding; |
128 | 0 | AddBinding(colref); |
129 | 0 | } |
130 | 0 | column_references.erase(rhs_col.binding); |
131 | 0 | } |
132 | 0 | break; |
133 | 44 | } |
134 | 0 | case LogicalOperatorType::LOGICAL_ANY_JOIN: |
135 | 0 | break; |
136 | 15.1k | case LogicalOperatorType::LOGICAL_UNION: { |
137 | 15.1k | auto &setop = op.Cast<LogicalSetOperation>(); |
138 | 15.1k | if (setop.setop_all && !everything_referenced) { |
139 | | // for UNION we can remove unreferenced columns if union all is used |
140 | | // it's possible not all columns are referenced, but unreferenced columns in the union can |
141 | | // still have an affect on the result of the union |
142 | 13.2k | vector<idx_t> entries; |
143 | 91.5k | for (idx_t i = 0; i < setop.column_count; i++) { |
144 | 78.3k | entries.push_back(i); |
145 | 78.3k | } |
146 | 13.2k | ClearUnusedExpressions(entries, setop.table_index); |
147 | 13.2k | if (entries.size() >= setop.column_count) { |
148 | 13.1k | return; |
149 | 13.1k | } |
150 | 91 | if (entries.empty()) { |
151 | | // no columns referenced: this happens in the case of a COUNT(*) |
152 | | // extract the first column |
153 | 91 | entries.push_back(0); |
154 | 91 | } |
155 | | // columns were cleared |
156 | 91 | setop.column_count = entries.size(); |
157 | | |
158 | 728 | for (idx_t child_idx = 0; child_idx < op.children.size(); child_idx++) { |
159 | 637 | RemoveUnusedColumns remove(binder, context, true); |
160 | 637 | auto &child = op.children[child_idx]; |
161 | | |
162 | | // we push a projection under this child that references the required columns of the union |
163 | 637 | child->ResolveOperatorTypes(); |
164 | 637 | auto bindings = child->GetColumnBindings(); |
165 | 637 | vector<unique_ptr<Expression>> expressions; |
166 | 637 | expressions.reserve(entries.size()); |
167 | 637 | for (auto &column_idx : entries) { |
168 | 637 | expressions.push_back( |
169 | 637 | make_uniq<BoundColumnRefExpression>(child->types[column_idx], bindings[column_idx])); |
170 | 637 | } |
171 | 637 | auto new_projection = make_uniq<LogicalProjection>(binder.GenerateTableIndex(), std::move(expressions)); |
172 | 637 | if (child->has_estimated_cardinality) { |
173 | 637 | new_projection->SetEstimatedCardinality(child->estimated_cardinality); |
174 | 637 | } |
175 | 637 | new_projection->children.push_back(std::move(child)); |
176 | 637 | op.children[child_idx] = std::move(new_projection); |
177 | | |
178 | 637 | remove.VisitOperator(*op.children[child_idx]); |
179 | 637 | } |
180 | 91 | return; |
181 | 13.2k | } |
182 | 16.7k | for (auto &child : op.children) { |
183 | 16.7k | RemoveUnusedColumns remove(binder, context, true); |
184 | 16.7k | remove.VisitOperator(*child); |
185 | 16.7k | } |
186 | 1.93k | return; |
187 | 15.1k | } |
188 | 4.89k | case LogicalOperatorType::LOGICAL_EXCEPT: |
189 | 4.89k | case LogicalOperatorType::LOGICAL_INTERSECT: { |
190 | | // for INTERSECT/EXCEPT operations we can't remove anything, just recursively visit the children |
191 | 9.78k | for (auto &child : op.children) { |
192 | 9.78k | RemoveUnusedColumns remove(binder, context, true); |
193 | 9.78k | remove.VisitOperator(*child); |
194 | 9.78k | } |
195 | 4.89k | return; |
196 | 4.89k | } |
197 | 109k | case LogicalOperatorType::LOGICAL_PROJECTION: { |
198 | 109k | if (!everything_referenced) { |
199 | 15.8k | auto &proj = op.Cast<LogicalProjection>(); |
200 | 15.8k | ClearUnusedExpressions(proj.expressions, proj.table_index); |
201 | | |
202 | 15.8k | if (proj.expressions.empty()) { |
203 | | // nothing references the projected expressions |
204 | | // this happens in the case of e.g. EXISTS(SELECT * FROM ...) |
205 | | // in this case we only need to project a single constant |
206 | 21 | proj.expressions.push_back(make_uniq<BoundConstantExpression>(Value::INTEGER(42))); |
207 | 21 | } |
208 | 15.8k | } |
209 | | // then recurse into the children of this projection |
210 | 109k | RemoveUnusedColumns remove(binder, context); |
211 | 109k | remove.VisitOperatorExpressions(op); |
212 | 109k | remove.VisitOperator(*op.children[0]); |
213 | 109k | return; |
214 | 4.89k | } |
215 | 0 | case LogicalOperatorType::LOGICAL_INSERT: |
216 | 0 | case LogicalOperatorType::LOGICAL_UPDATE: |
217 | 0 | case LogicalOperatorType::LOGICAL_DELETE: |
218 | 0 | case LogicalOperatorType::LOGICAL_MERGE_INTO: { |
219 | | //! When RETURNING is used, a PROJECTION is the top level operator for INSERTS, UPDATES, and DELETES |
220 | | //! We still need to project all values from these operators so the projection |
221 | | //! on top of them can select from only the table values being inserted. |
222 | | //! TODO: Push down the projections from the returning statement |
223 | | //! TODO: Be careful because you might be adding expressions when a user returns * |
224 | 0 | RemoveUnusedColumns remove(binder, context, true); |
225 | 0 | remove.VisitOperatorExpressions(op); |
226 | 0 | remove.VisitOperator(*op.children[0]); |
227 | 0 | return; |
228 | 0 | } |
229 | 735 | case LogicalOperatorType::LOGICAL_GET: { |
230 | 735 | LogicalOperatorVisitor::VisitOperatorExpressions(op); |
231 | 735 | if (everything_referenced) { |
232 | 0 | return; |
233 | 0 | } |
234 | 735 | auto &get = op.Cast<LogicalGet>(); |
235 | 735 | if (!get.function.projection_pushdown) { |
236 | 735 | return; |
237 | 735 | } |
238 | | |
239 | 0 | auto final_column_ids = get.GetColumnIds(); |
240 | | |
241 | | // Create "selection vector" of all column ids |
242 | 0 | vector<idx_t> proj_sel; |
243 | 0 | for (idx_t col_idx = 0; col_idx < final_column_ids.size(); col_idx++) { |
244 | 0 | proj_sel.push_back(col_idx); |
245 | 0 | } |
246 | | // Create a copy that we can use to match ids later |
247 | 0 | auto col_sel = proj_sel; |
248 | | // Clear unused ids, exclude filter columns that are projected out immediately |
249 | 0 | ClearUnusedExpressions(proj_sel, get.table_index, false); |
250 | |
|
251 | 0 | vector<unique_ptr<Expression>> filter_expressions; |
252 | | // for every table filter, push a column binding into the column references map to prevent the column from |
253 | | // being projected out |
254 | 0 | for (auto &filter : get.table_filters.filters) { |
255 | 0 | optional_idx index; |
256 | 0 | for (idx_t i = 0; i < final_column_ids.size(); i++) { |
257 | 0 | if (final_column_ids[i].GetPrimaryIndex() == filter.first) { |
258 | 0 | index = i; |
259 | 0 | break; |
260 | 0 | } |
261 | 0 | } |
262 | 0 | if (!index.IsValid()) { |
263 | 0 | throw InternalException("Could not find column index for table filter"); |
264 | 0 | } |
265 | | |
266 | 0 | auto column_type = get.GetColumnType(ColumnIndex(filter.first)); |
267 | |
|
268 | 0 | ColumnBinding filter_binding(get.table_index, index.GetIndex()); |
269 | 0 | auto column_ref = make_uniq<BoundColumnRefExpression>(std::move(column_type), filter_binding); |
270 | 0 | auto filter_expr = filter.second->ToExpression(*column_ref); |
271 | 0 | if (filter_expr->IsScalar()) { |
272 | 0 | filter_expr = std::move(column_ref); |
273 | 0 | } |
274 | 0 | VisitExpression(&filter_expr); |
275 | 0 | filter_expressions.push_back(std::move(filter_expr)); |
276 | 0 | } |
277 | | |
278 | | // Clear unused ids, include filter columns that are projected out immediately |
279 | 0 | ClearUnusedExpressions(col_sel, get.table_index); |
280 | | |
281 | | // Now set the column ids in the LogicalGet using the "selection vector" |
282 | 0 | vector<ColumnIndex> column_ids; |
283 | 0 | column_ids.reserve(col_sel.size()); |
284 | 0 | for (auto col_sel_idx : col_sel) { |
285 | 0 | auto entry = column_references.find(ColumnBinding(get.table_index, col_sel_idx)); |
286 | 0 | if (entry == column_references.end()) { |
287 | 0 | throw InternalException("RemoveUnusedColumns - could not find referenced column"); |
288 | 0 | } |
289 | 0 | ColumnIndex new_index(final_column_ids[col_sel_idx].GetPrimaryIndex(), entry->second.child_columns); |
290 | 0 | column_ids.emplace_back(new_index); |
291 | 0 | } |
292 | 0 | if (column_ids.empty()) { |
293 | | // this generally means we are only interested in whether or not anything exists in the table (e.g. |
294 | | // EXISTS(SELECT * FROM tbl)) in this case, we just scan the row identifier column as it means we do not |
295 | | // need to read any of the columns |
296 | 0 | column_ids.emplace_back(get.GetAnyColumn()); |
297 | 0 | } |
298 | 0 | get.SetColumnIds(std::move(column_ids)); |
299 | |
|
300 | 0 | if (!get.function.filter_prune) { |
301 | 0 | return; |
302 | 0 | } |
303 | | // Now set the projection cols by matching the "selection vector" that excludes filter columns |
304 | | // with the "selection vector" that includes filter columns |
305 | 0 | idx_t col_idx = 0; |
306 | 0 | get.projection_ids.clear(); |
307 | 0 | for (auto proj_sel_idx : proj_sel) { |
308 | 0 | for (; col_idx < col_sel.size(); col_idx++) { |
309 | 0 | if (proj_sel_idx == col_sel[col_idx]) { |
310 | 0 | get.projection_ids.push_back(col_idx); |
311 | 0 | break; |
312 | 0 | } |
313 | 0 | } |
314 | 0 | } |
315 | 0 | return; |
316 | 0 | } |
317 | 21 | case LogicalOperatorType::LOGICAL_DISTINCT: { |
318 | 21 | auto &distinct = op.Cast<LogicalDistinct>(); |
319 | 21 | if (distinct.distinct_type == DistinctType::DISTINCT_ON) { |
320 | | // distinct type references columns that need to be distinct on, so no |
321 | | // need to implicity reference everything. |
322 | 0 | break; |
323 | 0 | } |
324 | | // distinct, all projected columns are used for the DISTINCT computation |
325 | | // mark all columns as used and continue to the children |
326 | | // FIXME: DISTINCT with expression list does not implicitly reference everything |
327 | 21 | everything_referenced = true; |
328 | 21 | break; |
329 | 21 | } |
330 | 0 | case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: |
331 | 1.63k | case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: |
332 | 32.8k | case LogicalOperatorType::LOGICAL_CTE_REF: |
333 | 32.8k | case LogicalOperatorType::LOGICAL_COPY_TO_FILE: |
334 | 32.8k | case LogicalOperatorType::LOGICAL_PIVOT: { |
335 | 32.8k | everything_referenced = true; |
336 | 32.8k | break; |
337 | 32.8k | } |
338 | 114k | default: |
339 | 114k | break; |
340 | 286k | } |
341 | 147k | LogicalOperatorVisitor::VisitOperatorExpressions(op); |
342 | 147k | LogicalOperatorVisitor::VisitOperatorChildren(op); |
343 | | |
344 | 147k | if (op.type == LogicalOperatorType::LOGICAL_ASOF_JOIN || op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN || |
345 | 147k | op.type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { |
346 | 44 | auto &comp_join = op.Cast<LogicalComparisonJoin>(); |
347 | | // after removing duplicate columns we may have duplicate join conditions (if the join graph is cyclical) |
348 | 44 | vector<JoinCondition> unique_conditions; |
349 | 44 | for (auto &cond : comp_join.conditions) { |
350 | 44 | bool found = false; |
351 | 44 | for (auto &unique_cond : unique_conditions) { |
352 | 0 | if (cond.comparison == unique_cond.comparison && cond.left->Equals(*unique_cond.left) && |
353 | 0 | cond.right->Equals(*unique_cond.right)) { |
354 | 0 | found = true; |
355 | 0 | break; |
356 | 0 | } |
357 | 0 | } |
358 | 44 | if (!found) { |
359 | 44 | unique_conditions.push_back(std::move(cond)); |
360 | 44 | } |
361 | 44 | } |
362 | 44 | comp_join.conditions = std::move(unique_conditions); |
363 | 44 | } |
364 | 147k | } |
365 | | |
366 | | bool BaseColumnPruner::HandleStructExtractRecursive(Expression &expr, optional_ptr<BoundColumnRefExpression> &colref, |
367 | 918k | vector<idx_t> &indexes) { |
368 | 918k | if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { |
369 | 845k | return false; |
370 | 845k | } |
371 | 73.3k | auto &function = expr.Cast<BoundFunctionExpression>(); |
372 | 73.3k | if (function.function.name != "struct_extract_at" && function.function.name != "struct_extract" && |
373 | 73.3k | function.function.name != "array_extract") { |
374 | 73.3k | return false; |
375 | 73.3k | } |
376 | 0 | if (!function.bind_info) { |
377 | 0 | return false; |
378 | 0 | } |
379 | 0 | if (function.children[0]->return_type.id() != LogicalTypeId::STRUCT) { |
380 | 0 | return false; |
381 | 0 | } |
382 | 0 | auto &bind_data = function.bind_info->Cast<StructExtractBindData>(); |
383 | 0 | indexes.push_back(bind_data.index); |
384 | | // struct extract, check if left child is a bound column ref |
385 | 0 | if (function.children[0]->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { |
386 | | // column reference - check if it is a struct |
387 | 0 | auto &ref = function.children[0]->Cast<BoundColumnRefExpression>(); |
388 | 0 | if (ref.return_type.id() != LogicalTypeId::STRUCT) { |
389 | 0 | return false; |
390 | 0 | } |
391 | 0 | colref = &ref; |
392 | 0 | return true; |
393 | 0 | } |
394 | | // not a column reference - try to handle this recursively |
395 | 0 | if (!HandleStructExtractRecursive(*function.children[0], colref, indexes)) { |
396 | 0 | return false; |
397 | 0 | } |
398 | 0 | return true; |
399 | 0 | } |
400 | | |
401 | 918k | bool BaseColumnPruner::HandleStructExtract(Expression &expr) { |
402 | 918k | optional_ptr<BoundColumnRefExpression> colref; |
403 | 918k | vector<idx_t> indexes; |
404 | 918k | if (!HandleStructExtractRecursive(expr, colref, indexes)) { |
405 | 918k | return false; |
406 | 918k | } |
407 | 0 | D_ASSERT(!indexes.empty()); |
408 | | // construct the ColumnIndex |
409 | 0 | ColumnIndex index = ColumnIndex(indexes[0]); |
410 | 0 | for (idx_t i = 1; i < indexes.size(); i++) { |
411 | 0 | ColumnIndex new_index(indexes[i]); |
412 | 0 | new_index.AddChildIndex(std::move(index)); |
413 | 0 | index = std::move(new_index); |
414 | 0 | } |
415 | 0 | AddBinding(*colref, std::move(index)); |
416 | 0 | return true; |
417 | 918k | } |
418 | | |
419 | 0 | void MergeChildColumns(vector<ColumnIndex> ¤t_child_columns, ColumnIndex &new_child_column) { |
420 | 0 | if (current_child_columns.empty()) { |
421 | | // there's already a reference to the full column - we can't extract only a subfield |
422 | | // skip struct projection pushdown |
423 | 0 | return; |
424 | 0 | } |
425 | | // if we are already extract sub-fields, add it (if it is not there yet) |
426 | 0 | for (auto &binding : current_child_columns) { |
427 | 0 | if (binding.GetPrimaryIndex() != new_child_column.GetPrimaryIndex()) { |
428 | 0 | continue; |
429 | 0 | } |
430 | | // found a match: sub-field is already projected |
431 | | // check if we have child columns |
432 | 0 | auto &nested_child_columns = binding.GetChildIndexesMutable(); |
433 | 0 | if (!new_child_column.HasChildren()) { |
434 | | // new child is a reference to a full column - clear any existing bindings (if any) |
435 | 0 | nested_child_columns.clear(); |
436 | 0 | } else { |
437 | | // new child has a sub-reference - merge recursively |
438 | 0 | D_ASSERT(new_child_column.ChildIndexCount() == 1); |
439 | 0 | MergeChildColumns(nested_child_columns, new_child_column.GetChildIndex(0)); |
440 | 0 | } |
441 | 0 | return; |
442 | 0 | } |
443 | | // this child column is not projected yet - add it in |
444 | 0 | current_child_columns.push_back(std::move(new_child_column)); |
445 | 0 | } |
446 | | |
447 | 0 | void BaseColumnPruner::AddBinding(BoundColumnRefExpression &col, ColumnIndex child_column) { |
448 | 0 | auto entry = column_references.find(col.binding); |
449 | 0 | if (entry == column_references.end()) { |
450 | | // column not referenced yet - add a binding to it entirely |
451 | 0 | ReferencedColumn column; |
452 | 0 | column.bindings.push_back(col); |
453 | 0 | column.child_columns.push_back(std::move(child_column)); |
454 | 0 | column_references.insert(make_pair(col.binding, std::move(column))); |
455 | 0 | } else { |
456 | | // column reference already exists - check add the binding |
457 | 0 | auto &column = entry->second; |
458 | 0 | column.bindings.push_back(col); |
459 | |
|
460 | 0 | MergeChildColumns(column.child_columns, child_column); |
461 | 0 | } |
462 | 0 | } |
463 | | |
464 | 360k | void BaseColumnPruner::AddBinding(BoundColumnRefExpression &col) { |
465 | 360k | auto entry = column_references.find(col.binding); |
466 | 360k | if (entry == column_references.end()) { |
467 | | // column not referenced yet - add a binding to it entirely |
468 | 290k | column_references[col.binding].bindings.push_back(col); |
469 | 290k | } else { |
470 | | // column reference already exists - add the binding and clear any sub-references |
471 | 69.8k | auto &column = entry->second; |
472 | 69.8k | column.bindings.push_back(col); |
473 | 69.8k | column.child_columns.clear(); |
474 | 69.8k | } |
475 | 360k | } |
476 | | |
477 | 918k | void BaseColumnPruner::VisitExpression(unique_ptr<Expression> *expression) { |
478 | 918k | auto &expr = **expression; |
479 | 918k | if (HandleStructExtract(expr)) { |
480 | | // already handled |
481 | 0 | return; |
482 | 0 | } |
483 | | // recurse |
484 | 918k | LogicalOperatorVisitor::VisitExpression(expression); |
485 | 918k | } |
486 | | |
487 | | unique_ptr<Expression> BaseColumnPruner::VisitReplace(BoundColumnRefExpression &expr, |
488 | 360k | unique_ptr<Expression> *expr_ptr) { |
489 | | // add a reference to the entire column |
490 | 360k | AddBinding(expr); |
491 | 360k | return nullptr; |
492 | 360k | } |
493 | | |
494 | | unique_ptr<Expression> BaseColumnPruner::VisitReplace(BoundReferenceExpression &expr, |
495 | 0 | unique_ptr<Expression> *expr_ptr) { |
496 | | // BoundReferenceExpression should not be used here yet, they only belong in the physical plan |
497 | 0 | throw InternalException("BoundReferenceExpression should not be used here yet!"); |
498 | 0 | } |
499 | | |
500 | | } // namespace duckdb |