Created
November 30, 2021 17:57
-
-
Save thomcc/2ccbdba27981607e18849cc389795dbd to your computer and use it in GitHub Desktop.
quick and dirty impl of lamports bakery algorithm in rust as an example. untested.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! Implementation of [Lamport's bakery algorithm][bakery]. This is somewhat | |
//! interesting, because it's a mutex which can be implemented even if the | |
//! target only has atomic load/store, and no CAS (in principal it can be even | |
//! more general than this). | |
//! | |
//! [bakery]: https://en.wikipedia.org/wiki/Lamport%27s_bakery_algorithm | |
//! | |
//! Major caveat: This is not tested, and this algo is no longer appropriate | |
//! for modern code. Some variations of it can be useful in a thread pool, but | |
//! it's mostly useful to understand the concepts. | |
//! | |
//! ## Notes: | |
//! | |
//! Debug asserts are there for easy-to-check things, but do not (and cannot) | |
//! detect all possible misuse. | |
//! | |
//! Thread index could be stored in a thread local, but it's probably better to | |
//! just pass it into the thread on construction. In practice, for modern | |
//! applications this is quite inconvenient, and so this algorithm is hard to | |
//! use, even in situations where it could be appropriate. | |
//! | |
//! P.S. Reducing from SeqCst is left as an exercise for the reader. | |
use core::sync::atomic::{*, Ordering::*}; | |
pub type ThreadIndex = core::num::NonZeroUsize; | |
pub struct BakeryMutex<const MAX_THREADS: usize> { | |
entering: [AtomicBool; MAX_THREADS], | |
threads: [AtomicUsize; MAX_THREADS], | |
} | |
impl<const MAX_THREADS: usize> BakeryMutex<MAX_THREADS> { | |
const INIT: Self = { | |
const FALSE: AtomicBool = AtomicBool::new(false); | |
const ZERO: AtomicUsize = AtomicUsize::new(0); | |
Self { | |
entering: [FALSE; MAX_THREADS], | |
threads: [ZERO; MAX_THREADS], | |
} | |
}; | |
#[inline] | |
pub const fn new() -> Self { | |
Self::INIT | |
} | |
/// SAFETY: `thread_index` must be in range `1..=MAX_THREADS` which is | |
/// unique to your thread. Current thread also must not already hold the | |
/// lock, etc. | |
pub unsafe fn lock(&self, thread_index: ThreadIndex) { | |
assert!(MAX_THREADS != 0); | |
debug_assert!( | |
thread_index.get() <= MAX_THREADS, | |
"out of range 1..={MAX_THREADS:?}: {thread_index:?}", | |
); | |
let thread_index = thread_index.get() - 1; | |
debug_assert!(!self.entering[thread_index].load(Relaxed)); | |
debug_assert_eq!(!self.threads[thread_index].load(Relaxed), 0); | |
self.entering[thread_index].store(true, SeqCst); | |
// Note: Panicing here will deadlock everybody, which is probably not | |
// desirable. Overflow is not allowed, though. In practice, it's | |
// unlikely for this to get above MAX_THREADS by much, so this is | |
// probably a theoretical concern. | |
let ticket = self.threads.iter().map(|t| t.load(SeqCst)).max().unwrap_or_default().checked_add(1).unwrap(); | |
self.threads[thread_index].store(ticket, SeqCst); | |
self.entering[thread_index].store(false, SeqCst); | |
for (other_thread_index, (entering, other_ticket)) in self.entering.iter().zip(&self.threads).enumerate() { | |
while entering.load(SeqCst) { | |
core::hint::spin_loop(); | |
} | |
// Wait for our turn. If another thread has the same ticket value as | |
// us (which is possible), break the tie in favor of the thread with | |
// the lower index. | |
while { | |
other_ticket.load(SeqCst) != 0 && | |
(other_ticket.load(SeqCst), other_thread_index) < (ticket, thread_index) | |
} { | |
core::hint::spin_loop(); | |
} | |
} | |
} | |
/// SAFETY: `thread_index` must be in range `1..=MAX_THREADS` which is | |
/// unique to your thread. Current thread must hold lock. | |
pub unsafe fn unlock(&self, thread_index: ThreadIndex) { | |
debug_assert!( | |
thread_index.get() <= MAX_THREADS, | |
"out of range 1..={MAX_THREADS:?}: {thread_index:?}", | |
); | |
let thread_index = thread_index.get() - 1; | |
debug_assert!(!self.entering[thread_index].load(Relaxed)); | |
debug_assert_ne!(!self.threads[thread_index].load(Relaxed), 0); | |
self.threads[thread_index].store(0, SeqCst); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment