libs/capy/include/boost/capy/when_all.hpp

96.9% Lines (95/98) 91.2% Functions (466/511) 100.0% Branches (24/24)
libs/capy/include/boost/capy/when_all.hpp
Line Branch Hits Source Code
1 //
2 // Copyright (c) 2026 Steve Gerbino
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 // Official repository: https://github.com/cppalliance/capy
8 //
9
10 #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 #define BOOST_CAPY_WHEN_ALL_HPP
12
13 #include <boost/capy/detail/config.hpp>
14 #include <boost/capy/concept/executor.hpp>
15 #include <boost/capy/concept/io_awaitable.hpp>
16 #include <coroutine>
17 #include <boost/capy/ex/io_env.hpp>
18 #include <boost/capy/ex/frame_allocator.hpp>
19 #include <boost/capy/task.hpp>
20
21 #include <array>
22 #include <atomic>
23 #include <exception>
24 #include <optional>
25 #include <stop_token>
26 #include <tuple>
27 #include <type_traits>
28 #include <utility>
29
30 namespace boost {
31 namespace capy {
32
33 namespace detail {
34
35 /** Type trait to filter void types from a tuple.
36
37 Void-returning tasks do not contribute a value to the result tuple.
38 This trait computes the filtered result type.
39
40 Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 */
42 template<typename T>
43 using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44
45 template<typename... Ts>
46 using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47
48 /** Holds the result of a single task within when_all.
49 */
50 template<typename T>
51 struct result_holder
52 {
53 std::optional<T> value_;
54
55 60 void set(T v)
56 {
57 60 value_ = std::move(v);
58 60 }
59
60 53 T get() &&
61 {
62 53 return std::move(*value_);
63 }
64 };
65
66 /** Specialization for void tasks - no value storage needed.
67 */
68 template<>
69 struct result_holder<void>
70 {
71 };
72
73 /** Shared state for when_all operation.
74
75 @tparam Ts The result types of the tasks.
76 */
77 template<typename... Ts>
78 struct when_all_state
79 {
80 static constexpr std::size_t task_count = sizeof...(Ts);
81
82 // Completion tracking - when_all waits for all children
83 std::atomic<std::size_t> remaining_count_;
84
85 // Result storage in input order
86 std::tuple<result_holder<Ts>...> results_;
87
88 // Runner handles - destroyed in await_resume while allocator is valid
89 std::array<std::coroutine_handle<>, task_count> runner_handles_{};
90
91 // Exception storage - first error wins, others discarded
92 std::atomic<bool> has_exception_{false};
93 std::exception_ptr first_exception_;
94
95 // Stop propagation - on error, request stop for siblings
96 std::stop_source stop_source_;
97
98 // Connects parent's stop_token to our stop_source
99 struct stop_callback_fn
100 {
101 std::stop_source* source_;
102 4 void operator()() const { source_->request_stop(); }
103 };
104 using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 std::optional<stop_callback_t> parent_stop_callback_;
106
107 // Parent resumption
108 std::coroutine_handle<> continuation_;
109 io_env const* caller_env_ = nullptr;
110
111 34 when_all_state()
112
1/1
✓ Branch 5 taken 34 times.
34 : remaining_count_(task_count)
113 {
114 34 }
115
116 // Runners self-destruct in final_suspend. No destruction needed here.
117
118 /** Capture an exception (first one wins).
119 */
120 11 void capture_exception(std::exception_ptr ep)
121 {
122 11 bool expected = false;
123
2/2
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 3 times.
11 if(has_exception_.compare_exchange_strong(
124 expected, true, std::memory_order_relaxed))
125 8 first_exception_ = ep;
126 11 }
127
128 };
129
130 /** Wrapper coroutine that intercepts task completion.
131
132 This runner awaits its assigned task and stores the result in
133 the shared state, or captures the exception and requests stop.
134 */
135 template<typename T, typename... Ts>
136 struct when_all_runner
137 {
138 struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
139 {
140 when_all_state<Ts...>* state_ = nullptr;
141 io_env env_;
142
143 80 when_all_runner get_return_object()
144 {
145 80 return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
146 }
147
148 80 std::suspend_always initial_suspend() noexcept
149 {
150 80 return {};
151 }
152
153 80 auto final_suspend() noexcept
154 {
155 struct awaiter
156 {
157 promise_type* p_;
158
159 8 bool await_ready() const noexcept
160 {
161 8 return false;
162 }
163
164 8 std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept
165 {
166 // Extract everything needed before self-destruction.
167 8 auto* state = p_->state_;
168 8 auto* counter = &state->remaining_count_;
169 8 auto* caller_env = state->caller_env_;
170 8 auto cont = state->continuation_;
171
172 8 h.destroy();
173
174 // If last runner, dispatch parent for symmetric transfer.
175 8 auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
176
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
8 if(remaining == 1)
177 4 return caller_env->executor.dispatch(cont);
178 4 return std::noop_coroutine();
179 }
180
181 void await_resume() const noexcept
182 {
183 }
184 };
185 80 return awaiter{this};
186 }
187
188 69 void return_void()
189 {
190 69 }
191
192 11 void unhandled_exception()
193 {
194 11 state_->capture_exception(std::current_exception());
195 // Request stop for sibling tasks
196 11 state_->stop_source_.request_stop();
197 11 }
198
199 template<class Awaitable>
200 struct transform_awaiter
201 {
202 std::decay_t<Awaitable> a_;
203 promise_type* p_;
204
205 80 bool await_ready()
206 {
207 80 return a_.await_ready();
208 }
209
210 80 decltype(auto) await_resume()
211 {
212 80 return a_.await_resume();
213 }
214
215 template<class Promise>
216 79 auto await_suspend(std::coroutine_handle<Promise> h)
217 {
218 79 return a_.await_suspend(h, &p_->env_);
219 }
220 };
221
222 template<class Awaitable>
223 80 auto await_transform(Awaitable&& a)
224 {
225 using A = std::decay_t<Awaitable>;
226 if constexpr (IoAwaitable<A>)
227 {
228 return transform_awaiter<Awaitable>{
229 160 std::forward<Awaitable>(a), this};
230 }
231 else
232 {
233 static_assert(sizeof(A) == 0, "requires IoAwaitable");
234 }
235 80 }
236 };
237
238 std::coroutine_handle<promise_type> h_;
239
240 80 explicit when_all_runner(std::coroutine_handle<promise_type> h)
241 80 : h_(h)
242 {
243 80 }
244
245 // Enable move for all clang versions - some versions need it
246 when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
247
248 // Non-copyable
249 when_all_runner(when_all_runner const&) = delete;
250 when_all_runner& operator=(when_all_runner const&) = delete;
251 when_all_runner& operator=(when_all_runner&&) = delete;
252
253 80 auto release() noexcept
254 {
255 80 return std::exchange(h_, nullptr);
256 }
257 };
258
259 /** Create a runner coroutine for a single awaitable.
260
261 Awaitable is passed directly to ensure proper coroutine frame storage.
262 */
263 template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
264 when_all_runner<awaitable_result_t<Awaitable>, Ts...>
265
1/1
✓ Branch 1 taken 80 times.
80 make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
266 {
267 using T = awaitable_result_t<Awaitable>;
268 if constexpr (std::is_void_v<T>)
269 {
270 co_await std::move(inner);
271 }
272 else
273 {
274 std::get<Index>(state->results_).set(co_await std::move(inner));
275 }
276 160 }
277
278 /** Internal awaitable that launches all runner coroutines and waits.
279
280 This awaitable is used inside the when_all coroutine to handle
281 the concurrent execution of child awaitables.
282 */
283 template<IoAwaitable... Awaitables>
284 class when_all_launcher
285 {
286 using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
287
288 std::tuple<Awaitables...>* awaitables_;
289 state_type* state_;
290
291 public:
292 34 when_all_launcher(
293 std::tuple<Awaitables...>* awaitables,
294 state_type* state)
295 34 : awaitables_(awaitables)
296 34 , state_(state)
297 {
298 34 }
299
300 34 bool await_ready() const noexcept
301 {
302 34 return sizeof...(Awaitables) == 0;
303 }
304
305 34 std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation, io_env const* caller_env)
306 {
307 34 state_->continuation_ = continuation;
308 34 state_->caller_env_ = caller_env;
309
310 // Forward parent's stop requests to children
311
2/2
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 26 times.
34 if(caller_env->stop_token.stop_possible())
312 {
313 16 state_->parent_stop_callback_.emplace(
314 8 caller_env->stop_token,
315 8 typename state_type::stop_callback_fn{&state_->stop_source_});
316
317
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
8 if(caller_env->stop_token.stop_requested())
318 4 state_->stop_source_.request_stop();
319 }
320
321 // CRITICAL: If the last task finishes synchronously then the parent
322 // coroutine resumes, destroying its frame, and destroying this object
323 // prior to the completion of await_suspend. Therefore, await_suspend
324 // must ensure `this` cannot be referenced after calling `launch_one`
325 // for the last time.
326 34 auto token = state_->stop_source_.get_token();
327 [&]<std::size_t... Is>(std::index_sequence<Is...>) {
328
2/2
✓ Branch 2 taken 5 times.
✓ Branch 6 taken 5 times.
5 (..., launch_one<Is>(caller_env->executor, token));
329
2/2
✓ Branch 1 taken 29 times.
✓ Branch 1 taken 5 times.
34 }(std::index_sequence_for<Awaitables...>{});
330
331 // Let signal_completion() handle resumption
332 68 return std::noop_coroutine();
333 34 }
334
335 34 void await_resume() const noexcept
336 {
337 // Results are extracted by the when_all coroutine from state
338 34 }
339
340 private:
341 template<std::size_t I>
342 80 void launch_one(executor_ref caller_ex, std::stop_token token)
343 {
344
1/1
✓ Branch 2 taken 80 times.
80 auto runner = make_when_all_runner<I>(
345 80 std::move(std::get<I>(*awaitables_)), state_);
346
347 80 auto h = runner.release();
348 80 h.promise().state_ = state_;
349 80 h.promise().env_ = io_env{caller_ex, token, state_->caller_env_->allocator};
350
351 80 std::coroutine_handle<> ch{h};
352 80 state_->runner_handles_[I] = ch;
353
1/1
✓ Branch 1 taken 80 times.
80 state_->caller_env_->executor.post(ch);
354 160 }
355 };
356
357 /** Compute the result type for when_all.
358
359 Returns void when all tasks are void (P2300 aligned),
360 otherwise returns a tuple with void types filtered out.
361 */
362 template<typename... Ts>
363 using when_all_result_t = std::conditional_t<
364 std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
365 void,
366 filter_void_tuple_t<Ts...>>;
367
368 /** Helper to extract a single result, returning empty tuple for void.
369 This is a separate function to work around a GCC-11 ICE that occurs
370 when using nested immediately-invoked lambdas with pack expansion.
371 */
372 template<std::size_t I, typename... Ts>
373 57 auto extract_single_result(when_all_state<Ts...>& state)
374 {
375 using T = std::tuple_element_t<I, std::tuple<Ts...>>;
376 if constexpr (std::is_void_v<T>)
377 4 return std::tuple<>();
378 else
379
1/1
✓ Branch 4 taken 53 times.
53 return std::make_tuple(std::move(std::get<I>(state.results_)).get());
380 }
381
382 /** Extract results from state, filtering void types.
383 */
384 template<typename... Ts>
385 24 auto extract_results(when_all_state<Ts...>& state)
386 {
387 24 return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
388
5/5
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 1 time.
✓ Branch 4 taken 4 times.
✓ Branch 5 taken 1 time.
✓ Branch 7 taken 4 times.
5 return std::tuple_cat(extract_single_result<Is>(state)...);
389
1/1
✓ Branch 1 taken 24 times.
48 }(std::index_sequence_for<Ts...>{});
390 }
391
392 } // namespace detail
393
394 /** Execute multiple awaitables concurrently and collect their results.
395
396 Launches all awaitables simultaneously and waits for all to complete
397 before returning. Results are collected in input order. If any
398 awaitable throws, cancellation is requested for siblings and the first
399 exception is rethrown after all awaitables complete.
400
401 @li All child awaitables run concurrently on the caller's executor
402 @li Results are returned as a tuple in input order
403 @li Void-returning awaitables do not contribute to the result tuple
404 @li If all awaitables return void, `when_all` returns `task<void>`
405 @li First exception wins; subsequent exceptions are discarded
406 @li Stop is requested for siblings on first error
407 @li Completes only after all children have finished
408
409 @par Thread Safety
410 The returned task must be awaited from a single execution context.
411 Child awaitables execute concurrently but complete through the caller's
412 executor.
413
414 @param awaitables The awaitables to execute concurrently. Each must
415 satisfy @ref IoAwaitable and is consumed (moved-from) when
416 `when_all` is awaited.
417
418 @return A task yielding a tuple of non-void results. Returns
419 `task<void>` when all input awaitables return void.
420
421 @par Example
422
423 @code
424 task<> example()
425 {
426 // Concurrent fetch, results collected in order
427 auto [user, posts] = co_await when_all(
428 fetch_user( id ), // task<User>
429 fetch_posts( id ) // task<std::vector<Post>>
430 );
431
432 // Void awaitables don't contribute to result
433 co_await when_all(
434 log_event( "start" ), // task<void>
435 notify_user( id ) // task<void>
436 );
437 // Returns task<void>, no result tuple
438 }
439 @endcode
440
441 @see IoAwaitable, task
442 */
443 template<IoAwaitable... As>
444
1/1
✓ Branch 1 taken 34 times.
34 [[nodiscard]] auto when_all(As... awaitables)
445 -> task<detail::when_all_result_t<detail::awaitable_result_t<As>...>>
446 {
447 using result_type = detail::when_all_result_t<detail::awaitable_result_t<As>...>;
448
449 // State is stored in the coroutine frame, using the frame allocator
450 detail::when_all_state<detail::awaitable_result_t<As>...> state;
451
452 // Store awaitables in the frame
453 std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
454
455 // Launch all awaitables and wait for completion
456 co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
457
458 // Propagate first exception if any.
459 // Safe without explicit acquire: capture_exception() is sequenced-before
460 // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
461 // last task's decrement that resumes this coroutine.
462 if(state.first_exception_)
463 std::rethrow_exception(state.first_exception_);
464
465 // Extract and return results
466 if constexpr (std::is_void_v<result_type>)
467 co_return;
468 else
469 co_return detail::extract_results(state);
470 68 }
471
472 /// Compute the result type of `when_all` for the given task types.
473 template<typename... Ts>
474 using when_all_result_type = detail::when_all_result_t<Ts...>;
475
476 } // namespace capy
477 } // namespace boost
478
479 #endif
480