Coverage Report

Created: 2025-06-12 07:25

/src/duckdb/src/function/window/window_rownumber_function.cpp
Line
Count
Source (jump to first uncovered line)
1
#include "duckdb/function/window/window_rownumber_function.hpp"
2
#include "duckdb/function/window/window_shared_expressions.hpp"
3
#include "duckdb/function/window/window_token_tree.hpp"
4
#include "duckdb/planner/expression/bound_window_expression.hpp"
5
6
namespace duckdb {
7
8
//===--------------------------------------------------------------------===//
9
// WindowRowNumberGlobalState
10
//===--------------------------------------------------------------------===//
11
class WindowRowNumberGlobalState : public WindowExecutorGlobalState {
12
public:
13
  WindowRowNumberGlobalState(const WindowRowNumberExecutor &executor, const idx_t payload_count,
14
                             const ValidityMask &partition_mask, const ValidityMask &order_mask)
15
0
      : WindowExecutorGlobalState(executor, payload_count, partition_mask, order_mask),
16
0
        ntile_idx(executor.ntile_idx) {
17
0
    if (!executor.arg_order_idx.empty()) {
18
0
      use_framing = true;
19
20
      //  If the argument order is prefix of the partition ordering,
21
      //  then we can just use the partition ordering.
22
0
      auto &wexpr = executor.wexpr;
23
0
      auto &arg_orders = executor.wexpr.arg_orders;
24
0
      const auto optimize = ClientConfig::GetConfig(executor.context).enable_optimizer;
25
0
      if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) != arg_orders.size()) {
26
        //  "The ROW_NUMBER function can be computed by disambiguating duplicate elements based on their
27
        //  position in the input data, such that two elements never compare as equal."
28
0
        token_tree = make_uniq<WindowTokenTree>(executor.context, executor.wexpr.arg_orders,
29
0
                                                executor.arg_order_idx, payload_count, true);
30
0
      }
31
0
    }
32
0
  }
33
34
  //! Use framing instead of partitions (ORDER BY arguments)
35
  bool use_framing = false;
36
37
  //! The token tree for ORDER BY arguments
38
  unique_ptr<WindowTokenTree> token_tree;
39
40
  //! The evaluation index for NTILE
41
  const column_t ntile_idx;
42
};
43
44
//===--------------------------------------------------------------------===//
45
// WindowRowNumberLocalState
46
//===--------------------------------------------------------------------===//
47
class WindowRowNumberLocalState : public WindowExecutorBoundsState {
48
public:
49
  explicit WindowRowNumberLocalState(const WindowRowNumberGlobalState &grstate)
50
0
      : WindowExecutorBoundsState(grstate), grstate(grstate) {
51
0
    if (grstate.token_tree) {
52
0
      local_tree = grstate.token_tree->GetLocalState();
53
0
    }
54
0
  }
55
56
  //! Accumulate the secondary sort values
57
  void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk,
58
            idx_t input_idx) override;
59
  //! Finish the sinking and prepare to scan
60
  void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override;
61
62
  //! The corresponding global peer state
63
  const WindowRowNumberGlobalState &grstate;
64
  //! The optional sorting state for secondary sorts
65
  unique_ptr<WindowAggregatorState> local_tree;
66
};
67
68
void WindowRowNumberLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk,
69
0
                                     idx_t input_idx) {
70
0
  WindowExecutorBoundsState::Sink(gstate, sink_chunk, coll_chunk, input_idx);
71
72
0
  if (local_tree) {
73
0
    auto &local_tokens = local_tree->Cast<WindowMergeSortTreeLocalState>();
74
0
    local_tokens.SinkChunk(sink_chunk, input_idx, nullptr, 0);
75
0
  }
76
0
}
77
78
0
void WindowRowNumberLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) {
79
0
  WindowExecutorBoundsState::Finalize(gstate, collection);
80
81
0
  if (local_tree) {
82
0
    auto &local_tokens = local_tree->Cast<WindowMergeSortTreeLocalState>();
83
0
    local_tokens.Sort();
84
0
    local_tokens.window_tree.Build();
85
0
  }
86
0
}
87
88
//===--------------------------------------------------------------------===//
89
// WindowRowNumberExecutor
90
//===--------------------------------------------------------------------===//
91
WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context,
92
                                                 WindowSharedExpressions &shared)
93
0
    : WindowExecutor(wexpr, context, shared) {
94
95
0
  for (const auto &order : wexpr.arg_orders) {
96
0
    arg_order_idx.emplace_back(shared.RegisterSink(order.expression));
97
0
  }
98
0
}
99
100
unique_ptr<WindowExecutorGlobalState> WindowRowNumberExecutor::GetGlobalState(const idx_t payload_count,
101
                                                                              const ValidityMask &partition_mask,
102
0
                                                                              const ValidityMask &order_mask) const {
103
0
  return make_uniq<WindowRowNumberGlobalState>(*this, payload_count, partition_mask, order_mask);
104
0
}
105
106
unique_ptr<WindowExecutorLocalState>
107
0
WindowRowNumberExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const {
108
0
  return make_uniq<WindowRowNumberLocalState>(gstate.Cast<WindowRowNumberGlobalState>());
109
0
}
110
111
void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate,
112
                                               DataChunk &eval_chunk, Vector &result, idx_t count,
113
0
                                               idx_t row_idx) const {
114
0
  auto &grstate = gstate.Cast<WindowRowNumberGlobalState>();
115
0
  auto &lrstate = lstate.Cast<WindowRowNumberLocalState>();
116
0
  auto rdata = FlatVector::GetData<uint64_t>(result);
117
118
0
  if (grstate.use_framing) {
119
0
    auto frame_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_BEGIN]);
120
0
    auto frame_end = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_END]);
121
0
    if (grstate.token_tree) {
122
0
      for (idx_t i = 0; i < count; ++i, ++row_idx) {
123
        // Row numbers are unique ranks
124
0
        rdata[i] = grstate.token_tree->Rank(frame_begin[i], frame_end[i], row_idx);
125
0
      }
126
0
    } else {
127
0
      for (idx_t i = 0; i < count; ++i, ++row_idx) {
128
0
        rdata[i] = row_idx - frame_begin[i] + 1;
129
0
      }
130
0
    }
131
0
    return;
132
0
  }
133
134
0
  auto partition_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[PARTITION_BEGIN]);
135
0
  for (idx_t i = 0; i < count; ++i, ++row_idx) {
136
0
    rdata[i] = row_idx - partition_begin[i] + 1;
137
0
  }
138
0
}
139
140
//===--------------------------------------------------------------------===//
141
// WindowNtileExecutor
142
//===--------------------------------------------------------------------===//
143
WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context,
144
                                         WindowSharedExpressions &shared)
145
0
    : WindowRowNumberExecutor(wexpr, context, shared) {
146
147
  // NTILE has one argument
148
0
  ntile_idx = shared.RegisterEvaluate(wexpr.children[0]);
149
0
}
150
151
void WindowNtileExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate,
152
0
                                           DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const {
153
0
  auto &grstate = gstate.Cast<WindowRowNumberGlobalState>();
154
0
  auto &lrstate = lstate.Cast<WindowRowNumberLocalState>();
155
0
  auto partition_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[PARTITION_BEGIN]);
156
0
  auto partition_end = FlatVector::GetData<const idx_t>(lrstate.bounds.data[PARTITION_END]);
157
0
  if (grstate.use_framing) {
158
    // With secondary sorts, we restrict to the frame boundaries, but everything else should compute the same.
159
0
    partition_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_BEGIN]);
160
0
    partition_end = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_END]);
161
0
  }
162
0
  auto rdata = FlatVector::GetData<int64_t>(result);
163
0
  WindowInputExpression ntile_col(eval_chunk, ntile_idx);
164
0
  for (idx_t i = 0; i < count; ++i, ++row_idx) {
165
0
    if (ntile_col.CellIsNull(i)) {
166
0
      FlatVector::SetNull(result, i, true);
167
0
    } else {
168
0
      auto n_param = ntile_col.GetCell<int64_t>(i);
169
0
      if (n_param < 1) {
170
0
        throw InvalidInputException("Argument for ntile must be greater than zero");
171
0
      }
172
      // With thanks from SQLite's ntileValueFunc()
173
0
      auto n_total = NumericCast<int64_t>(partition_end[i] - partition_begin[i]);
174
0
      if (n_param > n_total) {
175
        // more groups allowed than we have values
176
        // map every entry to a unique group
177
0
        n_param = n_total;
178
0
      }
179
0
      int64_t n_size = (n_total / n_param);
180
      // find the row idx within the group
181
0
      D_ASSERT(row_idx >= partition_begin[i]);
182
0
      idx_t partition_idx = 0;
183
0
      if (grstate.token_tree) {
184
0
        partition_idx = grstate.token_tree->Rank(partition_begin[i], partition_end[i], row_idx) - 1;
185
0
      } else {
186
0
        partition_idx = row_idx - partition_begin[i];
187
0
      }
188
0
      auto adjusted_row_idx = NumericCast<int64_t>(partition_idx);
189
190
      // now compute the ntile
191
0
      int64_t n_large = n_total - n_param * n_size;
192
0
      int64_t i_small = n_large * (n_size + 1);
193
0
      int64_t result_ntile;
194
195
0
      D_ASSERT((n_large * (n_size + 1) + (n_param - n_large) * n_size) == n_total);
196
197
0
      if (adjusted_row_idx < i_small) {
198
0
        result_ntile = 1 + adjusted_row_idx / (n_size + 1);
199
0
      } else {
200
0
        result_ntile = 1 + n_large + (adjusted_row_idx - i_small) / n_size;
201
0
      }
202
      // result has to be between [1, NTILE]
203
0
      D_ASSERT(result_ntile >= 1 && result_ntile <= n_param);
204
0
      rdata[i] = result_ntile;
205
0
    }
206
0
  }
207
0
}
208
209
} // namespace duckdb