Coverage Report

Created: 2025-11-15 07:36

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp
Line
Count
Source
1
#include "duckdb/optimizer/rule/date_trunc_simplification.hpp"
2
3
#include "duckdb/common/exception.hpp"
4
#include "duckdb/common/enums/expression_type.hpp"
5
#include "duckdb/execution/expression_executor.hpp"
6
#include "duckdb/planner/expression/bound_cast_expression.hpp"
7
#include "duckdb/planner/expression/bound_columnref_expression.hpp"
8
#include "duckdb/planner/expression/bound_comparison_expression.hpp"
9
#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
10
#include "duckdb/planner/expression/bound_constant_expression.hpp"
11
#include "duckdb/planner/expression/bound_operator_expression.hpp"
12
#include "duckdb/optimizer/matcher/expression_matcher.hpp"
13
#include "duckdb/optimizer/expression_rewriter.hpp"
14
#include "duckdb/common/enums/date_part_specifier.hpp"
15
#include "duckdb/function/function.hpp"
16
#include "duckdb/function/function_binder.hpp"
17
18
namespace duckdb {
19
20
72.7k
DateTruncSimplificationRule::DateTruncSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
21
72.7k
  auto op = make_uniq<ComparisonExpressionMatcher>();
22
23
72.7k
  auto lhs = make_uniq<FunctionExpressionMatcher>();
24
72.7k
  lhs->function = make_uniq<ManyFunctionMatcher>(unordered_set<string> {"date_trunc", "datetrunc"});
25
72.7k
  lhs->matchers.push_back(make_uniq<ConstantExpressionMatcher>());
26
72.7k
  lhs->matchers.push_back(make_uniq<ExpressionMatcher>());
27
72.7k
  lhs->policy = SetMatcher::Policy::ORDERED;
28
29
72.7k
  auto rhs = make_uniq<ConstantExpressionMatcher>();
30
31
72.7k
  op->matchers.push_back(std::move(lhs));
32
72.7k
  op->matchers.push_back(std::move(rhs));
33
72.7k
  op->policy = SetMatcher::Policy::UNORDERED;
34
35
72.7k
  root = std::move(op);
36
72.7k
}
37
38
unique_ptr<Expression> DateTruncSimplificationRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings,
39
0
                                                          bool &changes_made, bool is_root) {
40
0
  auto &expr = bindings[0].get().Cast<BoundComparisonExpression>();
41
0
  auto comparison_type = expr.GetExpressionType();
42
43
0
  auto &date_part = bindings[2].get().Cast<BoundConstantExpression>();
44
  // We must have only a column on the LHS.
45
0
  if (bindings[3].get().GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) {
46
0
    return nullptr;
47
0
  }
48
49
0
  auto &column_part = bindings[3].get().Cast<BoundColumnRefExpression>();
50
0
  auto &rhs = bindings[4].get().Cast<BoundConstantExpression>();
51
52
  // Determine whether or not the column name is on the lhs or rhs.
53
0
  const bool col_is_lhs = (expr.left->GetExpressionClass() == ExpressionClass::BOUND_FUNCTION);
54
55
  // We want to treat rhs >= col equivalently to col <= rhs.
56
  // So, get the expression type if it was ordered such that the constant was actually on the right hand side.
57
0
  ExpressionType rhs_comparison_type = comparison_type;
58
0
  if (!col_is_lhs) {
59
0
    rhs_comparison_type = FlipComparisonExpression(comparison_type);
60
0
  }
61
62
  // Check whether trunc(date_part, constant_rhs) = constant_rhs.
63
0
  const bool is_truncated = DateIsTruncated(date_part, rhs);
64
65
0
  switch (rhs_comparison_type) {
66
0
  case ExpressionType::COMPARE_EQUAL:
67
0
  case ExpressionType::COMPARE_NOT_DISTINCT_FROM:
68
    // We handle two very similar optimizations here:
69
    //
70
    // date_trunc(part, column) = constant_rhs  -->  column >= date_trunc(part, constant_rhs) AND
71
    //                                               column < date_trunc(part, date_add(constant_rhs,
72
    //                                                                                  INTERVAL 1 part))
73
    //    or, if date_trunc(part, constant_rhs) <> constant_rhs, this is unsatisfiable
74
    //
75
    // ----
76
    //
77
    // date_trunc(part, column) IS NOT DISTINCT FROM constant_rhs
78
    //
79
    //   Here we have two cases: when constant_rhs is NULL, this simplifies to:
80
    //
81
    // column IS NULL
82
    //
83
    //   Otherwise, the expression becomes:
84
    //
85
    // (column >= date_trunc(part, constant_rhs) AND
86
    //  column < date_trunc(part, date_add(constant_rhs, INTERVAL 1 part)) AND
87
    //  column IS NOT NULL)
88
    //
89
0
    {
90
      // First check if we can just return `column IS NULL`.
91
0
      if (rhs_comparison_type == ExpressionType::COMPARE_NOT_DISTINCT_FROM && rhs.value.IsNull()) {
92
0
        auto op = make_uniq<BoundOperatorExpression>(ExpressionType::OPERATOR_IS_NULL, LogicalType::BOOLEAN);
93
0
        op->children.push_back(column_part.Copy());
94
0
        return std::move(op);
95
0
      } else {
96
0
        if (!is_truncated) {
97
0
          return make_uniq<BoundConstantExpression>(Value::BOOLEAN(false));
98
0
        }
99
100
0
        auto trunc = CreateTrunc(date_part, rhs, column_part.return_type);
101
0
        if (!trunc) {
102
0
          return nullptr;
103
0
        }
104
105
0
        auto trunc_add = CreateTruncAdd(date_part, rhs, column_part.return_type);
106
0
        if (!trunc_add) {
107
0
          return nullptr;
108
0
        }
109
110
0
        auto gteq = make_uniq<BoundComparisonExpression>(ExpressionType::COMPARE_GREATERTHANOREQUALTO,
111
0
                                                         column_part.Copy(), std::move(trunc));
112
0
        auto lt = make_uniq<BoundComparisonExpression>(ExpressionType::COMPARE_LESSTHAN, column_part.Copy(),
113
0
                                                       std::move(trunc_add));
114
115
        // For IS NOT DISTINCT FROM, we also have to add the extra NULL term.
116
0
        if (rhs_comparison_type == ExpressionType::COMPARE_NOT_DISTINCT_FROM) {
117
0
          auto comp = make_uniq<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_AND, std::move(gteq),
118
0
                                                            std::move(lt));
119
120
0
          auto isnotnull =
121
0
              make_uniq<BoundOperatorExpression>(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN);
122
0
          isnotnull->children.push_back(column_part.Copy());
123
124
0
          return make_uniq<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_AND, std::move(comp),
125
0
                                                       std::move(isnotnull));
126
0
        } else {
127
0
          return make_uniq<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_AND, std::move(gteq),
128
0
                                                       std::move(lt));
129
0
        }
130
0
      }
131
0
    }
132
133
0
  case ExpressionType::COMPARE_NOTEQUAL:
134
0
  case ExpressionType::COMPARE_DISTINCT_FROM:
135
    // We handle two very similar optimizations here:
136
    //
137
    // date_trunc(part, column) <> constant_rhs  -->  column < date_trunc(part, constant_rhs) OR
138
    //                                                column >= date_trunc(part, date_add(constant_rhs,
139
    //                                                                                    INTERVAL 1 part))
140
    //   or, if date_trunc(part, constant_rhs) <> constant_rhs, this is always true
141
    //
142
    // ----
143
    //
144
    // date_trunc(part, column) IS DISTINCT FROM constant_rhs
145
    //
146
    //   Here we have two cases: when constant_rhs is NULL, this simplifies to:
147
    //
148
    // column IS NOT NULL
149
    //
150
    //   Otherwise, the expression becomes:
151
    //
152
    // (column < date_trunc(part, constant_rhs) OR
153
    //  column >= date_trunc(part, date_add(constant_rhs, INTERVAL 1 part)) OR
154
    //  column IS NULL)
155
    //
156
0
    {
157
0
      if (rhs_comparison_type == ExpressionType::COMPARE_DISTINCT_FROM && rhs.value.IsNull()) {
158
        // Return 'column IS NOT NULL'.
159
0
        auto op =
160
0
            make_uniq<BoundOperatorExpression>(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN);
161
0
        op->children.push_back(column_part.Copy());
162
0
        return std::move(op);
163
0
      } else {
164
0
        if (!is_truncated) {
165
0
          return make_uniq<BoundConstantExpression>(Value::BOOLEAN(true));
166
0
        }
167
168
0
        auto trunc = CreateTrunc(date_part, rhs, column_part.return_type);
169
0
        if (!trunc) {
170
0
          return nullptr;
171
0
        }
172
173
0
        auto trunc_add = CreateTruncAdd(date_part, rhs, column_part.return_type);
174
0
        if (!trunc_add) {
175
0
          return nullptr;
176
0
        }
177
178
0
        auto lt = make_uniq<BoundComparisonExpression>(ExpressionType::COMPARE_LESSTHAN, column_part.Copy(),
179
0
                                                       std::move(trunc));
180
0
        auto gteq = make_uniq<BoundComparisonExpression>(ExpressionType::COMPARE_GREATERTHANOREQUALTO,
181
0
                                                         column_part.Copy(), std::move(trunc_add));
182
183
        // If this is a DISTINCT FROM, we need to add the 'column IS NULL' term.
184
0
        if (rhs_comparison_type == ExpressionType::COMPARE_DISTINCT_FROM) {
185
0
          auto comp = make_uniq<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_OR, std::move(gteq),
186
0
                                                            std::move(lt));
187
188
0
          auto isnull =
189
0
              make_uniq<BoundOperatorExpression>(ExpressionType::OPERATOR_IS_NULL, LogicalType::BOOLEAN);
190
0
          isnull->children.push_back(column_part.Copy());
191
192
0
          return make_uniq<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_OR, std::move(comp),
193
0
                                                       std::move(isnull));
194
0
        } else {
195
0
          return make_uniq<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_OR, std::move(gteq),
196
0
                                                       std::move(lt));
197
0
        }
198
0
      }
199
0
    }
200
0
    return nullptr;
201
202
0
  case ExpressionType::COMPARE_LESSTHAN:
203
0
  case ExpressionType::COMPARE_GREATERTHANOREQUALTO:
204
    // date_trunc(part, column) <  constant_rhs  -->  column <  date_trunc(part, date_add(constant_rhs,
205
    //                                                                                    INTERVAL 1 part))
206
    // date_trunc(part, column) >= constant_rhs  -->  column >= date_trunc(part, date_add(constant_rhs,
207
    //                                                                                    INTERVAL 1 part))
208
0
    {
209
      // The optimization for < and >= is a little tricky: if trunc(rhs) = rhs, then we need to just
210
      // use the rhs as-is, instead of using trunc(rhs + 1 date_part).
211
0
      if (!is_truncated) {
212
        // Create date_trunc(part, date_add(rhs, INTERVAL 1 part)) and fold the constant.
213
0
        auto trunc = CreateTruncAdd(date_part, rhs, column_part.return_type);
214
0
        if (!trunc) {
215
0
          return nullptr; // Something went wrong---don't do the optimization.
216
0
        }
217
218
0
        if (col_is_lhs) {
219
0
          expr.left = column_part.Copy();
220
0
          expr.right = std::move(trunc);
221
0
        } else {
222
0
          expr.right = column_part.Copy();
223
0
          expr.left = std::move(trunc);
224
0
        }
225
0
      } else {
226
        // If the RHS is already truncated (i.e.  date_trunc(part, rhs) = rhs), then we can use
227
        // it as-is.
228
0
        if (col_is_lhs) {
229
0
          expr.left = column_part.Copy();
230
          // Determine whether the RHS needs to be casted.
231
0
          if (rhs.return_type.id() != expr.left->return_type.id()) {
232
0
            expr.right = CastAndEvaluate(std::move(expr.right), expr.left->return_type);
233
0
          }
234
0
        } else {
235
0
          expr.right = column_part.Copy();
236
          // Determine whether the RHS needs to be casted.
237
0
          if (rhs.return_type.id() != expr.right->return_type.id()) {
238
0
            expr.left = CastAndEvaluate(std::move(expr.left), expr.right->return_type);
239
0
          }
240
0
        }
241
0
      }
242
243
0
      changes_made = true;
244
0
      return nullptr;
245
0
    }
246
247
0
  case ExpressionType::COMPARE_LESSTHANOREQUALTO:
248
0
  case ExpressionType::COMPARE_GREATERTHAN:
249
    // date_trunc(part, column) <= constant_rhs  -->  column <  date_trunc(part, date_add(constant_rhs,
250
    //                                                                                    INTERVAL 1 part))
251
    // date_trunc(part, column) >  constant_rhs  -->  column >= date_trunc(part, date_add(constant_rhs,
252
    //                                                                                    INTERVAL 1 part))
253
0
    {
254
      // Create date_trunc(part, date_add(rhs, INTERVAL 1 part)) and fold the constant.
255
0
      auto trunc = CreateTruncAdd(date_part, rhs, column_part.return_type);
256
0
      if (!trunc) {
257
0
        return nullptr; // Something went wrong---don't do the optimization.
258
0
      }
259
260
0
      if (col_is_lhs) {
261
0
        expr.left = column_part.Copy();
262
0
        expr.right = std::move(trunc);
263
0
      } else {
264
0
        expr.right = column_part.Copy();
265
0
        expr.left = std::move(trunc);
266
0
      }
267
268
      // > needs to become >=, and <= needs to become <.
269
0
      if (rhs_comparison_type == ExpressionType::COMPARE_GREATERTHAN) {
270
0
        if (col_is_lhs) {
271
0
          expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_GREATERTHANOREQUALTO);
272
0
        } else {
273
0
          expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_LESSTHANOREQUALTO);
274
0
        }
275
0
      } else {
276
0
        if (col_is_lhs) {
277
0
          expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_LESSTHAN);
278
0
        } else {
279
0
          expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_GREATERTHAN);
280
0
        }
281
0
      }
282
283
0
      changes_made = true;
284
0
      return nullptr;
285
0
    }
286
287
0
  default:
288
0
    return nullptr;
289
0
  }
290
0
}
291
292
0
string DateTruncSimplificationRule::DatePartToFunc(const DatePartSpecifier &date_part) {
293
0
  switch (date_part) {
294
  // These specifiers can be used as intervals.
295
0
  case DatePartSpecifier::YEAR:
296
0
    return "to_years";
297
0
  case DatePartSpecifier::MONTH:
298
0
    return "to_months";
299
0
  case DatePartSpecifier::DAY:
300
0
    return "to_days";
301
0
  case DatePartSpecifier::DECADE:
302
0
    return "to_decades";
303
0
  case DatePartSpecifier::CENTURY:
304
0
    return "to_centuries";
305
0
  case DatePartSpecifier::MILLENNIUM:
306
0
    return "to_millennia";
307
0
  case DatePartSpecifier::MICROSECONDS:
308
0
    return "to_microseconds";
309
0
  case DatePartSpecifier::MILLISECONDS:
310
0
    return "to_milliseconds";
311
0
  case DatePartSpecifier::SECOND:
312
0
    return "to_seconds";
313
0
  case DatePartSpecifier::MINUTE:
314
0
    return "to_minutes";
315
0
  case DatePartSpecifier::HOUR:
316
0
    return "to_hours";
317
0
  case DatePartSpecifier::WEEK:
318
0
    return "to_weeks";
319
0
  case DatePartSpecifier::QUARTER:
320
0
    return "to_quarters";
321
322
  // These specifiers cannot be used as intervals and can only be used as
323
  // date parts.
324
0
  case DatePartSpecifier::DOW:
325
0
  case DatePartSpecifier::ISODOW:
326
0
  case DatePartSpecifier::DOY:
327
0
  case DatePartSpecifier::ISOYEAR:
328
0
  case DatePartSpecifier::YEARWEEK:
329
0
  case DatePartSpecifier::ERA:
330
0
  case DatePartSpecifier::TIMEZONE:
331
0
  case DatePartSpecifier::TIMEZONE_HOUR:
332
0
  case DatePartSpecifier::TIMEZONE_MINUTE:
333
0
  default:
334
0
    return "";
335
0
  }
336
0
}
337
338
unique_ptr<Expression> DateTruncSimplificationRule::CreateTrunc(const BoundConstantExpression &date_part,
339
                                                                const BoundConstantExpression &rhs,
340
0
                                                                const LogicalType &return_type) {
341
0
  FunctionBinder binder(rewriter.context);
342
0
  ErrorData error;
343
344
0
  vector<unique_ptr<Expression>> args;
345
0
  args.emplace_back(date_part.Copy());
346
0
  args.emplace_back(rhs.Copy());
347
0
  auto trunc = binder.BindScalarFunction(DEFAULT_SCHEMA, "date_trunc", std::move(args), error);
348
349
  // Ensure that the RHS type matches the column type.
350
0
  if (trunc->return_type.id() != return_type.id()) {
351
0
    trunc = BoundCastExpression::AddDefaultCastToType(std::move(trunc), return_type, true);
352
0
  }
353
354
0
  if (trunc->IsFoldable()) {
355
0
    Value result;
356
0
    if (!ExpressionExecutor::TryEvaluateScalar(rewriter.context, *trunc, result)) {
357
0
      return trunc;
358
0
    }
359
360
0
    return make_uniq<BoundConstantExpression>(result);
361
0
  }
362
363
0
  return trunc;
364
0
}
365
366
unique_ptr<Expression> DateTruncSimplificationRule::CreateTruncAdd(const BoundConstantExpression &date_part,
367
                                                                   const BoundConstantExpression &rhs,
368
0
                                                                   const LogicalType &return_type) {
369
0
  DatePartSpecifier part = GetDatePartSpecifier(StringValue::Get(date_part.value));
370
0
  const string interval_func_name = DatePartToFunc(part);
371
372
  // If the date part cannot be represented as an interval, then we cannot
373
  // perform the optimization.
374
0
  if (interval_func_name.empty()) {
375
0
    return nullptr;
376
0
  }
377
378
0
  FunctionBinder binder(rewriter.context);
379
0
  ErrorData error;
380
381
0
  vector<unique_ptr<Expression>> args1;
382
0
  auto constant_param = make_uniq<BoundConstantExpression>(Value::INTEGER(1));
383
0
  args1.emplace_back(std::move(constant_param));
384
0
  auto interval = binder.BindScalarFunction(DEFAULT_SCHEMA, interval_func_name, std::move(args1), error);
385
0
  if (!interval) {
386
0
    return nullptr; // Something wrong---just don't do the optimization.
387
0
  }
388
389
0
  vector<unique_ptr<Expression>> args2;
390
0
  args2.emplace_back(rhs.Copy());
391
0
  args2.emplace_back(std::move(interval));
392
0
  auto add = binder.BindScalarFunction(DEFAULT_SCHEMA, "+", std::move(args2), error);
393
394
0
  vector<unique_ptr<Expression>> args3;
395
0
  args3.emplace_back(date_part.Copy());
396
0
  args3.emplace_back(std::move(add));
397
0
  auto trunc = binder.BindScalarFunction(DEFAULT_SCHEMA, "date_trunc", std::move(args3), error);
398
399
  // Ensure that the RHS type matches the column type.
400
0
  if (trunc->return_type.id() != return_type.id()) {
401
0
    trunc = BoundCastExpression::AddDefaultCastToType(std::move(trunc), return_type, true);
402
0
  }
403
404
0
  if (trunc->IsFoldable()) {
405
0
    Value result;
406
0
    if (!ExpressionExecutor::TryEvaluateScalar(rewriter.context, *trunc, result)) {
407
0
      return trunc;
408
0
    }
409
410
0
    return make_uniq<BoundConstantExpression>(result);
411
0
  }
412
413
0
  return trunc;
414
0
}
415
416
bool DateTruncSimplificationRule::DateIsTruncated(const BoundConstantExpression &date_part,
417
0
                                                  const BoundConstantExpression &rhs) {
418
  // If the rhs is null, then the date is "truncated" in the sense that date_trunc(..., NULL) is also NULL.
419
0
  if (rhs.value.IsNull()) {
420
0
    return true;
421
0
  }
422
423
  // Create the node date_trunc(date_part, rhs).
424
0
  auto trunc = CreateTrunc(date_part, rhs, rhs.return_type);
425
426
0
  Value trunc_result, result;
427
0
  if (!ExpressionExecutor::TryEvaluateScalar(rewriter.context, *trunc, trunc_result)) {
428
0
    return false;
429
0
  }
430
0
  if (!ExpressionExecutor::TryEvaluateScalar(rewriter.context, rhs, result)) {
431
0
    return false;
432
0
  }
433
434
0
  return (result == trunc_result);
435
0
}
436
437
unique_ptr<Expression> DateTruncSimplificationRule::CastAndEvaluate(unique_ptr<Expression> rhs,
438
0
                                                                    const LogicalType &return_type) {
439
0
  auto cast = BoundCastExpression::AddDefaultCastToType(std::move(rhs), return_type, true);
440
0
  if (cast->IsFoldable()) {
441
0
    Value result;
442
0
    if (!ExpressionExecutor::TryEvaluateScalar(rewriter.context, *cast, result)) {
443
0
      return cast;
444
0
    }
445
446
0
    return make_uniq<BoundConstantExpression>(result);
447
0
  }
448
449
0
  return cast;
450
0
}
451
452
} // namespace duckdb