/src/serenity/AK/Coroutine.h
Line | Count | Source (jump to first uncovered line) |
1 | | /* |
2 | | * Copyright (c) 2024, Dan Klishch <danilklishch@gmail.com> |
3 | | * |
4 | | * SPDX-License-Identifier: BSD-2-Clause |
5 | | */ |
6 | | |
7 | | #pragma once |
8 | | |
9 | | #include <AK/Concepts.h> |
10 | | #include <AK/Noncopyable.h> |
11 | | #include <coroutine> |
12 | | |
13 | | namespace AK { |
14 | | |
15 | | namespace Detail { |
16 | | |
17 | | // FIXME: GCC ICEs when an implementation of CO_TRY_OR_FAIL with statement expressions is used. See also LibTest/AsyncTestCase.h. |
18 | | #if defined(AK_COROUTINE_STATEMENT_EXPRS_BROKEN) && !defined(AK_COROUTINE_DESTRUCTION_BROKEN) |
19 | | # define AK_USE_TRY_OR_FAIL_AWAITER |
20 | | #endif |
21 | | |
22 | | #ifdef AK_USE_TRY_OR_FAIL_AWAITER |
23 | | namespace Test { |
24 | | |
25 | | template<typename T> |
26 | | struct TryOrFailAwaiter; |
27 | | |
28 | | } |
29 | | #endif |
30 | | |
31 | | struct SuspendNever { |
32 | | // Even though we set -fno-exceptions, Clang really wants these to be noexcept. |
33 | 550k | bool await_ready() const noexcept { return true; } |
34 | 0 | void await_suspend(std::coroutine_handle<>) const noexcept { } |
35 | 550k | void await_resume() const noexcept { } |
36 | | }; |
37 | | |
38 | | struct SuspendAlways { |
39 | 131M | bool await_ready() const noexcept { return false; } |
40 | 131M | void await_suspend(std::coroutine_handle<>) const noexcept { } |
41 | 130M | void await_resume() const noexcept { } |
42 | | }; |
43 | | |
44 | | struct SymmetricControlTransfer { |
45 | | SymmetricControlTransfer(std::coroutine_handle<> handle) |
46 | 0 | : m_handle(handle ? handle : std::noop_coroutine()) |
47 | 0 | { |
48 | 0 | } |
49 | | |
50 | 0 | bool await_ready() const noexcept { return false; } |
51 | 0 | auto await_suspend(std::coroutine_handle<>) const noexcept { return m_handle; } |
52 | 0 | void await_resume() const noexcept { } |
53 | | |
54 | | std::coroutine_handle<> m_handle; |
55 | | }; |
56 | | |
57 | | template<typename T> |
58 | | struct TryAwaiter; |
59 | | |
60 | | template<typename T> |
61 | | struct ValueHolder { |
62 | | alignas(T) u8 m_return_value[sizeof(T)]; |
63 | | }; |
64 | | |
65 | | template<> |
66 | | struct ValueHolder<void> { }; |
67 | | } |
68 | | |
69 | | template<typename T> |
70 | | class [[nodiscard]] Coroutine : private Detail::ValueHolder<T> { |
71 | | struct CoroutinePromiseVoid; |
72 | | struct CoroutinePromiseValue; |
73 | | |
74 | | AK_MAKE_NONCOPYABLE(Coroutine); |
75 | | |
76 | | public: |
77 | | using ReturnType = T; |
78 | | using promise_type = Conditional<SameAs<T, void>, CoroutinePromiseVoid, CoroutinePromiseValue>; |
79 | | |
80 | | ~Coroutine() |
81 | 0 | { |
82 | 0 | VERIFY(await_ready()); |
83 | | if constexpr (!SameAs<T, void>) |
84 | | return_value()->~T(); |
85 | 0 | if (m_handle) |
86 | 0 | m_handle.destroy(); |
87 | 0 | } |
88 | | |
89 | | Coroutine(Coroutine&& other) |
90 | 0 | { |
91 | 0 | m_handle = AK::exchange(other.m_handle, {}); |
92 | 0 | if (!await_ready()) |
93 | 0 | m_handle.promise().m_coroutine = this; |
94 | 0 | else if constexpr (!IsVoid<T>) |
95 | 0 | new (return_value()) T(move(*other.return_value())); |
96 | 0 | } |
97 | | |
98 | | Coroutine& operator=(Coroutine&& other) |
99 | | { |
100 | | if (this != &other) { |
101 | | this->~Coroutine(); |
102 | | new (this) Coroutine(move(other)); |
103 | | } |
104 | | return *this; |
105 | | } |
106 | | |
107 | | bool await_ready() const |
108 | 0 | { |
109 | 0 | return !m_handle || m_handle.done(); |
110 | 0 | } Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<void, AK::Error> >::await_ready() const Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<AK::NonnullOwnPtr<Core::TCPSocket>, AK::Error> >::await_ready() const Unexecuted instantiation: AK::Coroutine<void>::await_ready() const Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<unsigned long, AK::Error> >::await_ready() const |
111 | | |
112 | | void await_suspend(std::coroutine_handle<> awaiter) |
113 | 0 | { |
114 | 0 | m_handle.promise().m_awaiter = awaiter; |
115 | 0 | } Unexecuted instantiation: AK::Coroutine<void>::await_suspend(std::__1::coroutine_handle<void>) Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<void, AK::Error> >::await_suspend(std::__1::coroutine_handle<void>) |
116 | | |
117 | | // Do NOT bind the result of await_resume() on a temporary coroutine (or the result of CO_TRY) to auto&&! |
118 | | [[nodiscard]] decltype(auto) await_resume() |
119 | 0 | { |
120 | | if constexpr (SameAs<T, void>) |
121 | 0 | return; |
122 | | else |
123 | | return static_cast<T&&>(*return_value()); |
124 | 0 | } Unexecuted instantiation: AK::Coroutine<void>::await_resume() Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<void, AK::Error> >::await_resume() |
125 | | |
126 | | private: |
127 | | template<typename U> |
128 | | friend struct Detail::TryAwaiter; |
129 | | |
130 | | #ifdef AK_USE_TRY_OR_FAIL_AWAITER |
131 | | template<typename U> |
132 | | friend struct AK::Detail::Test::TryOrFailAwaiter; |
133 | | #endif |
134 | | |
135 | | // You cannot just have return_value and return_void defined in the same promise type because C++. |
136 | | struct CoroutinePromiseBase { |
137 | 0 | CoroutinePromiseBase() = default; |
138 | | |
139 | | Coroutine get_return_object() |
140 | 0 | { |
141 | 0 | return { std::coroutine_handle<promise_type>::from_promise(*static_cast<promise_type*>(this)) }; |
142 | 0 | } Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<void, AK::Error> >::CoroutinePromiseBase::get_return_object() Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<AK::NonnullOwnPtr<Core::TCPSocket>, AK::Error> >::CoroutinePromiseBase::get_return_object() Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<unsigned long, AK::Error> >::CoroutinePromiseBase::get_return_object() |
143 | | |
144 | 0 | Detail::SuspendNever initial_suspend() { return {}; } Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<void, AK::Error> >::CoroutinePromiseBase::initial_suspend() Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<AK::NonnullOwnPtr<Core::TCPSocket>, AK::Error> >::CoroutinePromiseBase::initial_suspend() Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<unsigned long, AK::Error> >::CoroutinePromiseBase::initial_suspend() |
145 | | |
146 | | Detail::SymmetricControlTransfer final_suspend() noexcept |
147 | 0 | { |
148 | 0 | return { m_awaiter }; |
149 | 0 | } Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<void, AK::Error> >::CoroutinePromiseBase::final_suspend() Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<AK::NonnullOwnPtr<Core::TCPSocket>, AK::Error> >::CoroutinePromiseBase::final_suspend() Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<unsigned long, AK::Error> >::CoroutinePromiseBase::final_suspend() |
150 | | |
151 | | void unhandled_exception() = delete; |
152 | | |
153 | | std::coroutine_handle<> m_awaiter; |
154 | | Coroutine* m_coroutine { nullptr }; |
155 | | }; |
156 | | |
157 | | struct CoroutinePromiseValue : CoroutinePromiseBase { |
158 | | template<typename U> |
159 | | requires requires { { T(forward<U>(declval<U>())) }; } |
160 | | void return_value(U&& returned_object) |
161 | 0 | { |
162 | 0 | new (this->m_coroutine->return_value()) T(forward<U>(returned_object)); |
163 | 0 | } Unexecuted instantiation: _ZN2AK9CoroutineINS_7ErrorOrIvNS_5ErrorEEEE21CoroutinePromiseValue12return_valueIS2_QrqXcvT_cl7forwardITL0__Ecl7declvalIS8_EEEEEEvOS7_ Unexecuted instantiation: _ZN2AK9CoroutineINS_7ErrorOrINS_13NonnullOwnPtrIN4Core9TCPSocketEEENS_5ErrorEEEE21CoroutinePromiseValue12return_valueIS6_QrqXcvT_cl7forwardITL0__Ecl7declvalISC_EEEEEEvOSB_ Unexecuted instantiation: _ZN2AK9CoroutineINS_7ErrorOrINS_13NonnullOwnPtrIN4Core9TCPSocketEEENS_5ErrorEEEE21CoroutinePromiseValue12return_valueIS5_QrqXcvT_cl7forwardITL0__Ecl7declvalISC_EEEEEEvOSB_ Unexecuted instantiation: _ZN2AK9CoroutineINS_7ErrorOrImNS_5ErrorEEEE21CoroutinePromiseValue12return_valueIS2_QrqXcvT_cl7forwardITL0__Ecl7declvalIS8_EEEEEEvOS7_ |
164 | | |
165 | | void return_value(T&& returned_object) |
166 | 0 | { |
167 | 0 | new (this->m_coroutine->return_value()) T(move(returned_object)); |
168 | 0 | } Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<void, AK::Error> >::CoroutinePromiseValue::return_value(AK::ErrorOr<void, AK::Error>&&) Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<unsigned long, AK::Error> >::CoroutinePromiseValue::return_value(AK::ErrorOr<unsigned long, AK::Error>&&) |
169 | | }; |
170 | | |
171 | | struct CoroutinePromiseVoid : CoroutinePromiseBase { |
172 | | void return_void() { } |
173 | | }; |
174 | | |
175 | | Coroutine(std::coroutine_handle<promise_type>&& handle) |
176 | 0 | : m_handle(move(handle)) |
177 | 0 | { |
178 | 0 | m_handle.promise().m_coroutine = this; |
179 | 0 | } |
180 | | |
181 | | T* return_value() |
182 | 0 | { |
183 | 0 | return reinterpret_cast<T*>(this->m_return_value); |
184 | 0 | } Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<void, AK::Error> >::return_value() Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<AK::NonnullOwnPtr<Core::TCPSocket>, AK::Error> >::return_value() Unexecuted instantiation: AK::Coroutine<AK::ErrorOr<unsigned long, AK::Error> >::return_value() |
185 | | |
186 | | std::coroutine_handle<promise_type> m_handle; |
187 | | }; |
188 | | |
189 | | template<typename T> |
190 | | T must_sync(Coroutine<ErrorOr<T>>&& coroutine) |
191 | | { |
192 | | VERIFY(coroutine.await_ready()); |
193 | | auto&& object = coroutine.await_resume(); |
194 | | VERIFY(!object.is_error()); |
195 | | if constexpr (!IsSame<T, void>) |
196 | | return object.release_value(); |
197 | | } |
198 | | |
199 | | #ifndef AK_COROUTINE_DESTRUCTION_BROKEN |
200 | | namespace Detail { |
201 | | template<typename T> |
202 | | struct TryAwaiter { |
203 | | TryAwaiter(T& expression) |
204 | | requires(!IsSpecializationOf<T, Coroutine>) |
205 | | : m_expression(&expression) |
206 | | { |
207 | | } |
208 | | |
209 | | TryAwaiter(T&& expression) |
210 | | requires(!IsSpecializationOf<T, Coroutine>) |
211 | | : m_expression(&expression) |
212 | | { |
213 | | } |
214 | | |
215 | | bool await_ready() { return false; } |
216 | | |
217 | | template<typename U> |
218 | | requires IsSpecializationOf<T, ErrorOr> |
219 | | std::coroutine_handle<> await_suspend(std::coroutine_handle<U> handle) |
220 | | { |
221 | | if (!m_expression->is_error()) { |
222 | | return handle; |
223 | | } else { |
224 | | auto awaiter = handle.promise().m_awaiter; |
225 | | auto* coroutine = handle.promise().m_coroutine; |
226 | | using ReturnType = RemoveReference<decltype(*coroutine)>::ReturnType; |
227 | | static_assert(IsSpecializationOf<ReturnType, ErrorOr>, |
228 | | "CO_TRY can only be used inside functions returning a specialization of ErrorOr"); |
229 | | |
230 | | // Move error to the user-visible AK::Coroutine |
231 | | new (coroutine->return_value()) ReturnType(m_expression->release_error()); |
232 | | // ... and tell it that there's a result available. |
233 | | coroutine->m_handle = {}; |
234 | | |
235 | | // Run destructors for locals in the coroutine that failed. |
236 | | handle.destroy(); |
237 | | |
238 | | // Lastly, transfer control to the parent (or nothing, if parent is not yet suspended). |
239 | | if (awaiter) |
240 | | return awaiter; |
241 | | return std::noop_coroutine(); |
242 | | } |
243 | | } |
244 | | |
245 | | void await_resume() |
246 | | requires(IsSame<T, ErrorOr<void>>) |
247 | | { |
248 | | (void)m_expression->release_value(); |
249 | | } |
250 | | |
251 | | decltype(auto) await_resume() |
252 | | { |
253 | | return m_expression->release_value(); |
254 | | } |
255 | | |
256 | | T* m_expression { nullptr }; |
257 | | }; |
258 | | } |
259 | | |
260 | | # ifndef AK_COROUTINE_TYPE_DEDUCTION_BROKEN |
261 | | # define CO_TRY(expression) (co_await ::AK::Detail::TryAwaiter { (expression) }) |
262 | | # else |
263 | | namespace Detail { |
264 | | template<typename T> |
265 | | auto declval_coro_result(Coroutine<T>&&) -> T; |
266 | | template<typename T> |
267 | | auto declval_coro_result(T&&) -> T; |
268 | | } |
269 | | |
270 | | // GCC cannot handle CO_TRY(...CO_TRY(...)...), this hack ensures that it always has the right type information available. |
271 | | // FIXME: Remove this once GCC can correctly infer the result type of `co_await TryAwaiter { ... }`. |
272 | | # define CO_TRY(expression) static_cast<AddRvalueReference<typename RemoveReference<decltype(expression)>::ResultType>>(co_await ::AK::Detail::TryAwaiter { (expression) }) |
273 | | # endif |
274 | | #elifndef AK_COROUTINE_STATEMENT_EXPRS_BROKEN |
275 | | # define CO_TRY(expression) \ |
276 | 0 | ({ \ |
277 | 0 | AK_IGNORE_DIAGNOSTIC("-Wshadow", \ |
278 | 0 | auto _temporary_result = (expression)); \ |
279 | 0 | if (_temporary_result.is_error()) [[unlikely]] \ |
280 | 0 | co_return _temporary_result.release_error(); \ |
281 | 0 | _temporary_result.release_value(); \ |
282 | 0 | }) |
283 | | #else |
284 | | # error Unable to work around compiler bugs in definiton of CO_TRY. |
285 | | #endif |
286 | | |
287 | | } |
288 | | |
289 | | #ifdef USING_AK_GLOBALLY |
290 | | using AK::Coroutine; |
291 | | #endif |