Coverage Report

Created: 2026-03-31 07:54

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/duckdb/src/optimizer/cse_optimizer.cpp
Line
Count
Source
1
#include "duckdb/optimizer/cse_optimizer.hpp"
2
3
#include "duckdb/planner/expression/bound_columnref_expression.hpp"
4
#include "duckdb/planner/expression_iterator.hpp"
5
#include "duckdb/planner/operator/logical_projection.hpp"
6
#include "duckdb/planner/column_binding_map.hpp"
7
#include "duckdb/planner/binder.hpp"
8
9
namespace duckdb {
10
11
//! The CSENode contains information about a common subexpression; how many times it occurs, and the column index in the
12
//! underlying projection
13
struct CSENode {
14
  idx_t count;
15
  ProjectionIndex column_index;
16
17
152k
  CSENode() : count(1), column_index() {
18
152k
  }
19
};
20
21
//! The CSEReplacementState
22
struct CSEReplacementState {
23
  //! The projection index of the new projection
24
  TableIndex projection_index;
25
  //! Map of expression -> CSENode
26
  expression_map_t<CSENode> expression_count;
27
  //! Map of column bindings to column indexes in the projection expression list
28
  column_binding_map_t<ProjectionIndex> column_map;
29
  //! The set of expressions of the resulting projection
30
  vector<unique_ptr<Expression>> expressions;
31
  //! Cached expressions that are kept around so the expression_map always contains valid expressions
32
  vector<unique_ptr<Expression>> cached_expressions;
33
  //! Short circuit argument tracking
34
  bool short_circuited = false;
35
};
36
37
506k
void CommonSubExpressionOptimizer::VisitOperator(LogicalOperator &op) {
38
506k
  switch (op.type) {
39
185k
  case LogicalOperatorType::LOGICAL_PROJECTION:
40
242k
  case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY:
41
242k
    ExtractCommonSubExpresions(op);
42
242k
    break;
43
264k
  default:
44
264k
    break;
45
506k
  }
46
506k
  LogicalOperatorVisitor::VisitOperator(op);
47
506k
}
48
49
1.26M
void CommonSubExpressionOptimizer::CountExpressions(Expression &expr, CSEReplacementState &state) {
50
  // we only consider expressions with children for CSE elimination
51
1.26M
  switch (expr.GetExpressionClass()) {
52
603k
  case ExpressionClass::BOUND_COLUMN_REF:
53
941k
  case ExpressionClass::BOUND_CONSTANT:
54
949k
  case ExpressionClass::BOUND_PARAMETER:
55
949k
    return;
56
318k
  default:
57
318k
    break;
58
1.26M
  }
59
318k
  if (expr.GetExpressionClass() != ExpressionClass::BOUND_AGGREGATE && !expr.IsVolatile()) {
60
    // we can't move aggregates to a projection, so we only consider the children of the aggregate
61
78.8k
    auto node = state.expression_count.find(expr);
62
78.8k
    if (node == state.expression_count.end()) {
63
      // first time we encounter this expression, insert this node with [count = 1]
64
      // but only if it is not an interior argument of a short circuit sensitive expression.
65
76.1k
      if (!state.short_circuited) {
66
76.0k
        state.expression_count[expr] = CSENode();
67
76.0k
      }
68
76.1k
    } else {
69
      // we encountered this expression before, increment the occurrence count
70
2.71k
      node->second.count++;
71
2.71k
    }
72
78.8k
  }
73
74
  // If we have a function that uses short circuiting, then we can only extract CSEs from the leftmost
75
  // side of the argument tree (child_no == 0)
76
318k
  switch (expr.GetExpressionClass()) {
77
401
  case ExpressionClass::BOUND_CONJUNCTION:
78
38.4k
  case ExpressionClass::BOUND_CASE: {
79
    // Save the short circuit reference
80
38.4k
    const auto save_short_circuit = state.short_circuited;
81
121k
    ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) {
82
121k
      CountExpressions(child, state);
83
121k
      state.short_circuited = true;
84
121k
    });
85
38.4k
    state.short_circuited = save_short_circuit;
86
38.4k
    break;
87
401
  }
88
280k
  default:
89
    // recursively count the children
90
334k
    ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { CountExpressions(child, state); });
91
280k
    break;
92
318k
  }
93
318k
}
94
95
14.5k
void CommonSubExpressionOptimizer::PerformCSEReplacement(unique_ptr<Expression> &expr_ptr, CSEReplacementState &state) {
96
14.5k
  Expression &expr = *expr_ptr;
97
14.5k
  if (expr.GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) {
98
269
    auto &bound_column_ref = expr.Cast<BoundColumnRefExpression>();
99
    // bound column ref, check if this one has already been recorded in the expression list
100
269
    auto column_entry = state.column_map.find(bound_column_ref.binding);
101
269
    if (column_entry == state.column_map.end()) {
102
      // not there yet: push the expression
103
45
      auto new_col_ref = make_uniq<BoundColumnRefExpression>(
104
45
          bound_column_ref.GetAlias(), bound_column_ref.return_type, bound_column_ref.binding);
105
45
      auto new_column_index = ColumnBinding::PushExpression(state.expressions, std::move(new_col_ref));
106
45
      state.column_map[bound_column_ref.binding] = new_column_index;
107
45
      bound_column_ref.binding = ColumnBinding(state.projection_index, new_column_index);
108
224
    } else {
109
      // else: just update the column binding!
110
224
      bound_column_ref.binding = ColumnBinding(state.projection_index, column_entry->second);
111
224
    }
112
269
    return;
113
269
  }
114
  // check if this child is eligible for CSE elimination
115
14.2k
  if (state.expression_count.find(expr) != state.expression_count.end()) {
116
8.09k
    auto &node = state.expression_count[expr];
117
8.09k
    if (node.count > 1) {
118
      // this expression occurs more than once! push it into the projection
119
      // check if it has already been pushed into the projection
120
1.74k
      auto alias = expr.GetAlias();
121
1.74k
      auto type = expr.return_type;
122
1.74k
      if (!node.column_index.IsValid()) {
123
        // has not been pushed yet: push it
124
732
        node.column_index = ColumnBinding::PushExpression(state.expressions, std::move(expr_ptr));
125
1.00k
      } else {
126
1.00k
        state.cached_expressions.push_back(std::move(expr_ptr));
127
1.00k
      }
128
      // replace the original expression with a bound column ref
129
1.74k
      expr_ptr = make_uniq<BoundColumnRefExpression>(alias, type,
130
1.74k
                                                     ColumnBinding(state.projection_index, node.column_index));
131
1.74k
      return;
132
1.74k
    }
133
8.09k
  }
134
  // this expression only occurs once, we can't perform CSE elimination
135
  // look into the children to see if we can replace them
136
12.5k
  ExpressionIterator::EnumerateChildren(expr,
137
12.5k
                                        [&](unique_ptr<Expression> &child) { PerformCSEReplacement(child, state); });
138
12.5k
}
139
140
242k
void CommonSubExpressionOptimizer::ExtractCommonSubExpresions(LogicalOperator &op) {
141
242k
  D_ASSERT(op.children.size() == 1);
142
143
  // first we count for each expression with children how many types it occurs
144
242k
  CSEReplacementState state;
145
242k
  LogicalOperatorVisitor::EnumerateExpressions(
146
811k
      op, [&](unique_ptr<Expression> *child) { CountExpressions(**child, state); });
147
  // check if there are any expressions to extract
148
242k
  bool perform_replacement = false;
149
242k
  for (auto &expr : state.expression_count) {
150
69.7k
    if (expr.second.count > 1) {
151
444
      perform_replacement = true;
152
444
      break;
153
444
    }
154
69.7k
  }
155
242k
  if (!perform_replacement) {
156
    // no CSEs to extract
157
241k
    return;
158
241k
  }
159
444
  state.projection_index = binder.GenerateTableIndex();
160
  // we found common subexpressions to extract
161
  // now we iterate over all the expressions and perform the actual CSE elimination
162
163
444
  LogicalOperatorVisitor::EnumerateExpressions(
164
2.86k
      op, [&](unique_ptr<Expression> *child) { PerformCSEReplacement(*child, state); });
165
444
  D_ASSERT(state.expressions.size() > 0);
166
  // create a projection node as the child of this node
167
444
  auto projection = make_uniq<LogicalProjection>(state.projection_index, std::move(state.expressions));
168
444
  if (op.children[0]->has_estimated_cardinality) {
169
444
    projection->SetEstimatedCardinality(op.children[0]->estimated_cardinality);
170
444
  }
171
444
  projection->children.push_back(std::move(op.children[0]));
172
444
  op.children[0] = std::move(projection);
173
444
}
174
175
} // namespace duckdb