/rust/registry/src/index.crates.io-6f17d22bba15001f/rayon-core-1.12.1/src/broadcast/mod.rs
Line | Count | Source (jump to first uncovered line) |
1 | | use crate::job::{ArcJob, StackJob}; |
2 | | use crate::latch::{CountLatch, LatchRef}; |
3 | | use crate::registry::{Registry, WorkerThread}; |
4 | | use std::fmt; |
5 | | use std::marker::PhantomData; |
6 | | use std::sync::Arc; |
7 | | |
8 | | mod test; |
9 | | |
10 | | /// Executes `op` within every thread in the current threadpool. If this is |
11 | | /// called from a non-Rayon thread, it will execute in the global threadpool. |
12 | | /// Any attempts to use `join`, `scope`, or parallel iterators will then operate |
13 | | /// within that threadpool. When the call has completed on each thread, returns |
14 | | /// a vector containing all of their return values. |
15 | | /// |
16 | | /// For more information, see the [`ThreadPool::broadcast()`][m] method. |
17 | | /// |
18 | | /// [m]: struct.ThreadPool.html#method.broadcast |
19 | 0 | pub fn broadcast<OP, R>(op: OP) -> Vec<R> |
20 | 0 | where |
21 | 0 | OP: Fn(BroadcastContext<'_>) -> R + Sync, |
22 | 0 | R: Send, |
23 | 0 | { |
24 | 0 | // We assert that current registry has not terminated. |
25 | 0 | unsafe { broadcast_in(op, &Registry::current()) } |
26 | 0 | } |
27 | | |
28 | | /// Spawns an asynchronous task on every thread in this thread-pool. This task |
29 | | /// will run in the implicit, global scope, which means that it may outlast the |
30 | | /// current stack frame -- therefore, it cannot capture any references onto the |
31 | | /// stack (you will likely need a `move` closure). |
32 | | /// |
33 | | /// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method. |
34 | | /// |
35 | | /// [m]: struct.ThreadPool.html#method.spawn_broadcast |
36 | 0 | pub fn spawn_broadcast<OP>(op: OP) |
37 | 0 | where |
38 | 0 | OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, |
39 | 0 | { |
40 | 0 | // We assert that current registry has not terminated. |
41 | 0 | unsafe { spawn_broadcast_in(op, &Registry::current()) } |
42 | 0 | } |
43 | | |
44 | | /// Provides context to a closure called by `broadcast`. |
45 | | pub struct BroadcastContext<'a> { |
46 | | worker: &'a WorkerThread, |
47 | | |
48 | | /// Make sure to prevent auto-traits like `Send` and `Sync`. |
49 | | _marker: PhantomData<&'a mut dyn Fn()>, |
50 | | } |
51 | | |
52 | | impl<'a> BroadcastContext<'a> { |
53 | 0 | pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R { |
54 | 0 | let worker_thread = WorkerThread::current(); |
55 | 0 | assert!(!worker_thread.is_null()); |
56 | 0 | f(BroadcastContext { |
57 | 0 | worker: unsafe { &*worker_thread }, |
58 | 0 | _marker: PhantomData, |
59 | 0 | }) |
60 | 0 | } |
61 | | |
62 | | /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`). |
63 | | #[inline] |
64 | 0 | pub fn index(&self) -> usize { |
65 | 0 | self.worker.index() |
66 | 0 | } |
67 | | |
68 | | /// The number of threads receiving the broadcast in the thread pool. |
69 | | /// |
70 | | /// # Future compatibility note |
71 | | /// |
72 | | /// Future versions of Rayon might vary the number of threads over time, but |
73 | | /// this method will always return the number of threads which are actually |
74 | | /// receiving your particular `broadcast` call. |
75 | | #[inline] |
76 | 0 | pub fn num_threads(&self) -> usize { |
77 | 0 | self.worker.registry().num_threads() |
78 | 0 | } |
79 | | } |
80 | | |
81 | | impl<'a> fmt::Debug for BroadcastContext<'a> { |
82 | 0 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
83 | 0 | fmt.debug_struct("BroadcastContext") |
84 | 0 | .field("index", &self.index()) |
85 | 0 | .field("num_threads", &self.num_threads()) |
86 | 0 | .field("pool_id", &self.worker.registry().id()) |
87 | 0 | .finish() |
88 | 0 | } |
89 | | } |
90 | | |
91 | | /// Execute `op` on every thread in the pool. It will be executed on each |
92 | | /// thread when they have nothing else to do locally, before they try to |
93 | | /// steal work from other threads. This function will not return until all |
94 | | /// threads have completed the `op`. |
95 | | /// |
96 | | /// Unsafe because `registry` must not yet have terminated. |
97 | 0 | pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R> |
98 | 0 | where |
99 | 0 | OP: Fn(BroadcastContext<'_>) -> R + Sync, |
100 | 0 | R: Send, |
101 | 0 | { |
102 | 0 | let f = move |injected: bool| { |
103 | 0 | debug_assert!(injected); |
104 | 0 | BroadcastContext::with(&op) |
105 | 0 | }; |
106 | | |
107 | 0 | let n_threads = registry.num_threads(); |
108 | 0 | let current_thread = WorkerThread::current().as_ref(); |
109 | 0 | let latch = CountLatch::with_count(n_threads, current_thread); |
110 | 0 | let jobs: Vec<_> = (0..n_threads) |
111 | 0 | .map(|_| StackJob::new(&f, LatchRef::new(&latch))) |
112 | 0 | .collect(); |
113 | 0 | let job_refs = jobs.iter().map(|job| job.as_job_ref()); |
114 | 0 |
|
115 | 0 | registry.inject_broadcast(job_refs); |
116 | 0 |
|
117 | 0 | // Wait for all jobs to complete, then collect the results, maybe propagating a panic. |
118 | 0 | latch.wait(current_thread); |
119 | 0 | jobs.into_iter().map(|job| job.into_result()).collect() |
120 | 0 | } |
121 | | |
122 | | /// Execute `op` on every thread in the pool. It will be executed on each |
123 | | /// thread when they have nothing else to do locally, before they try to |
124 | | /// steal work from other threads. This function returns immediately after |
125 | | /// injecting the jobs. |
126 | | /// |
127 | | /// Unsafe because `registry` must not yet have terminated. |
128 | 0 | pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>) |
129 | 0 | where |
130 | 0 | OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, |
131 | 0 | { |
132 | 0 | let job = ArcJob::new({ |
133 | 0 | let registry = Arc::clone(registry); |
134 | 0 | move || { |
135 | 0 | registry.catch_unwind(|| BroadcastContext::with(&op)); |
136 | 0 | registry.terminate(); // (*) permit registry to terminate now |
137 | 0 | } |
138 | 0 | }); |
139 | 0 |
|
140 | 0 | let n_threads = registry.num_threads(); |
141 | 0 | let job_refs = (0..n_threads).map(|_| { |
142 | 0 | // Ensure that registry cannot terminate until this job has executed |
143 | 0 | // on each thread. This ref is decremented at the (*) above. |
144 | 0 | registry.increment_terminate_count(); |
145 | 0 |
|
146 | 0 | ArcJob::as_static_job_ref(&job) |
147 | 0 | }); |
148 | 0 |
|
149 | 0 | registry.inject_broadcast(job_refs); |
150 | 0 | } |