Coverage Report

Created: 2021-01-22 16:54

crossbeam-utils/src/sync/wait_group.rs
Line
Count
Source
1
// Necessary for using `Mutex<usize>` for conditional variables
2
#![allow(clippy::mutex_atomic)]
3
4
use crate::primitive::sync::{Arc, Condvar, Mutex};
5
use std::fmt;
6
7
/// Enables threads to synchronize the beginning or end of some computation.
8
///
9
/// # Wait groups vs barriers
10
///
11
/// `WaitGroup` is very similar to [`Barrier`], but there are a few differences:
12
///
13
/// * [`Barrier`] needs to know the number of threads at construction, while `WaitGroup` is cloned to
14
///   register more threads.
15
///
16
/// * A [`Barrier`] can be reused even after all threads have synchronized, while a `WaitGroup`
17
///   synchronizes threads only once.
18
///
19
/// * All threads wait for others to reach the [`Barrier`]. With `WaitGroup`, each thread can choose
20
///   to either wait for other threads or to continue without blocking.
21
///
22
/// # Examples
23
///
24
/// ```
25
/// use crossbeam_utils::sync::WaitGroup;
26
/// use std::thread;
27
///
28
/// // Create a new wait group.
29
/// let wg = WaitGroup::new();
30
///
31
/// for _ in 0..4 {
32
///     // Create another reference to the wait group.
33
///     let wg = wg.clone();
34
///
35
///     thread::spawn(move || {
36
///         // Do some work.
37
///
38
///         // Drop the reference to the wait group.
39
///         drop(wg);
40
///     });
41
/// }
42
///
43
/// // Block until all threads have finished their work.
44
/// wg.wait();
45
/// ```
46
///
47
/// [`Barrier`]: std::sync::Barrier
48
pub struct WaitGroup {
49
    inner: Arc<Inner>,
50
}
51
52
/// Inner state of a `WaitGroup`.
53
struct Inner {
54
    cvar: Condvar,
55
    count: Mutex<usize>,
56
}
57
58
impl Default for WaitGroup {
59
31.7k
    fn default() -> Self {
60
31.7k
        Self {
61
31.7k
            inner: Arc::new(Inner {
62
31.7k
                cvar: Condvar::new(),
63
31.7k
                count: Mutex::new(1),
64
31.7k
            }),
65
31.7k
        }
66
31.7k
    }
67
}
68
69
impl WaitGroup {
70
    /// Creates a new wait group and returns the single reference to it.
71
    ///
72
    /// # Examples
73
    ///
74
    /// ```
75
    /// use crossbeam_utils::sync::WaitGroup;
76
    ///
77
    /// let wg = WaitGroup::new();
78
    /// ```
79
31.7k
    pub fn new() -> Self {
80
31.7k
        Self::default()
81
31.7k
    }
82
83
    /// Drops this reference and waits until all other references are dropped.
84
    ///
85
    /// # Examples
86
    ///
87
    /// ```
88
    /// use crossbeam_utils::sync::WaitGroup;
89
    /// use std::thread;
90
    ///
91
    /// let wg = WaitGroup::new();
92
    ///
93
    /// thread::spawn({
94
    ///     let wg = wg.clone();
95
    ///     move || {
96
    ///         // Block until both threads have reached `wait()`.
97
    ///         wg.wait();
98
    ///     }
99
    /// });
100
    ///
101
    /// // Block until both threads have reached `wait()`.
102
    /// wg.wait();
103
    /// ```
104
31.7k
    pub fn wait(self) {
105
31.7k
        if *self.inner.count.lock().unwrap() == 1 {
106
75
            return;
107
31.6k
        }
108
31.6k
109
31.6k
        let inner = self.inner.clone();
110
31.6k
        drop(self);
111
31.6k
112
31.6k
        let mut count = inner.count.lock().unwrap();
113
63.3k
        while *count > 0 {
114
31.6k
            count = inner.cvar.wait(count).unwrap();
115
31.6k
        }
116
31.7k
    }
117
}
118
119
impl Drop for WaitGroup {
120
135k
    fn drop(&mut self) {
121
135k
        let mut count = self.inner.count.lock().unwrap();
122
135k
        *count -= 1;
123
135k
124
135k
        if *count == 0 {
125
31.7k
            self.inner.cvar.notify_all();
126
103k
        }
127
135k
    }
128
}
129
130
impl Clone for WaitGroup {
131
103k
    fn clone(&self) -> WaitGroup {
132
103k
        let mut count = self.inner.count.lock().unwrap();
133
103k
        *count += 1;
134
103k
135
103k
        WaitGroup {
136
103k
            inner: self.inner.clone(),
137
103k
        }
138
103k
    }
139
}
140
141
impl fmt::Debug for WaitGroup {
142
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143
        let count: &usize = &*self.inner.count.lock().unwrap();
144
        f.debug_struct("WaitGroup").field("count", count).finish()
145
    }
146
}