Created
September 1, 2023 16:21
-
-
Save evanxg852000/dbd79dcf5498c15f1fbdf4242626710a to your computer and use it in GitHub Desktop.
A spin lock implementation in rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// To run undefined behavior checks: `cargo +nightly miri run` | |
use std::{ | |
cell::UnsafeCell, | |
ops::{Deref, DerefMut}, | |
sync::{ | |
atomic::{AtomicBool, Ordering}, | |
Arc, | |
}, | |
thread, | |
}; | |
use itertools::Itertools; | |
struct SpinLock<T> { | |
value: UnsafeCell<T>, | |
locked: AtomicBool, | |
} | |
unsafe impl<T: Send> Send for SpinLock<T> {} | |
unsafe impl<T: Sync> Sync for SpinLock<T> {} | |
impl<T> SpinLock<T> { | |
pub fn new(value: T) -> Self { | |
Self { | |
value: UnsafeCell::new(value), | |
locked: AtomicBool::new(false), | |
} | |
} | |
pub fn lock(&self) -> SpinLockGuard<T> { | |
while self.locked.swap(true, Ordering::Acquire) { | |
thread::yield_now(); | |
} | |
SpinLockGuard { spin_lock: self } | |
} | |
fn unlock(&self) { | |
self.locked.store(false, Ordering::Release); | |
} | |
} | |
struct SpinLockGuard<'a, T> { | |
spin_lock: &'a SpinLock<T>, | |
} | |
impl<T> Deref for SpinLockGuard<'_, T> { | |
type Target = T; | |
fn deref(&self) -> &Self::Target { | |
unsafe { &*self.spin_lock.value.get() } | |
} | |
} | |
impl<T> DerefMut for SpinLockGuard<'_, T> { | |
fn deref_mut(&mut self) -> &mut Self::Target { | |
unsafe { &mut *self.spin_lock.value.get() } | |
} | |
} | |
impl<'a, T> Drop for SpinLockGuard<'a, T> { | |
fn drop(&mut self) { | |
self.spin_lock.unlock(); | |
} | |
} | |
fn main() { | |
let counter = Arc::new(SpinLock::new(0)); | |
let threads = (0..100) | |
.map(|_| { | |
let moved_counter = counter.clone(); | |
thread::spawn(move || { | |
for _ in 0..1000 { | |
let mut counter_guard = moved_counter.lock(); | |
*counter_guard += 1; | |
} | |
}) | |
}) | |
.collect_vec(); | |
threads | |
.into_iter() | |
.for_each(|thread| thread.join().unwrap()); | |
println!("COUNTER: {}", *counter.lock()); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment