A simple wait-group implementation in Rust
When writing concurrent code in Go, wait-groups are used ubiquitously as an important piece of arsenal for syncing go-routines. However, in Rust, wait-groups can be redundant as it supports classic fork-join parallelism of threads. Meaning that, in Rust you can spawn a single or multiple set of threads and wait for them to finish/join all at once.
Despite this, the popular crossbeam crate provides an implementation for wait-groups as a mechanism which enables threads to synchronize the beginning or end of some computation.
This is an attempt to replicate the wait-group API of crossbeam with a simple approach using an atomic counter and a condition variable (condvar).
To implement a wait-group we need few things,
- A counter variable that can be incremented and decremented. It is used to determine whether the wait-group should block or unblock.
- A mechanism to continuously check the current value of the wait-group.
- A way to implement everything within the bounds of safe rust.
For the counter variable, I will be using an AtomicUsize type, so that it can easily be shared and mutated by multiple threads without race conditions.
For checking the current value of the counter, I’ll be using a Condvar. Which is called a conditional variable and it provides the ability to block a thread such that it consumes no CPU time while waiting for an event to occur. A naive approach would have been to use a separate thread and use an infinite loop to check the value of the wait-group. However, such an approach would be hugely undesirable in-terms of wasted CPU cycles.
We’d also need a Mutex to go with the said conditional variable, as typically condvars are associated with a boolean predicate and a corresponding mutex for that predicate. I’d explain the flow of the condvar in detail later in this article.
Below code defines the structure of the wait-group with a constructor function.
Notice that our defined wait-group should be sharable mutably across thread boundaries to modify the counter variable. One option here is to utilize the interior mutability laws of Rust and another option would be wrap the struct in an Arc smart pointer. I decided to go with the latter as it is more comprehensible than using interior mutability.
Now that our wait-group can be shared across safely without having the Rust compiler shouting at us, let’s look at how we can increment and decrement the counter variable safely.
Typically when using Arc smart pointers, we share the variable across threads by cloning and moving it. So, we could utilize a similar pattern and do the increment at the point of cloning by implementing Clone trait for the WaitGroup struct. Similarly, the cloned variable will be dropped when the thread exit and we can utilize that pattern to do the decrement by implementing the Drop trait for the WaitGroup struct.
If you look at the above code, in the clone function, I’m adding 1 to the counter variable. Notice that, to do operations on atomic variables, we need to use specialize methods such as fetch_add, which adds the passed value to the current value and we also need to pass in an Ordering parameter. Memory ordering is a complicated topic and it specify the way atomic operations synchronizes memory. Since, my implementation is very basic, I will be using the weakest constraint called Relaxed, where only the memory directly touched by the operation is synchronized.
In the drop implementation, I’m deducting 1 from the counter by using the method fetch_sub. Notice, that I also call the condvar by using the notify_one method, which wakes up the corresponding blocked thread on this condvar.
Now let’s look at the wait function implementation.
Before discussing the above code, I’d like to emphasize a bit on the flow of conditional variables. A conditional variable is usually associated with a mutex and a boolean predicate. Initially, condition variable acquires the lock and then usually it goes through an infinite loop, where if a certain predicate isn’t met, the condition variable will call wait, at which point, the associated thread will be blocked and it will be put to sleep. The call to wait would also release the lock for any other thread to use and mutate the underlying predicate. Once a notify call is made by another thread, the condition variable will wake up and acquire the lock again. It will check the predicate and if it’s satisfactory it can exit the loop and If not, it will call wait again by releasing the lock and the flow will continue as before.
So, the WaitGroup’s wait method just checks the counter variable by using the load method and if it’s 0, it will be exited from the loop, thereby unblocking the thread.
This pretty much covers the implementation and let’s look at the complete implementation code with an example.
In the example, I’m spawning a set of threads which sleeps for different set of duration and then wait for the whole set to finish by using the wait method. This produces the below output.