/src/duckdb/src/function/window/window_rownumber_function.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | #include "duckdb/function/window/window_rownumber_function.hpp" |
2 | | #include "duckdb/function/window/window_shared_expressions.hpp" |
3 | | #include "duckdb/function/window/window_token_tree.hpp" |
4 | | #include "duckdb/planner/expression/bound_window_expression.hpp" |
5 | | |
6 | | namespace duckdb { |
7 | | |
8 | | //===--------------------------------------------------------------------===// |
9 | | // WindowRowNumberGlobalState |
10 | | //===--------------------------------------------------------------------===// |
11 | | class WindowRowNumberGlobalState : public WindowExecutorGlobalState { |
12 | | public: |
13 | | WindowRowNumberGlobalState(const WindowRowNumberExecutor &executor, const idx_t payload_count, |
14 | | const ValidityMask &partition_mask, const ValidityMask &order_mask) |
15 | 0 | : WindowExecutorGlobalState(executor, payload_count, partition_mask, order_mask), |
16 | 0 | ntile_idx(executor.ntile_idx) { |
17 | 0 | if (!executor.arg_order_idx.empty()) { |
18 | 0 | use_framing = true; |
19 | | |
20 | | // If the argument order is prefix of the partition ordering, |
21 | | // then we can just use the partition ordering. |
22 | 0 | auto &wexpr = executor.wexpr; |
23 | 0 | auto &arg_orders = executor.wexpr.arg_orders; |
24 | 0 | const auto optimize = ClientConfig::GetConfig(executor.context).enable_optimizer; |
25 | 0 | if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) != arg_orders.size()) { |
26 | | // "The ROW_NUMBER function can be computed by disambiguating duplicate elements based on their |
27 | | // position in the input data, such that two elements never compare as equal." |
28 | 0 | token_tree = make_uniq<WindowTokenTree>(executor.context, executor.wexpr.arg_orders, |
29 | 0 | executor.arg_order_idx, payload_count, true); |
30 | 0 | } |
31 | 0 | } |
32 | 0 | } |
33 | | |
34 | | //! Use framing instead of partitions (ORDER BY arguments) |
35 | | bool use_framing = false; |
36 | | |
37 | | //! The token tree for ORDER BY arguments |
38 | | unique_ptr<WindowTokenTree> token_tree; |
39 | | |
40 | | //! The evaluation index for NTILE |
41 | | const column_t ntile_idx; |
42 | | }; |
43 | | |
44 | | //===--------------------------------------------------------------------===// |
45 | | // WindowRowNumberLocalState |
46 | | //===--------------------------------------------------------------------===// |
47 | | class WindowRowNumberLocalState : public WindowExecutorBoundsState { |
48 | | public: |
49 | | explicit WindowRowNumberLocalState(const WindowRowNumberGlobalState &grstate) |
50 | 0 | : WindowExecutorBoundsState(grstate), grstate(grstate) { |
51 | 0 | if (grstate.token_tree) { |
52 | 0 | local_tree = grstate.token_tree->GetLocalState(); |
53 | 0 | } |
54 | 0 | } |
55 | | |
56 | | //! Accumulate the secondary sort values |
57 | | void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, |
58 | | idx_t input_idx) override; |
59 | | //! Finish the sinking and prepare to scan |
60 | | void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override; |
61 | | |
62 | | //! The corresponding global peer state |
63 | | const WindowRowNumberGlobalState &grstate; |
64 | | //! The optional sorting state for secondary sorts |
65 | | unique_ptr<WindowAggregatorState> local_tree; |
66 | | }; |
67 | | |
68 | | void WindowRowNumberLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, |
69 | 0 | idx_t input_idx) { |
70 | 0 | WindowExecutorBoundsState::Sink(gstate, sink_chunk, coll_chunk, input_idx); |
71 | |
|
72 | 0 | if (local_tree) { |
73 | 0 | auto &local_tokens = local_tree->Cast<WindowMergeSortTreeLocalState>(); |
74 | 0 | local_tokens.SinkChunk(sink_chunk, input_idx, nullptr, 0); |
75 | 0 | } |
76 | 0 | } |
77 | | |
78 | 0 | void WindowRowNumberLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { |
79 | 0 | WindowExecutorBoundsState::Finalize(gstate, collection); |
80 | |
|
81 | 0 | if (local_tree) { |
82 | 0 | auto &local_tokens = local_tree->Cast<WindowMergeSortTreeLocalState>(); |
83 | 0 | local_tokens.Sort(); |
84 | 0 | local_tokens.window_tree.Build(); |
85 | 0 | } |
86 | 0 | } |
87 | | |
88 | | //===--------------------------------------------------------------------===// |
89 | | // WindowRowNumberExecutor |
90 | | //===--------------------------------------------------------------------===// |
91 | | WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context, |
92 | | WindowSharedExpressions &shared) |
93 | 0 | : WindowExecutor(wexpr, context, shared) { |
94 | |
|
95 | 0 | for (const auto &order : wexpr.arg_orders) { |
96 | 0 | arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); |
97 | 0 | } |
98 | 0 | } |
99 | | |
100 | | unique_ptr<WindowExecutorGlobalState> WindowRowNumberExecutor::GetGlobalState(const idx_t payload_count, |
101 | | const ValidityMask &partition_mask, |
102 | 0 | const ValidityMask &order_mask) const { |
103 | 0 | return make_uniq<WindowRowNumberGlobalState>(*this, payload_count, partition_mask, order_mask); |
104 | 0 | } |
105 | | |
106 | | unique_ptr<WindowExecutorLocalState> |
107 | 0 | WindowRowNumberExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { |
108 | 0 | return make_uniq<WindowRowNumberLocalState>(gstate.Cast<WindowRowNumberGlobalState>()); |
109 | 0 | } |
110 | | |
111 | | void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, |
112 | | DataChunk &eval_chunk, Vector &result, idx_t count, |
113 | 0 | idx_t row_idx) const { |
114 | 0 | auto &grstate = gstate.Cast<WindowRowNumberGlobalState>(); |
115 | 0 | auto &lrstate = lstate.Cast<WindowRowNumberLocalState>(); |
116 | 0 | auto rdata = FlatVector::GetData<uint64_t>(result); |
117 | |
|
118 | 0 | if (grstate.use_framing) { |
119 | 0 | auto frame_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_BEGIN]); |
120 | 0 | auto frame_end = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_END]); |
121 | 0 | if (grstate.token_tree) { |
122 | 0 | for (idx_t i = 0; i < count; ++i, ++row_idx) { |
123 | | // Row numbers are unique ranks |
124 | 0 | rdata[i] = grstate.token_tree->Rank(frame_begin[i], frame_end[i], row_idx); |
125 | 0 | } |
126 | 0 | } else { |
127 | 0 | for (idx_t i = 0; i < count; ++i, ++row_idx) { |
128 | 0 | rdata[i] = row_idx - frame_begin[i] + 1; |
129 | 0 | } |
130 | 0 | } |
131 | 0 | return; |
132 | 0 | } |
133 | | |
134 | 0 | auto partition_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[PARTITION_BEGIN]); |
135 | 0 | for (idx_t i = 0; i < count; ++i, ++row_idx) { |
136 | 0 | rdata[i] = row_idx - partition_begin[i] + 1; |
137 | 0 | } |
138 | 0 | } |
139 | | |
140 | | //===--------------------------------------------------------------------===// |
141 | | // WindowNtileExecutor |
142 | | //===--------------------------------------------------------------------===// |
143 | | WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, |
144 | | WindowSharedExpressions &shared) |
145 | 0 | : WindowRowNumberExecutor(wexpr, context, shared) { |
146 | | |
147 | | // NTILE has one argument |
148 | 0 | ntile_idx = shared.RegisterEvaluate(wexpr.children[0]); |
149 | 0 | } |
150 | | |
151 | | void WindowNtileExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, |
152 | 0 | DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { |
153 | 0 | auto &grstate = gstate.Cast<WindowRowNumberGlobalState>(); |
154 | 0 | auto &lrstate = lstate.Cast<WindowRowNumberLocalState>(); |
155 | 0 | auto partition_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[PARTITION_BEGIN]); |
156 | 0 | auto partition_end = FlatVector::GetData<const idx_t>(lrstate.bounds.data[PARTITION_END]); |
157 | 0 | if (grstate.use_framing) { |
158 | | // With secondary sorts, we restrict to the frame boundaries, but everything else should compute the same. |
159 | 0 | partition_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_BEGIN]); |
160 | 0 | partition_end = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_END]); |
161 | 0 | } |
162 | 0 | auto rdata = FlatVector::GetData<int64_t>(result); |
163 | 0 | WindowInputExpression ntile_col(eval_chunk, ntile_idx); |
164 | 0 | for (idx_t i = 0; i < count; ++i, ++row_idx) { |
165 | 0 | if (ntile_col.CellIsNull(i)) { |
166 | 0 | FlatVector::SetNull(result, i, true); |
167 | 0 | } else { |
168 | 0 | auto n_param = ntile_col.GetCell<int64_t>(i); |
169 | 0 | if (n_param < 1) { |
170 | 0 | throw InvalidInputException("Argument for ntile must be greater than zero"); |
171 | 0 | } |
172 | | // With thanks from SQLite's ntileValueFunc() |
173 | 0 | auto n_total = NumericCast<int64_t>(partition_end[i] - partition_begin[i]); |
174 | 0 | if (n_param > n_total) { |
175 | | // more groups allowed than we have values |
176 | | // map every entry to a unique group |
177 | 0 | n_param = n_total; |
178 | 0 | } |
179 | 0 | int64_t n_size = (n_total / n_param); |
180 | | // find the row idx within the group |
181 | 0 | D_ASSERT(row_idx >= partition_begin[i]); |
182 | 0 | idx_t partition_idx = 0; |
183 | 0 | if (grstate.token_tree) { |
184 | 0 | partition_idx = grstate.token_tree->Rank(partition_begin[i], partition_end[i], row_idx) - 1; |
185 | 0 | } else { |
186 | 0 | partition_idx = row_idx - partition_begin[i]; |
187 | 0 | } |
188 | 0 | auto adjusted_row_idx = NumericCast<int64_t>(partition_idx); |
189 | | |
190 | | // now compute the ntile |
191 | 0 | int64_t n_large = n_total - n_param * n_size; |
192 | 0 | int64_t i_small = n_large * (n_size + 1); |
193 | 0 | int64_t result_ntile; |
194 | |
|
195 | 0 | D_ASSERT((n_large * (n_size + 1) + (n_param - n_large) * n_size) == n_total); |
196 | |
|
197 | 0 | if (adjusted_row_idx < i_small) { |
198 | 0 | result_ntile = 1 + adjusted_row_idx / (n_size + 1); |
199 | 0 | } else { |
200 | 0 | result_ntile = 1 + n_large + (adjusted_row_idx - i_small) / n_size; |
201 | 0 | } |
202 | | // result has to be between [1, NTILE] |
203 | 0 | D_ASSERT(result_ntile >= 1 && result_ntile <= n_param); |
204 | 0 | rdata[i] = result_ntile; |
205 | 0 | } |
206 | 0 | } |
207 | 0 | } |
208 | | |
209 | | } // namespace duckdb |