Coverage Report

Created: 2024-10-16 07:58

/rust/registry/src/index.crates.io-6f17d22bba15001f/rayon-core-1.12.1/src/broadcast/mod.rs
Line
Count
Source (jump to first uncovered line)
1
use crate::job::{ArcJob, StackJob};
2
use crate::latch::{CountLatch, LatchRef};
3
use crate::registry::{Registry, WorkerThread};
4
use std::fmt;
5
use std::marker::PhantomData;
6
use std::sync::Arc;
7
8
mod test;
9
10
/// Executes `op` within every thread in the current threadpool. If this is
11
/// called from a non-Rayon thread, it will execute in the global threadpool.
12
/// Any attempts to use `join`, `scope`, or parallel iterators will then operate
13
/// within that threadpool. When the call has completed on each thread, returns
14
/// a vector containing all of their return values.
15
///
16
/// For more information, see the [`ThreadPool::broadcast()`][m] method.
17
///
18
/// [m]: struct.ThreadPool.html#method.broadcast
19
0
pub fn broadcast<OP, R>(op: OP) -> Vec<R>
20
0
where
21
0
    OP: Fn(BroadcastContext<'_>) -> R + Sync,
22
0
    R: Send,
23
0
{
24
0
    // We assert that current registry has not terminated.
25
0
    unsafe { broadcast_in(op, &Registry::current()) }
26
0
}
27
28
/// Spawns an asynchronous task on every thread in this thread-pool. This task
29
/// will run in the implicit, global scope, which means that it may outlast the
30
/// current stack frame -- therefore, it cannot capture any references onto the
31
/// stack (you will likely need a `move` closure).
32
///
33
/// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method.
34
///
35
/// [m]: struct.ThreadPool.html#method.spawn_broadcast
36
0
pub fn spawn_broadcast<OP>(op: OP)
37
0
where
38
0
    OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
39
0
{
40
0
    // We assert that current registry has not terminated.
41
0
    unsafe { spawn_broadcast_in(op, &Registry::current()) }
42
0
}
43
44
/// Provides context to a closure called by `broadcast`.
45
pub struct BroadcastContext<'a> {
46
    worker: &'a WorkerThread,
47
48
    /// Make sure to prevent auto-traits like `Send` and `Sync`.
49
    _marker: PhantomData<&'a mut dyn Fn()>,
50
}
51
52
impl<'a> BroadcastContext<'a> {
53
0
    pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
54
0
        let worker_thread = WorkerThread::current();
55
0
        assert!(!worker_thread.is_null());
56
0
        f(BroadcastContext {
57
0
            worker: unsafe { &*worker_thread },
58
0
            _marker: PhantomData,
59
0
        })
60
0
    }
61
62
    /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`).
63
    #[inline]
64
0
    pub fn index(&self) -> usize {
65
0
        self.worker.index()
66
0
    }
67
68
    /// The number of threads receiving the broadcast in the thread pool.
69
    ///
70
    /// # Future compatibility note
71
    ///
72
    /// Future versions of Rayon might vary the number of threads over time, but
73
    /// this method will always return the number of threads which are actually
74
    /// receiving your particular `broadcast` call.
75
    #[inline]
76
0
    pub fn num_threads(&self) -> usize {
77
0
        self.worker.registry().num_threads()
78
0
    }
79
}
80
81
impl<'a> fmt::Debug for BroadcastContext<'a> {
82
0
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
83
0
        fmt.debug_struct("BroadcastContext")
84
0
            .field("index", &self.index())
85
0
            .field("num_threads", &self.num_threads())
86
0
            .field("pool_id", &self.worker.registry().id())
87
0
            .finish()
88
0
    }
89
}
90
91
/// Execute `op` on every thread in the pool. It will be executed on each
92
/// thread when they have nothing else to do locally, before they try to
93
/// steal work from other threads. This function will not return until all
94
/// threads have completed the `op`.
95
///
96
/// Unsafe because `registry` must not yet have terminated.
97
0
pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
98
0
where
99
0
    OP: Fn(BroadcastContext<'_>) -> R + Sync,
100
0
    R: Send,
101
0
{
102
0
    let f = move |injected: bool| {
103
0
        debug_assert!(injected);
104
0
        BroadcastContext::with(&op)
105
0
    };
106
107
0
    let n_threads = registry.num_threads();
108
0
    let current_thread = WorkerThread::current().as_ref();
109
0
    let latch = CountLatch::with_count(n_threads, current_thread);
110
0
    let jobs: Vec<_> = (0..n_threads)
111
0
        .map(|_| StackJob::new(&f, LatchRef::new(&latch)))
112
0
        .collect();
113
0
    let job_refs = jobs.iter().map(|job| job.as_job_ref());
114
0
115
0
    registry.inject_broadcast(job_refs);
116
0
117
0
    // Wait for all jobs to complete, then collect the results, maybe propagating a panic.
118
0
    latch.wait(current_thread);
119
0
    jobs.into_iter().map(|job| job.into_result()).collect()
120
0
}
121
122
/// Execute `op` on every thread in the pool. It will be executed on each
123
/// thread when they have nothing else to do locally, before they try to
124
/// steal work from other threads. This function returns immediately after
125
/// injecting the jobs.
126
///
127
/// Unsafe because `registry` must not yet have terminated.
128
0
pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
129
0
where
130
0
    OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
131
0
{
132
0
    let job = ArcJob::new({
133
0
        let registry = Arc::clone(registry);
134
0
        move || {
135
0
            registry.catch_unwind(|| BroadcastContext::with(&op));
136
0
            registry.terminate(); // (*) permit registry to terminate now
137
0
        }
138
0
    });
139
0
140
0
    let n_threads = registry.num_threads();
141
0
    let job_refs = (0..n_threads).map(|_| {
142
0
        // Ensure that registry cannot terminate until this job has executed
143
0
        // on each thread. This ref is decremented at the (*) above.
144
0
        registry.increment_terminate_count();
145
0
146
0
        ArcJob::as_static_job_ref(&job)
147
0
    });
148
0
149
0
    registry.inject_broadcast(job_refs);
150
0
}