/rust/registry/src/index.crates.io-1949cf8c6b5b557f/threadpool-1.8.1/src/lib.rs
Line | Count | Source |
1 | | // Copyright 2014 The Rust Project Developers. See the COPYRIGHT |
2 | | // file at the top-level directory of this distribution and at |
3 | | // http://rust-lang.org/COPYRIGHT. |
4 | | // |
5 | | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
6 | | // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
7 | | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your |
8 | | // option. This file may not be copied, modified, or distributed |
9 | | // except according to those terms. |
10 | | |
11 | | //! A thread pool used to execute functions in parallel. |
12 | | //! |
13 | | //! Spawns a specified number of worker threads and replenishes the pool if any worker threads |
14 | | //! panic. |
15 | | //! |
16 | | //! # Examples |
17 | | //! |
18 | | //! ## Synchronized with a channel |
19 | | //! |
20 | | //! Every thread sends one message over the channel, which then is collected with the `take()`. |
21 | | //! |
22 | | //! ``` |
23 | | //! use threadpool::ThreadPool; |
24 | | //! use std::sync::mpsc::channel; |
25 | | //! |
26 | | //! let n_workers = 4; |
27 | | //! let n_jobs = 8; |
28 | | //! let pool = ThreadPool::new(n_workers); |
29 | | //! |
30 | | //! let (tx, rx) = channel(); |
31 | | //! for _ in 0..n_jobs { |
32 | | //! let tx = tx.clone(); |
33 | | //! pool.execute(move|| { |
34 | | //! tx.send(1).expect("channel will be there waiting for the pool"); |
35 | | //! }); |
36 | | //! } |
37 | | //! |
38 | | //! assert_eq!(rx.iter().take(n_jobs).fold(0, |a, b| a + b), 8); |
39 | | //! ``` |
40 | | //! |
41 | | //! ## Synchronized with a barrier |
42 | | //! |
43 | | //! Keep in mind, if a barrier synchronizes more jobs than you have workers in the pool, |
44 | | //! you will end up with a [deadlock](https://en.wikipedia.org/wiki/Deadlock) |
45 | | //! at the barrier which is [not considered unsafe]( |
46 | | //! https://doc.rust-lang.org/reference/behavior-not-considered-unsafe.html). |
47 | | //! |
48 | | //! ``` |
49 | | //! use threadpool::ThreadPool; |
50 | | //! use std::sync::{Arc, Barrier}; |
51 | | //! use std::sync::atomic::{AtomicUsize, Ordering}; |
52 | | //! |
53 | | //! // create at least as many workers as jobs or you will deadlock yourself |
54 | | //! let n_workers = 42; |
55 | | //! let n_jobs = 23; |
56 | | //! let pool = ThreadPool::new(n_workers); |
57 | | //! let an_atomic = Arc::new(AtomicUsize::new(0)); |
58 | | //! |
59 | | //! assert!(n_jobs <= n_workers, "too many jobs, will deadlock"); |
60 | | //! |
61 | | //! // create a barrier that waits for all jobs plus the starter thread |
62 | | //! let barrier = Arc::new(Barrier::new(n_jobs + 1)); |
63 | | //! for _ in 0..n_jobs { |
64 | | //! let barrier = barrier.clone(); |
65 | | //! let an_atomic = an_atomic.clone(); |
66 | | //! |
67 | | //! pool.execute(move|| { |
68 | | //! // do the heavy work |
69 | | //! an_atomic.fetch_add(1, Ordering::Relaxed); |
70 | | //! |
71 | | //! // then wait for the other threads |
72 | | //! barrier.wait(); |
73 | | //! }); |
74 | | //! } |
75 | | //! |
76 | | //! // wait for the threads to finish the work |
77 | | //! barrier.wait(); |
78 | | //! assert_eq!(an_atomic.load(Ordering::SeqCst), /* n_jobs = */ 23); |
79 | | //! ``` |
80 | | |
81 | | extern crate num_cpus; |
82 | | |
83 | | use std::fmt; |
84 | | use std::sync::atomic::{AtomicUsize, Ordering}; |
85 | | use std::sync::mpsc::{channel, Receiver, Sender}; |
86 | | use std::sync::{Arc, Condvar, Mutex}; |
87 | | use std::thread; |
88 | | |
89 | | trait FnBox { |
90 | | fn call_box(self: Box<Self>); |
91 | | } |
92 | | |
93 | | impl<F: FnOnce()> FnBox for F { |
94 | 0 | fn call_box(self: Box<F>) { |
95 | 0 | (*self)() |
96 | 0 | } |
97 | | } |
98 | | |
99 | | type Thunk<'a> = Box<FnBox + Send + 'a>; |
100 | | |
101 | | struct Sentinel<'a> { |
102 | | shared_data: &'a Arc<ThreadPoolSharedData>, |
103 | | active: bool, |
104 | | } |
105 | | |
106 | | impl<'a> Sentinel<'a> { |
107 | 0 | fn new(shared_data: &'a Arc<ThreadPoolSharedData>) -> Sentinel<'a> { |
108 | 0 | Sentinel { |
109 | 0 | shared_data: shared_data, |
110 | 0 | active: true, |
111 | 0 | } |
112 | 0 | } |
113 | | |
114 | | /// Cancel and destroy this sentinel. |
115 | 0 | fn cancel(mut self) { |
116 | 0 | self.active = false; |
117 | 0 | } |
118 | | } |
119 | | |
120 | | impl<'a> Drop for Sentinel<'a> { |
121 | 0 | fn drop(&mut self) { |
122 | 0 | if self.active { |
123 | 0 | self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst); |
124 | 0 | if thread::panicking() { |
125 | 0 | self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst); |
126 | 0 | } |
127 | 0 | self.shared_data.no_work_notify_all(); |
128 | 0 | spawn_in_pool(self.shared_data.clone()) |
129 | 0 | } |
130 | 0 | } |
131 | | } |
132 | | |
133 | | /// [`ThreadPool`] factory, which can be used in order to configure the properties of the |
134 | | /// [`ThreadPool`]. |
135 | | /// |
136 | | /// The three configuration options available: |
137 | | /// |
138 | | /// * `num_threads`: maximum number of threads that will be alive at any given moment by the built |
139 | | /// [`ThreadPool`] |
140 | | /// * `thread_name`: thread name for each of the threads spawned by the built [`ThreadPool`] |
141 | | /// * `thread_stack_size`: stack size (in bytes) for each of the threads spawned by the built |
142 | | /// [`ThreadPool`] |
143 | | /// |
144 | | /// [`ThreadPool`]: struct.ThreadPool.html |
145 | | /// |
146 | | /// # Examples |
147 | | /// |
148 | | /// Build a [`ThreadPool`] that uses a maximum of eight threads simultaneously and each thread has |
149 | | /// a 8 MB stack size: |
150 | | /// |
151 | | /// ``` |
152 | | /// let pool = threadpool::Builder::new() |
153 | | /// .num_threads(8) |
154 | | /// .thread_stack_size(8_000_000) |
155 | | /// .build(); |
156 | | /// ``` |
157 | | #[derive(Clone, Default)] |
158 | | pub struct Builder { |
159 | | num_threads: Option<usize>, |
160 | | thread_name: Option<String>, |
161 | | thread_stack_size: Option<usize>, |
162 | | } |
163 | | |
164 | | impl Builder { |
165 | | /// Initiate a new [`Builder`]. |
166 | | /// |
167 | | /// [`Builder`]: struct.Builder.html |
168 | | /// |
169 | | /// # Examples |
170 | | /// |
171 | | /// ``` |
172 | | /// let builder = threadpool::Builder::new(); |
173 | | /// ``` |
174 | 0 | pub fn new() -> Builder { |
175 | 0 | Builder { |
176 | 0 | num_threads: None, |
177 | 0 | thread_name: None, |
178 | 0 | thread_stack_size: None, |
179 | 0 | } |
180 | 0 | } |
181 | | |
182 | | /// Set the maximum number of worker-threads that will be alive at any given moment by the built |
183 | | /// [`ThreadPool`]. If not specified, defaults the number of threads to the number of CPUs. |
184 | | /// |
185 | | /// [`ThreadPool`]: struct.ThreadPool.html |
186 | | /// |
187 | | /// # Panics |
188 | | /// |
189 | | /// This method will panic if `num_threads` is 0. |
190 | | /// |
191 | | /// # Examples |
192 | | /// |
193 | | /// No more than eight threads will be alive simultaneously for this pool: |
194 | | /// |
195 | | /// ``` |
196 | | /// use std::thread; |
197 | | /// |
198 | | /// let pool = threadpool::Builder::new() |
199 | | /// .num_threads(8) |
200 | | /// .build(); |
201 | | /// |
202 | | /// for _ in 0..100 { |
203 | | /// pool.execute(|| { |
204 | | /// println!("Hello from a worker thread!") |
205 | | /// }) |
206 | | /// } |
207 | | /// ``` |
208 | 0 | pub fn num_threads(mut self, num_threads: usize) -> Builder { |
209 | 0 | assert!(num_threads > 0); |
210 | 0 | self.num_threads = Some(num_threads); |
211 | 0 | self |
212 | 0 | } |
213 | | |
214 | | /// Set the thread name for each of the threads spawned by the built [`ThreadPool`]. If not |
215 | | /// specified, threads spawned by the thread pool will be unnamed. |
216 | | /// |
217 | | /// [`ThreadPool`]: struct.ThreadPool.html |
218 | | /// |
219 | | /// # Examples |
220 | | /// |
221 | | /// Each thread spawned by this pool will have the name "foo": |
222 | | /// |
223 | | /// ``` |
224 | | /// use std::thread; |
225 | | /// |
226 | | /// let pool = threadpool::Builder::new() |
227 | | /// .thread_name("foo".into()) |
228 | | /// .build(); |
229 | | /// |
230 | | /// for _ in 0..100 { |
231 | | /// pool.execute(|| { |
232 | | /// assert_eq!(thread::current().name(), Some("foo")); |
233 | | /// }) |
234 | | /// } |
235 | | /// ``` |
236 | 0 | pub fn thread_name(mut self, name: String) -> Builder { |
237 | 0 | self.thread_name = Some(name); |
238 | 0 | self |
239 | 0 | } |
240 | | |
241 | | /// Set the stack size (in bytes) for each of the threads spawned by the built [`ThreadPool`]. |
242 | | /// If not specified, threads spawned by the threadpool will have a stack size [as specified in |
243 | | /// the `std::thread` documentation][thread]. |
244 | | /// |
245 | | /// [thread]: https://doc.rust-lang.org/nightly/std/thread/index.html#stack-size |
246 | | /// [`ThreadPool`]: struct.ThreadPool.html |
247 | | /// |
248 | | /// # Examples |
249 | | /// |
250 | | /// Each thread spawned by this pool will have a 4 MB stack: |
251 | | /// |
252 | | /// ``` |
253 | | /// let pool = threadpool::Builder::new() |
254 | | /// .thread_stack_size(4_000_000) |
255 | | /// .build(); |
256 | | /// |
257 | | /// for _ in 0..100 { |
258 | | /// pool.execute(|| { |
259 | | /// println!("This thread has a 4 MB stack size!"); |
260 | | /// }) |
261 | | /// } |
262 | | /// ``` |
263 | 0 | pub fn thread_stack_size(mut self, size: usize) -> Builder { |
264 | 0 | self.thread_stack_size = Some(size); |
265 | 0 | self |
266 | 0 | } |
267 | | |
268 | | /// Finalize the [`Builder`] and build the [`ThreadPool`]. |
269 | | /// |
270 | | /// [`Builder`]: struct.Builder.html |
271 | | /// [`ThreadPool`]: struct.ThreadPool.html |
272 | | /// |
273 | | /// # Examples |
274 | | /// |
275 | | /// ``` |
276 | | /// let pool = threadpool::Builder::new() |
277 | | /// .num_threads(8) |
278 | | /// .thread_stack_size(4_000_000) |
279 | | /// .build(); |
280 | | /// ``` |
281 | 0 | pub fn build(self) -> ThreadPool { |
282 | 0 | let (tx, rx) = channel::<Thunk<'static>>(); |
283 | | |
284 | 0 | let num_threads = self.num_threads.unwrap_or_else(num_cpus::get); |
285 | | |
286 | 0 | let shared_data = Arc::new(ThreadPoolSharedData { |
287 | 0 | name: self.thread_name, |
288 | 0 | job_receiver: Mutex::new(rx), |
289 | 0 | empty_condvar: Condvar::new(), |
290 | 0 | empty_trigger: Mutex::new(()), |
291 | 0 | join_generation: AtomicUsize::new(0), |
292 | 0 | queued_count: AtomicUsize::new(0), |
293 | 0 | active_count: AtomicUsize::new(0), |
294 | 0 | max_thread_count: AtomicUsize::new(num_threads), |
295 | 0 | panic_count: AtomicUsize::new(0), |
296 | 0 | stack_size: self.thread_stack_size, |
297 | 0 | }); |
298 | | |
299 | | // Threadpool threads |
300 | 0 | for _ in 0..num_threads { |
301 | 0 | spawn_in_pool(shared_data.clone()); |
302 | 0 | } |
303 | | |
304 | 0 | ThreadPool { |
305 | 0 | jobs: tx, |
306 | 0 | shared_data: shared_data, |
307 | 0 | } |
308 | 0 | } |
309 | | } |
310 | | |
311 | | struct ThreadPoolSharedData { |
312 | | name: Option<String>, |
313 | | job_receiver: Mutex<Receiver<Thunk<'static>>>, |
314 | | empty_trigger: Mutex<()>, |
315 | | empty_condvar: Condvar, |
316 | | join_generation: AtomicUsize, |
317 | | queued_count: AtomicUsize, |
318 | | active_count: AtomicUsize, |
319 | | max_thread_count: AtomicUsize, |
320 | | panic_count: AtomicUsize, |
321 | | stack_size: Option<usize>, |
322 | | } |
323 | | |
324 | | impl ThreadPoolSharedData { |
325 | 0 | fn has_work(&self) -> bool { |
326 | 0 | self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0 |
327 | 0 | } |
328 | | |
329 | | /// Notify all observers joining this pool if there is no more work to do. |
330 | 0 | fn no_work_notify_all(&self) { |
331 | 0 | if !self.has_work() { |
332 | 0 | *self |
333 | 0 | .empty_trigger |
334 | 0 | .lock() |
335 | 0 | .expect("Unable to notify all joining threads"); |
336 | 0 | self.empty_condvar.notify_all(); |
337 | 0 | } |
338 | 0 | } |
339 | | } |
340 | | |
341 | | /// Abstraction of a thread pool for basic parallelism. |
342 | | pub struct ThreadPool { |
343 | | // How the threadpool communicates with subthreads. |
344 | | // |
345 | | // This is the only such Sender, so when it is dropped all subthreads will |
346 | | // quit. |
347 | | jobs: Sender<Thunk<'static>>, |
348 | | shared_data: Arc<ThreadPoolSharedData>, |
349 | | } |
350 | | |
351 | | impl ThreadPool { |
352 | | /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently. |
353 | | /// |
354 | | /// # Panics |
355 | | /// |
356 | | /// This function will panic if `num_threads` is 0. |
357 | | /// |
358 | | /// # Examples |
359 | | /// |
360 | | /// Create a new thread pool capable of executing four jobs concurrently: |
361 | | /// |
362 | | /// ``` |
363 | | /// use threadpool::ThreadPool; |
364 | | /// |
365 | | /// let pool = ThreadPool::new(4); |
366 | | /// ``` |
367 | 0 | pub fn new(num_threads: usize) -> ThreadPool { |
368 | 0 | Builder::new().num_threads(num_threads).build() |
369 | 0 | } |
370 | | |
371 | | /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently. |
372 | | /// Each thread will have the [name][thread name] `name`. |
373 | | /// |
374 | | /// # Panics |
375 | | /// |
376 | | /// This function will panic if `num_threads` is 0. |
377 | | /// |
378 | | /// # Examples |
379 | | /// |
380 | | /// ```rust |
381 | | /// use std::thread; |
382 | | /// use threadpool::ThreadPool; |
383 | | /// |
384 | | /// let pool = ThreadPool::with_name("worker".into(), 2); |
385 | | /// for _ in 0..2 { |
386 | | /// pool.execute(|| { |
387 | | /// assert_eq!( |
388 | | /// thread::current().name(), |
389 | | /// Some("worker") |
390 | | /// ); |
391 | | /// }); |
392 | | /// } |
393 | | /// pool.join(); |
394 | | /// ``` |
395 | | /// |
396 | | /// [thread name]: https://doc.rust-lang.org/std/thread/struct.Thread.html#method.name |
397 | 0 | pub fn with_name(name: String, num_threads: usize) -> ThreadPool { |
398 | 0 | Builder::new() |
399 | 0 | .num_threads(num_threads) |
400 | 0 | .thread_name(name) |
401 | 0 | .build() |
402 | 0 | } |
403 | | |
404 | | /// **Deprecated: Use [`ThreadPool::with_name`](#method.with_name)** |
405 | | #[inline(always)] |
406 | | #[deprecated(since = "1.4.0", note = "use ThreadPool::with_name")] |
407 | 0 | pub fn new_with_name(name: String, num_threads: usize) -> ThreadPool { |
408 | 0 | Self::with_name(name, num_threads) |
409 | 0 | } |
410 | | |
411 | | /// Executes the function `job` on a thread in the pool. |
412 | | /// |
413 | | /// # Examples |
414 | | /// |
415 | | /// Execute four jobs on a thread pool that can run two jobs concurrently: |
416 | | /// |
417 | | /// ``` |
418 | | /// use threadpool::ThreadPool; |
419 | | /// |
420 | | /// let pool = ThreadPool::new(2); |
421 | | /// pool.execute(|| println!("hello")); |
422 | | /// pool.execute(|| println!("world")); |
423 | | /// pool.execute(|| println!("foo")); |
424 | | /// pool.execute(|| println!("bar")); |
425 | | /// pool.join(); |
426 | | /// ``` |
427 | 0 | pub fn execute<F>(&self, job: F) |
428 | 0 | where |
429 | 0 | F: FnOnce() + Send + 'static, |
430 | | { |
431 | 0 | self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst); |
432 | 0 | self.jobs |
433 | 0 | .send(Box::new(job)) |
434 | 0 | .expect("ThreadPool::execute unable to send job into queue."); |
435 | 0 | } |
436 | | |
437 | | /// Returns the number of jobs waiting to executed in the pool. |
438 | | /// |
439 | | /// # Examples |
440 | | /// |
441 | | /// ``` |
442 | | /// use threadpool::ThreadPool; |
443 | | /// use std::time::Duration; |
444 | | /// use std::thread::sleep; |
445 | | /// |
446 | | /// let pool = ThreadPool::new(2); |
447 | | /// for _ in 0..10 { |
448 | | /// pool.execute(|| { |
449 | | /// sleep(Duration::from_secs(100)); |
450 | | /// }); |
451 | | /// } |
452 | | /// |
453 | | /// sleep(Duration::from_secs(1)); // wait for threads to start |
454 | | /// assert_eq!(8, pool.queued_count()); |
455 | | /// ``` |
456 | 0 | pub fn queued_count(&self) -> usize { |
457 | 0 | self.shared_data.queued_count.load(Ordering::Relaxed) |
458 | 0 | } |
459 | | |
460 | | /// Returns the number of currently active threads. |
461 | | /// |
462 | | /// # Examples |
463 | | /// |
464 | | /// ``` |
465 | | /// use threadpool::ThreadPool; |
466 | | /// use std::time::Duration; |
467 | | /// use std::thread::sleep; |
468 | | /// |
469 | | /// let pool = ThreadPool::new(4); |
470 | | /// for _ in 0..10 { |
471 | | /// pool.execute(move || { |
472 | | /// sleep(Duration::from_secs(100)); |
473 | | /// }); |
474 | | /// } |
475 | | /// |
476 | | /// sleep(Duration::from_secs(1)); // wait for threads to start |
477 | | /// assert_eq!(4, pool.active_count()); |
478 | | /// ``` |
479 | 0 | pub fn active_count(&self) -> usize { |
480 | 0 | self.shared_data.active_count.load(Ordering::SeqCst) |
481 | 0 | } |
482 | | |
483 | | /// Returns the maximum number of threads the pool will execute concurrently. |
484 | | /// |
485 | | /// # Examples |
486 | | /// |
487 | | /// ``` |
488 | | /// use threadpool::ThreadPool; |
489 | | /// |
490 | | /// let mut pool = ThreadPool::new(4); |
491 | | /// assert_eq!(4, pool.max_count()); |
492 | | /// |
493 | | /// pool.set_num_threads(8); |
494 | | /// assert_eq!(8, pool.max_count()); |
495 | | /// ``` |
496 | 0 | pub fn max_count(&self) -> usize { |
497 | 0 | self.shared_data.max_thread_count.load(Ordering::Relaxed) |
498 | 0 | } |
499 | | |
500 | | /// Returns the number of panicked threads over the lifetime of the pool. |
501 | | /// |
502 | | /// # Examples |
503 | | /// |
504 | | /// ``` |
505 | | /// use threadpool::ThreadPool; |
506 | | /// |
507 | | /// let pool = ThreadPool::new(4); |
508 | | /// for n in 0..10 { |
509 | | /// pool.execute(move || { |
510 | | /// // simulate a panic |
511 | | /// if n % 2 == 0 { |
512 | | /// panic!() |
513 | | /// } |
514 | | /// }); |
515 | | /// } |
516 | | /// pool.join(); |
517 | | /// |
518 | | /// assert_eq!(5, pool.panic_count()); |
519 | | /// ``` |
520 | 0 | pub fn panic_count(&self) -> usize { |
521 | 0 | self.shared_data.panic_count.load(Ordering::Relaxed) |
522 | 0 | } |
523 | | |
524 | | /// **Deprecated: Use [`ThreadPool::set_num_threads`](#method.set_num_threads)** |
525 | | #[deprecated(since = "1.3.0", note = "use ThreadPool::set_num_threads")] |
526 | 0 | pub fn set_threads(&mut self, num_threads: usize) { |
527 | 0 | self.set_num_threads(num_threads) |
528 | 0 | } |
529 | | |
530 | | /// Sets the number of worker-threads to use as `num_threads`. |
531 | | /// Can be used to change the threadpool size during runtime. |
532 | | /// Will not abort already running or waiting threads. |
533 | | /// |
534 | | /// # Panics |
535 | | /// |
536 | | /// This function will panic if `num_threads` is 0. |
537 | | /// |
538 | | /// # Examples |
539 | | /// |
540 | | /// ``` |
541 | | /// use threadpool::ThreadPool; |
542 | | /// use std::time::Duration; |
543 | | /// use std::thread::sleep; |
544 | | /// |
545 | | /// let mut pool = ThreadPool::new(4); |
546 | | /// for _ in 0..10 { |
547 | | /// pool.execute(move || { |
548 | | /// sleep(Duration::from_secs(100)); |
549 | | /// }); |
550 | | /// } |
551 | | /// |
552 | | /// sleep(Duration::from_secs(1)); // wait for threads to start |
553 | | /// assert_eq!(4, pool.active_count()); |
554 | | /// assert_eq!(6, pool.queued_count()); |
555 | | /// |
556 | | /// // Increase thread capacity of the pool |
557 | | /// pool.set_num_threads(8); |
558 | | /// |
559 | | /// sleep(Duration::from_secs(1)); // wait for new threads to start |
560 | | /// assert_eq!(8, pool.active_count()); |
561 | | /// assert_eq!(2, pool.queued_count()); |
562 | | /// |
563 | | /// // Decrease thread capacity of the pool |
564 | | /// // No active threads are killed |
565 | | /// pool.set_num_threads(4); |
566 | | /// |
567 | | /// assert_eq!(8, pool.active_count()); |
568 | | /// assert_eq!(2, pool.queued_count()); |
569 | | /// ``` |
570 | 0 | pub fn set_num_threads(&mut self, num_threads: usize) { |
571 | 0 | assert!(num_threads >= 1); |
572 | 0 | let prev_num_threads = self |
573 | 0 | .shared_data |
574 | 0 | .max_thread_count |
575 | 0 | .swap(num_threads, Ordering::Release); |
576 | 0 | if let Some(num_spawn) = num_threads.checked_sub(prev_num_threads) { |
577 | | // Spawn new threads |
578 | 0 | for _ in 0..num_spawn { |
579 | 0 | spawn_in_pool(self.shared_data.clone()); |
580 | 0 | } |
581 | 0 | } |
582 | 0 | } |
583 | | |
584 | | /// Block the current thread until all jobs in the pool have been executed. |
585 | | /// |
586 | | /// Calling `join` on an empty pool will cause an immediate return. |
587 | | /// `join` may be called from multiple threads concurrently. |
588 | | /// A `join` is an atomic point in time. All threads joining before the join |
589 | | /// event will exit together even if the pool is processing new jobs by the |
590 | | /// time they get scheduled. |
591 | | /// |
592 | | /// Calling `join` from a thread within the pool will cause a deadlock. This |
593 | | /// behavior is considered safe. |
594 | | /// |
595 | | /// # Examples |
596 | | /// |
597 | | /// ``` |
598 | | /// use threadpool::ThreadPool; |
599 | | /// use std::sync::Arc; |
600 | | /// use std::sync::atomic::{AtomicUsize, Ordering}; |
601 | | /// |
602 | | /// let pool = ThreadPool::new(8); |
603 | | /// let test_count = Arc::new(AtomicUsize::new(0)); |
604 | | /// |
605 | | /// for _ in 0..42 { |
606 | | /// let test_count = test_count.clone(); |
607 | | /// pool.execute(move || { |
608 | | /// test_count.fetch_add(1, Ordering::Relaxed); |
609 | | /// }); |
610 | | /// } |
611 | | /// |
612 | | /// pool.join(); |
613 | | /// assert_eq!(42, test_count.load(Ordering::Relaxed)); |
614 | | /// ``` |
615 | 0 | pub fn join(&self) { |
616 | | // fast path requires no mutex |
617 | 0 | if self.shared_data.has_work() == false { |
618 | 0 | return (); |
619 | 0 | } |
620 | | |
621 | 0 | let generation = self.shared_data.join_generation.load(Ordering::SeqCst); |
622 | 0 | let mut lock = self.shared_data.empty_trigger.lock().unwrap(); |
623 | | |
624 | 0 | while generation == self.shared_data.join_generation.load(Ordering::Relaxed) |
625 | 0 | && self.shared_data.has_work() |
626 | 0 | { |
627 | 0 | lock = self.shared_data.empty_condvar.wait(lock).unwrap(); |
628 | 0 | } |
629 | | |
630 | | // increase generation if we are the first thread to come out of the loop |
631 | 0 | self.shared_data.join_generation.compare_and_swap( |
632 | 0 | generation, |
633 | 0 | generation.wrapping_add(1), |
634 | 0 | Ordering::SeqCst, |
635 | | ); |
636 | 0 | } |
637 | | } |
638 | | |
639 | | impl Clone for ThreadPool { |
640 | | /// Cloning a pool will create a new handle to the pool. |
641 | | /// The behavior is similar to [Arc](https://doc.rust-lang.org/stable/std/sync/struct.Arc.html). |
642 | | /// |
643 | | /// We could for example submit jobs from multiple threads concurrently. |
644 | | /// |
645 | | /// ``` |
646 | | /// use threadpool::ThreadPool; |
647 | | /// use std::thread; |
648 | | /// use std::sync::mpsc::channel; |
649 | | /// |
650 | | /// let pool = ThreadPool::with_name("clone example".into(), 2); |
651 | | /// |
652 | | /// let results = (0..2) |
653 | | /// .map(|i| { |
654 | | /// let pool = pool.clone(); |
655 | | /// thread::spawn(move || { |
656 | | /// let (tx, rx) = channel(); |
657 | | /// for i in 1..12 { |
658 | | /// let tx = tx.clone(); |
659 | | /// pool.execute(move || { |
660 | | /// tx.send(i).expect("channel will be waiting"); |
661 | | /// }); |
662 | | /// } |
663 | | /// drop(tx); |
664 | | /// if i == 0 { |
665 | | /// rx.iter().fold(0, |accumulator, element| accumulator + element) |
666 | | /// } else { |
667 | | /// rx.iter().fold(1, |accumulator, element| accumulator * element) |
668 | | /// } |
669 | | /// }) |
670 | | /// }) |
671 | | /// .map(|join_handle| join_handle.join().expect("collect results from threads")) |
672 | | /// .collect::<Vec<usize>>(); |
673 | | /// |
674 | | /// assert_eq!(vec![66, 39916800], results); |
675 | | /// ``` |
676 | 0 | fn clone(&self) -> ThreadPool { |
677 | 0 | ThreadPool { |
678 | 0 | jobs: self.jobs.clone(), |
679 | 0 | shared_data: self.shared_data.clone(), |
680 | 0 | } |
681 | 0 | } |
682 | | } |
683 | | |
684 | | /// Create a thread pool with one thread per CPU. |
685 | | /// On machines with hyperthreading, |
686 | | /// this will create one thread per hyperthread. |
687 | | impl Default for ThreadPool { |
688 | 0 | fn default() -> Self { |
689 | 0 | ThreadPool::new(num_cpus::get()) |
690 | 0 | } |
691 | | } |
692 | | |
693 | | impl fmt::Debug for ThreadPool { |
694 | 0 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
695 | 0 | f.debug_struct("ThreadPool") |
696 | 0 | .field("name", &self.shared_data.name) |
697 | 0 | .field("queued_count", &self.queued_count()) |
698 | 0 | .field("active_count", &self.active_count()) |
699 | 0 | .field("max_count", &self.max_count()) |
700 | 0 | .finish() |
701 | 0 | } |
702 | | } |
703 | | |
704 | | impl PartialEq for ThreadPool { |
705 | | /// Check if you are working with the same pool |
706 | | /// |
707 | | /// ``` |
708 | | /// use threadpool::ThreadPool; |
709 | | /// |
710 | | /// let a = ThreadPool::new(2); |
711 | | /// let b = ThreadPool::new(2); |
712 | | /// |
713 | | /// assert_eq!(a, a); |
714 | | /// assert_eq!(b, b); |
715 | | /// |
716 | | /// # // TODO: change this to assert_ne in the future |
717 | | /// assert!(a != b); |
718 | | /// assert!(b != a); |
719 | | /// ``` |
720 | 0 | fn eq(&self, other: &ThreadPool) -> bool { |
721 | 0 | let a: &ThreadPoolSharedData = &*self.shared_data; |
722 | 0 | let b: &ThreadPoolSharedData = &*other.shared_data; |
723 | 0 | a as *const ThreadPoolSharedData == b as *const ThreadPoolSharedData |
724 | | // with rust 1.17 and late: |
725 | | // Arc::ptr_eq(&self.shared_data, &other.shared_data) |
726 | 0 | } |
727 | | } |
728 | | impl Eq for ThreadPool {} |
729 | | |
730 | 0 | fn spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>) { |
731 | 0 | let mut builder = thread::Builder::new(); |
732 | 0 | if let Some(ref name) = shared_data.name { |
733 | 0 | builder = builder.name(name.clone()); |
734 | 0 | } |
735 | 0 | if let Some(ref stack_size) = shared_data.stack_size { |
736 | 0 | builder = builder.stack_size(stack_size.to_owned()); |
737 | 0 | } |
738 | 0 | builder |
739 | 0 | .spawn(move || { |
740 | | // Will spawn a new thread on panic unless it is cancelled. |
741 | 0 | let sentinel = Sentinel::new(&shared_data); |
742 | | |
743 | | loop { |
744 | | // Shutdown this thread if the pool has become smaller |
745 | 0 | let thread_counter_val = shared_data.active_count.load(Ordering::Acquire); |
746 | 0 | let max_thread_count_val = shared_data.max_thread_count.load(Ordering::Relaxed); |
747 | 0 | if thread_counter_val >= max_thread_count_val { |
748 | 0 | break; |
749 | 0 | } |
750 | 0 | let message = { |
751 | | // Only lock jobs for the time it takes |
752 | | // to get a job, not run it. |
753 | 0 | let lock = shared_data |
754 | 0 | .job_receiver |
755 | 0 | .lock() |
756 | 0 | .expect("Worker thread unable to lock job_receiver"); |
757 | 0 | lock.recv() |
758 | | }; |
759 | | |
760 | 0 | let job = match message { |
761 | 0 | Ok(job) => job, |
762 | | // The ThreadPool was dropped. |
763 | 0 | Err(..) => break, |
764 | | }; |
765 | | // Do not allow IR around the job execution |
766 | 0 | shared_data.active_count.fetch_add(1, Ordering::SeqCst); |
767 | 0 | shared_data.queued_count.fetch_sub(1, Ordering::SeqCst); |
768 | | |
769 | 0 | job.call_box(); |
770 | | |
771 | 0 | shared_data.active_count.fetch_sub(1, Ordering::SeqCst); |
772 | 0 | shared_data.no_work_notify_all(); |
773 | | } |
774 | | |
775 | 0 | sentinel.cancel(); |
776 | 0 | }) |
777 | 0 | .unwrap(); |
778 | 0 | } |
779 | | |
780 | | #[cfg(test)] |
781 | | mod test { |
782 | | use super::{Builder, ThreadPool}; |
783 | | use std::sync::atomic::{AtomicUsize, Ordering}; |
784 | | use std::sync::mpsc::{channel, sync_channel}; |
785 | | use std::sync::{Arc, Barrier}; |
786 | | use std::thread::{self, sleep}; |
787 | | use std::time::Duration; |
788 | | |
789 | | const TEST_TASKS: usize = 4; |
790 | | |
791 | | #[test] |
792 | | fn test_set_num_threads_increasing() { |
793 | | let new_thread_amount = TEST_TASKS + 8; |
794 | | let mut pool = ThreadPool::new(TEST_TASKS); |
795 | | for _ in 0..TEST_TASKS { |
796 | | pool.execute(move || sleep(Duration::from_secs(23))); |
797 | | } |
798 | | sleep(Duration::from_secs(1)); |
799 | | assert_eq!(pool.active_count(), TEST_TASKS); |
800 | | |
801 | | pool.set_num_threads(new_thread_amount); |
802 | | |
803 | | for _ in 0..(new_thread_amount - TEST_TASKS) { |
804 | | pool.execute(move || sleep(Duration::from_secs(23))); |
805 | | } |
806 | | sleep(Duration::from_secs(1)); |
807 | | assert_eq!(pool.active_count(), new_thread_amount); |
808 | | |
809 | | pool.join(); |
810 | | } |
811 | | |
812 | | #[test] |
813 | | fn test_set_num_threads_decreasing() { |
814 | | let new_thread_amount = 2; |
815 | | let mut pool = ThreadPool::new(TEST_TASKS); |
816 | | for _ in 0..TEST_TASKS { |
817 | | pool.execute(move || { |
818 | | assert_eq!(1, 1); |
819 | | }); |
820 | | } |
821 | | pool.set_num_threads(new_thread_amount); |
822 | | for _ in 0..new_thread_amount { |
823 | | pool.execute(move || sleep(Duration::from_secs(23))); |
824 | | } |
825 | | sleep(Duration::from_secs(1)); |
826 | | assert_eq!(pool.active_count(), new_thread_amount); |
827 | | |
828 | | pool.join(); |
829 | | } |
830 | | |
831 | | #[test] |
832 | | fn test_active_count() { |
833 | | let pool = ThreadPool::new(TEST_TASKS); |
834 | | for _ in 0..2 * TEST_TASKS { |
835 | | pool.execute(move || loop { |
836 | | sleep(Duration::from_secs(10)) |
837 | | }); |
838 | | } |
839 | | sleep(Duration::from_secs(1)); |
840 | | let active_count = pool.active_count(); |
841 | | assert_eq!(active_count, TEST_TASKS); |
842 | | let initialized_count = pool.max_count(); |
843 | | assert_eq!(initialized_count, TEST_TASKS); |
844 | | } |
845 | | |
846 | | #[test] |
847 | | fn test_works() { |
848 | | let pool = ThreadPool::new(TEST_TASKS); |
849 | | |
850 | | let (tx, rx) = channel(); |
851 | | for _ in 0..TEST_TASKS { |
852 | | let tx = tx.clone(); |
853 | | pool.execute(move || { |
854 | | tx.send(1).unwrap(); |
855 | | }); |
856 | | } |
857 | | |
858 | | assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS); |
859 | | } |
860 | | |
861 | | #[test] |
862 | | #[should_panic] |
863 | | fn test_zero_tasks_panic() { |
864 | | ThreadPool::new(0); |
865 | | } |
866 | | |
867 | | #[test] |
868 | | fn test_recovery_from_subtask_panic() { |
869 | | let pool = ThreadPool::new(TEST_TASKS); |
870 | | |
871 | | // Panic all the existing threads. |
872 | | for _ in 0..TEST_TASKS { |
873 | | pool.execute(move || panic!("Ignore this panic, it must!")); |
874 | | } |
875 | | pool.join(); |
876 | | |
877 | | assert_eq!(pool.panic_count(), TEST_TASKS); |
878 | | |
879 | | // Ensure new threads were spawned to compensate. |
880 | | let (tx, rx) = channel(); |
881 | | for _ in 0..TEST_TASKS { |
882 | | let tx = tx.clone(); |
883 | | pool.execute(move || { |
884 | | tx.send(1).unwrap(); |
885 | | }); |
886 | | } |
887 | | |
888 | | assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS); |
889 | | } |
890 | | |
891 | | #[test] |
892 | | fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() { |
893 | | let pool = ThreadPool::new(TEST_TASKS); |
894 | | let waiter = Arc::new(Barrier::new(TEST_TASKS + 1)); |
895 | | |
896 | | // Panic all the existing threads in a bit. |
897 | | for _ in 0..TEST_TASKS { |
898 | | let waiter = waiter.clone(); |
899 | | pool.execute(move || { |
900 | | waiter.wait(); |
901 | | panic!("Ignore this panic, it should!"); |
902 | | }); |
903 | | } |
904 | | |
905 | | drop(pool); |
906 | | |
907 | | // Kick off the failure. |
908 | | waiter.wait(); |
909 | | } |
910 | | |
911 | | #[test] |
912 | | fn test_massive_task_creation() { |
913 | | let test_tasks = 4_200_000; |
914 | | |
915 | | let pool = ThreadPool::new(TEST_TASKS); |
916 | | let b0 = Arc::new(Barrier::new(TEST_TASKS + 1)); |
917 | | let b1 = Arc::new(Barrier::new(TEST_TASKS + 1)); |
918 | | |
919 | | let (tx, rx) = channel(); |
920 | | |
921 | | for i in 0..test_tasks { |
922 | | let tx = tx.clone(); |
923 | | let (b0, b1) = (b0.clone(), b1.clone()); |
924 | | |
925 | | pool.execute(move || { |
926 | | // Wait until the pool has been filled once. |
927 | | if i < TEST_TASKS { |
928 | | b0.wait(); |
929 | | // wait so the pool can be measured |
930 | | b1.wait(); |
931 | | } |
932 | | |
933 | | tx.send(1).is_ok(); |
934 | | }); |
935 | | } |
936 | | |
937 | | b0.wait(); |
938 | | assert_eq!(pool.active_count(), TEST_TASKS); |
939 | | b1.wait(); |
940 | | |
941 | | assert_eq!(rx.iter().take(test_tasks).fold(0, |a, b| a + b), test_tasks); |
942 | | pool.join(); |
943 | | |
944 | | let atomic_active_count = pool.active_count(); |
945 | | assert!( |
946 | | atomic_active_count == 0, |
947 | | "atomic_active_count: {}", |
948 | | atomic_active_count |
949 | | ); |
950 | | } |
951 | | |
952 | | #[test] |
953 | | fn test_shrink() { |
954 | | let test_tasks_begin = TEST_TASKS + 2; |
955 | | |
956 | | let mut pool = ThreadPool::new(test_tasks_begin); |
957 | | let b0 = Arc::new(Barrier::new(test_tasks_begin + 1)); |
958 | | let b1 = Arc::new(Barrier::new(test_tasks_begin + 1)); |
959 | | |
960 | | for _ in 0..test_tasks_begin { |
961 | | let (b0, b1) = (b0.clone(), b1.clone()); |
962 | | pool.execute(move || { |
963 | | b0.wait(); |
964 | | b1.wait(); |
965 | | }); |
966 | | } |
967 | | |
968 | | let b2 = Arc::new(Barrier::new(TEST_TASKS + 1)); |
969 | | let b3 = Arc::new(Barrier::new(TEST_TASKS + 1)); |
970 | | |
971 | | for _ in 0..TEST_TASKS { |
972 | | let (b2, b3) = (b2.clone(), b3.clone()); |
973 | | pool.execute(move || { |
974 | | b2.wait(); |
975 | | b3.wait(); |
976 | | }); |
977 | | } |
978 | | |
979 | | b0.wait(); |
980 | | pool.set_num_threads(TEST_TASKS); |
981 | | |
982 | | assert_eq!(pool.active_count(), test_tasks_begin); |
983 | | b1.wait(); |
984 | | |
985 | | b2.wait(); |
986 | | assert_eq!(pool.active_count(), TEST_TASKS); |
987 | | b3.wait(); |
988 | | } |
989 | | |
990 | | #[test] |
991 | | fn test_name() { |
992 | | let name = "test"; |
993 | | let mut pool = ThreadPool::with_name(name.to_owned(), 2); |
994 | | let (tx, rx) = sync_channel(0); |
995 | | |
996 | | // initial thread should share the name "test" |
997 | | for _ in 0..2 { |
998 | | let tx = tx.clone(); |
999 | | pool.execute(move || { |
1000 | | let name = thread::current().name().unwrap().to_owned(); |
1001 | | tx.send(name).unwrap(); |
1002 | | }); |
1003 | | } |
1004 | | |
1005 | | // new spawn thread should share the name "test" too. |
1006 | | pool.set_num_threads(3); |
1007 | | let tx_clone = tx.clone(); |
1008 | | pool.execute(move || { |
1009 | | let name = thread::current().name().unwrap().to_owned(); |
1010 | | tx_clone.send(name).unwrap(); |
1011 | | panic!(); |
1012 | | }); |
1013 | | |
1014 | | // recover thread should share the name "test" too. |
1015 | | pool.execute(move || { |
1016 | | let name = thread::current().name().unwrap().to_owned(); |
1017 | | tx.send(name).unwrap(); |
1018 | | }); |
1019 | | |
1020 | | for thread_name in rx.iter().take(4) { |
1021 | | assert_eq!(name, thread_name); |
1022 | | } |
1023 | | } |
1024 | | |
1025 | | #[test] |
1026 | | fn test_debug() { |
1027 | | let pool = ThreadPool::new(4); |
1028 | | let debug = format!("{:?}", pool); |
1029 | | assert_eq!( |
1030 | | debug, |
1031 | | "ThreadPool { name: None, queued_count: 0, active_count: 0, max_count: 4 }" |
1032 | | ); |
1033 | | |
1034 | | let pool = ThreadPool::with_name("hello".into(), 4); |
1035 | | let debug = format!("{:?}", pool); |
1036 | | assert_eq!( |
1037 | | debug, |
1038 | | "ThreadPool { name: Some(\"hello\"), queued_count: 0, active_count: 0, max_count: 4 }" |
1039 | | ); |
1040 | | |
1041 | | let pool = ThreadPool::new(4); |
1042 | | pool.execute(move || sleep(Duration::from_secs(5))); |
1043 | | sleep(Duration::from_secs(1)); |
1044 | | let debug = format!("{:?}", pool); |
1045 | | assert_eq!( |
1046 | | debug, |
1047 | | "ThreadPool { name: None, queued_count: 0, active_count: 1, max_count: 4 }" |
1048 | | ); |
1049 | | } |
1050 | | |
1051 | | #[test] |
1052 | | fn test_repeate_join() { |
1053 | | let pool = ThreadPool::with_name("repeate join test".into(), 8); |
1054 | | let test_count = Arc::new(AtomicUsize::new(0)); |
1055 | | |
1056 | | for _ in 0..42 { |
1057 | | let test_count = test_count.clone(); |
1058 | | pool.execute(move || { |
1059 | | sleep(Duration::from_secs(2)); |
1060 | | test_count.fetch_add(1, Ordering::Release); |
1061 | | }); |
1062 | | } |
1063 | | |
1064 | | println!("{:?}", pool); |
1065 | | pool.join(); |
1066 | | assert_eq!(42, test_count.load(Ordering::Acquire)); |
1067 | | |
1068 | | for _ in 0..42 { |
1069 | | let test_count = test_count.clone(); |
1070 | | pool.execute(move || { |
1071 | | sleep(Duration::from_secs(2)); |
1072 | | test_count.fetch_add(1, Ordering::Relaxed); |
1073 | | }); |
1074 | | } |
1075 | | pool.join(); |
1076 | | assert_eq!(84, test_count.load(Ordering::Relaxed)); |
1077 | | } |
1078 | | |
1079 | | #[test] |
1080 | | fn test_multi_join() { |
1081 | | use std::sync::mpsc::TryRecvError::*; |
1082 | | |
1083 | | // Toggle the following lines to debug the deadlock |
1084 | | fn error(_s: String) { |
1085 | | //use ::std::io::Write; |
1086 | | //let stderr = ::std::io::stderr(); |
1087 | | //let mut stderr = stderr.lock(); |
1088 | | //stderr.write(&_s.as_bytes()).is_ok(); |
1089 | | } |
1090 | | |
1091 | | let pool0 = ThreadPool::with_name("multi join pool0".into(), 4); |
1092 | | let pool1 = ThreadPool::with_name("multi join pool1".into(), 4); |
1093 | | let (tx, rx) = channel(); |
1094 | | |
1095 | | for i in 0..8 { |
1096 | | let pool1 = pool1.clone(); |
1097 | | let pool0_ = pool0.clone(); |
1098 | | let tx = tx.clone(); |
1099 | | pool0.execute(move || { |
1100 | | pool1.execute(move || { |
1101 | | error(format!("p1: {} -=- {:?}\n", i, pool0_)); |
1102 | | pool0_.join(); |
1103 | | error(format!("p1: send({})\n", i)); |
1104 | | tx.send(i).expect("send i from pool1 -> main"); |
1105 | | }); |
1106 | | error(format!("p0: {}\n", i)); |
1107 | | }); |
1108 | | } |
1109 | | drop(tx); |
1110 | | |
1111 | | assert_eq!(rx.try_recv(), Err(Empty)); |
1112 | | error(format!("{:?}\n{:?}\n", pool0, pool1)); |
1113 | | pool0.join(); |
1114 | | error(format!("pool0.join() complete =-= {:?}", pool1)); |
1115 | | pool1.join(); |
1116 | | error("pool1.join() complete\n".into()); |
1117 | | assert_eq!( |
1118 | | rx.iter().fold(0, |acc, i| acc + i), |
1119 | | 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 |
1120 | | ); |
1121 | | } |
1122 | | |
1123 | | #[test] |
1124 | | fn test_empty_pool() { |
1125 | | // Joining an empty pool must return imminently |
1126 | | let pool = ThreadPool::new(4); |
1127 | | |
1128 | | pool.join(); |
1129 | | |
1130 | | assert!(true); |
1131 | | } |
1132 | | |
1133 | | #[test] |
1134 | | fn test_no_fun_or_joy() { |
1135 | | // What happens when you keep adding jobs after a join |
1136 | | |
1137 | | fn sleepy_function() { |
1138 | | sleep(Duration::from_secs(6)); |
1139 | | } |
1140 | | |
1141 | | let pool = ThreadPool::with_name("no fun or joy".into(), 8); |
1142 | | |
1143 | | pool.execute(sleepy_function); |
1144 | | |
1145 | | let p_t = pool.clone(); |
1146 | | thread::spawn(move || { |
1147 | | (0..23).map(|_| p_t.execute(sleepy_function)).count(); |
1148 | | }); |
1149 | | |
1150 | | pool.join(); |
1151 | | } |
1152 | | |
1153 | | #[test] |
1154 | | fn test_clone() { |
1155 | | let pool = ThreadPool::with_name("clone example".into(), 2); |
1156 | | |
1157 | | // This batch of jobs will occupy the pool for some time |
1158 | | for _ in 0..6 { |
1159 | | pool.execute(move || { |
1160 | | sleep(Duration::from_secs(2)); |
1161 | | }); |
1162 | | } |
1163 | | |
1164 | | // The following jobs will be inserted into the pool in a random fashion |
1165 | | let t0 = { |
1166 | | let pool = pool.clone(); |
1167 | | thread::spawn(move || { |
1168 | | // wait for the first batch of tasks to finish |
1169 | | pool.join(); |
1170 | | |
1171 | | let (tx, rx) = channel(); |
1172 | | for i in 0..42 { |
1173 | | let tx = tx.clone(); |
1174 | | pool.execute(move || { |
1175 | | tx.send(i).expect("channel will be waiting"); |
1176 | | }); |
1177 | | } |
1178 | | drop(tx); |
1179 | | rx.iter() |
1180 | | .fold(0, |accumulator, element| accumulator + element) |
1181 | | }) |
1182 | | }; |
1183 | | let t1 = { |
1184 | | let pool = pool.clone(); |
1185 | | thread::spawn(move || { |
1186 | | // wait for the first batch of tasks to finish |
1187 | | pool.join(); |
1188 | | |
1189 | | let (tx, rx) = channel(); |
1190 | | for i in 1..12 { |
1191 | | let tx = tx.clone(); |
1192 | | pool.execute(move || { |
1193 | | tx.send(i).expect("channel will be waiting"); |
1194 | | }); |
1195 | | } |
1196 | | drop(tx); |
1197 | | rx.iter() |
1198 | | .fold(1, |accumulator, element| accumulator * element) |
1199 | | }) |
1200 | | }; |
1201 | | |
1202 | | assert_eq!( |
1203 | | 861, |
1204 | | t0.join() |
1205 | | .expect("thread 0 will return after calculating additions",) |
1206 | | ); |
1207 | | assert_eq!( |
1208 | | 39916800, |
1209 | | t1.join() |
1210 | | .expect("thread 1 will return after calculating multiplications",) |
1211 | | ); |
1212 | | } |
1213 | | |
1214 | | #[test] |
1215 | | fn test_sync_shared_data() { |
1216 | | fn assert_sync<T: Sync>() {} |
1217 | | assert_sync::<super::ThreadPoolSharedData>(); |
1218 | | } |
1219 | | |
1220 | | #[test] |
1221 | | fn test_send_shared_data() { |
1222 | | fn assert_send<T: Send>() {} |
1223 | | assert_send::<super::ThreadPoolSharedData>(); |
1224 | | } |
1225 | | |
1226 | | #[test] |
1227 | | fn test_send() { |
1228 | | fn assert_send<T: Send>() {} |
1229 | | assert_send::<ThreadPool>(); |
1230 | | } |
1231 | | |
1232 | | #[test] |
1233 | | fn test_cloned_eq() { |
1234 | | let a = ThreadPool::new(2); |
1235 | | |
1236 | | assert_eq!(a, a.clone()); |
1237 | | } |
1238 | | |
1239 | | #[test] |
1240 | | /// The scenario is joining threads should not be stuck once their wave |
1241 | | /// of joins has completed. So once one thread joining on a pool has |
1242 | | /// succeded other threads joining on the same pool must get out even if |
1243 | | /// the thread is used for other jobs while the first group is finishing |
1244 | | /// their join |
1245 | | /// |
1246 | | /// In this example this means the waiting threads will exit the join in |
1247 | | /// groups of four because the waiter pool has four workers. |
1248 | | fn test_join_wavesurfer() { |
1249 | | let n_cycles = 4; |
1250 | | let n_workers = 4; |
1251 | | let (tx, rx) = channel(); |
1252 | | let builder = Builder::new() |
1253 | | .num_threads(n_workers) |
1254 | | .thread_name("join wavesurfer".into()); |
1255 | | let p_waiter = builder.clone().build(); |
1256 | | let p_clock = builder.build(); |
1257 | | |
1258 | | let barrier = Arc::new(Barrier::new(3)); |
1259 | | let wave_clock = Arc::new(AtomicUsize::new(0)); |
1260 | | let clock_thread = { |
1261 | | let barrier = barrier.clone(); |
1262 | | let wave_clock = wave_clock.clone(); |
1263 | | thread::spawn(move || { |
1264 | | barrier.wait(); |
1265 | | for wave_num in 0..n_cycles { |
1266 | | wave_clock.store(wave_num, Ordering::SeqCst); |
1267 | | sleep(Duration::from_secs(1)); |
1268 | | } |
1269 | | }) |
1270 | | }; |
1271 | | |
1272 | | { |
1273 | | let barrier = barrier.clone(); |
1274 | | p_clock.execute(move || { |
1275 | | barrier.wait(); |
1276 | | // this sleep is for stabilisation on weaker platforms |
1277 | | sleep(Duration::from_millis(100)); |
1278 | | }); |
1279 | | } |
1280 | | |
1281 | | // prepare three waves of jobs |
1282 | | for i in 0..3 * n_workers { |
1283 | | let p_clock = p_clock.clone(); |
1284 | | let tx = tx.clone(); |
1285 | | let wave_clock = wave_clock.clone(); |
1286 | | p_waiter.execute(move || { |
1287 | | let now = wave_clock.load(Ordering::SeqCst); |
1288 | | p_clock.join(); |
1289 | | // submit jobs for the second wave |
1290 | | p_clock.execute(|| sleep(Duration::from_secs(1))); |
1291 | | let clock = wave_clock.load(Ordering::SeqCst); |
1292 | | tx.send((now, clock, i)).unwrap(); |
1293 | | }); |
1294 | | } |
1295 | | println!("all scheduled at {}", wave_clock.load(Ordering::SeqCst)); |
1296 | | barrier.wait(); |
1297 | | |
1298 | | p_clock.join(); |
1299 | | //p_waiter.join(); |
1300 | | |
1301 | | drop(tx); |
1302 | | let mut hist = vec![0; n_cycles]; |
1303 | | let mut data = vec![]; |
1304 | | for (now, after, i) in rx.iter() { |
1305 | | let mut dur = after - now; |
1306 | | if dur >= n_cycles - 1 { |
1307 | | dur = n_cycles - 1; |
1308 | | } |
1309 | | hist[dur] += 1; |
1310 | | |
1311 | | data.push((now, after, i)); |
1312 | | } |
1313 | | for (i, n) in hist.iter().enumerate() { |
1314 | | println!( |
1315 | | "\t{}: {} {}", |
1316 | | i, |
1317 | | n, |
1318 | | &*(0..*n).fold("".to_owned(), |s, _| s + "*") |
1319 | | ); |
1320 | | } |
1321 | | assert!(data.iter().all(|&(cycle, stop, i)| if i < n_workers { |
1322 | | cycle == stop |
1323 | | } else { |
1324 | | cycle < stop |
1325 | | })); |
1326 | | |
1327 | | clock_thread.join().unwrap(); |
1328 | | } |
1329 | | } |