/src/duckdb/src/optimizer/expression_rewriter.cpp
Line | Count | Source |
1 | | #include "duckdb/optimizer/expression_rewriter.hpp" |
2 | | |
3 | | #include "duckdb/common/exception.hpp" |
4 | | #include "duckdb/planner/expression_iterator.hpp" |
5 | | #include "duckdb/planner/operator/logical_filter.hpp" |
6 | | #include "duckdb/function/scalar/generic_functions.hpp" |
7 | | #include "duckdb/function/scalar/generic_common.hpp" |
8 | | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
9 | | #include "duckdb/planner/expression/bound_function_expression.hpp" |
10 | | |
11 | | namespace duckdb { |
12 | | |
13 | | unique_ptr<Expression> ExpressionRewriter::ApplyRules(LogicalOperator &op, const vector<reference<Rule>> &rules, |
14 | 660k | unique_ptr<Expression> expr, bool &changes_made, bool is_root) { |
15 | 12.8M | for (auto &rule : rules) { |
16 | 12.8M | vector<reference<Expression>> bindings; |
17 | 12.8M | if (rule.get().root->Match(*expr, bindings)) { |
18 | | // the rule matches! try to apply it |
19 | 264k | bool rule_made_change = false; |
20 | 264k | auto alias = expr->alias; |
21 | 264k | auto result = rule.get().Apply(op, bindings, rule_made_change, is_root); |
22 | 264k | if (result) { |
23 | 51.8k | changes_made = true; |
24 | | // the base node changed: the rule applied changes |
25 | | // rerun on the new node |
26 | 51.8k | if (!alias.empty()) { |
27 | 13.3k | result->alias = std::move(alias); |
28 | 13.3k | } |
29 | 51.8k | return ExpressionRewriter::ApplyRules(op, rules, std::move(result), changes_made); |
30 | 212k | } else if (rule_made_change) { |
31 | 176 | changes_made = true; |
32 | | // the base node didn't change, but changes were made, rerun |
33 | 176 | return expr; |
34 | 176 | } |
35 | | // else nothing changed, continue to the next rule |
36 | 212k | continue; |
37 | 264k | } |
38 | 12.8M | } |
39 | | // no changes could be made to this node |
40 | | // recursively run on the children of this node |
41 | 608k | ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr<Expression> &child) { |
42 | 256k | child = ExpressionRewriter::ApplyRules(op, rules, std::move(child), changes_made); |
43 | 256k | }); |
44 | 608k | return expr; |
45 | 660k | } |
46 | | |
47 | 51 | unique_ptr<Expression> ExpressionRewriter::ConstantOrNull(unique_ptr<Expression> child, Value value) { |
48 | 51 | vector<unique_ptr<Expression>> children; |
49 | 51 | children.push_back(make_uniq<BoundConstantExpression>(value)); |
50 | 51 | children.push_back(std::move(child)); |
51 | 51 | return ConstantOrNull(std::move(children), std::move(value)); |
52 | 51 | } |
53 | | |
54 | 51 | unique_ptr<Expression> ExpressionRewriter::ConstantOrNull(vector<unique_ptr<Expression>> children, Value value) { |
55 | 51 | auto type = value.type(); |
56 | 51 | auto func = ConstantOrNullFun::GetFunction(); |
57 | 51 | func.arguments[0] = type; |
58 | 51 | func.SetReturnType(type); |
59 | 51 | children.insert(children.begin(), make_uniq<BoundConstantExpression>(value)); |
60 | 51 | return make_uniq<BoundFunctionExpression>(type, func, std::move(children), ConstantOrNull::Bind(std::move(value))); |
61 | 51 | } |
62 | | |
63 | 228k | void ExpressionRewriter::VisitOperator(LogicalOperator &op) { |
64 | 228k | VisitOperatorChildren(op); |
65 | 228k | this->op = &op; |
66 | | |
67 | 228k | to_apply_rules.clear(); |
68 | 4.80M | for (auto &rule : rules) { |
69 | 4.80M | to_apply_rules.push_back(*rule); |
70 | 4.80M | } |
71 | | |
72 | 228k | VisitOperatorExpressions(op); |
73 | | |
74 | | // if it is a LogicalFilter, we split up filter conjunctions again |
75 | 228k | if (op.type == LogicalOperatorType::LOGICAL_FILTER) { |
76 | 840 | auto &filter = op.Cast<LogicalFilter>(); |
77 | 840 | filter.SplitPredicates(); |
78 | 840 | } |
79 | 228k | } |
80 | | |
81 | 305k | void ExpressionRewriter::VisitExpression(unique_ptr<Expression> *expression) { |
82 | 305k | bool changes_made; |
83 | 352k | do { |
84 | 352k | changes_made = false; |
85 | 352k | *expression = ExpressionRewriter::ApplyRules(*op, to_apply_rules, std::move(*expression), changes_made, true); |
86 | 352k | } while (changes_made); |
87 | 305k | } |
88 | | |
89 | 99.8k | ClientContext &Rule::GetContext() const { |
90 | 99.8k | return rewriter.context; |
91 | 99.8k | } |
92 | | |
93 | | } // namespace duckdb |