forked from mirrors/linux
		
	Replace the examples in the documentation by the ones from the user-space version and introduce the standalone examples from the user-space version such as the `CMutex<T>` type. The `CMutex<T>` example from the pinned-init repository [1] is used in several documentation examples in the user-space version instead of the kernel `Mutex<T>` type (as it's not available). In order to split off the pin-init crate, all examples need to be free of kernel-specific types. Link: https://github.com/rust-for-Linux/pinned-init [1] Signed-off-by: Benno Lossin <benno.lossin@proton.me> Reviewed-by: Fiona Behrens <me@kloenk.dev> Tested-by: Andreas Hindborg <a.hindborg@kernel.org> Link: https://lore.kernel.org/r/20250308110339.2997091-6-benno.lossin@proton.me Signed-off-by: Miguel Ojeda <ojeda@kernel.org>
		
			
				
	
	
		
			209 lines
		
	
	
	
		
			5.3 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
			
		
		
	
	
			209 lines
		
	
	
	
		
			5.3 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
// SPDX-License-Identifier: Apache-2.0 OR MIT
 | 
						|
 | 
						|
#![allow(clippy::undocumented_unsafe_blocks)]
 | 
						|
#![cfg_attr(feature = "alloc", feature(allocator_api))]
 | 
						|
#![allow(clippy::missing_safety_doc)]
 | 
						|
 | 
						|
use core::{
 | 
						|
    cell::{Cell, UnsafeCell},
 | 
						|
    marker::PhantomPinned,
 | 
						|
    ops::{Deref, DerefMut},
 | 
						|
    pin::Pin,
 | 
						|
    sync::atomic::{AtomicBool, Ordering},
 | 
						|
};
 | 
						|
use std::{
 | 
						|
    sync::Arc,
 | 
						|
    thread::{self, park, sleep, Builder, Thread},
 | 
						|
    time::Duration,
 | 
						|
};
 | 
						|
 | 
						|
use pin_init::*;
 | 
						|
#[expect(unused_attributes)]
 | 
						|
#[path = "./linked_list.rs"]
 | 
						|
pub mod linked_list;
 | 
						|
use linked_list::*;
 | 
						|
 | 
						|
pub struct SpinLock {
 | 
						|
    inner: AtomicBool,
 | 
						|
}
 | 
						|
 | 
						|
impl SpinLock {
 | 
						|
    #[inline]
 | 
						|
    pub fn acquire(&self) -> SpinLockGuard<'_> {
 | 
						|
        while self
 | 
						|
            .inner
 | 
						|
            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
 | 
						|
            .is_err()
 | 
						|
        {
 | 
						|
            while self.inner.load(Ordering::Relaxed) {
 | 
						|
                thread::yield_now();
 | 
						|
            }
 | 
						|
        }
 | 
						|
        SpinLockGuard(self)
 | 
						|
    }
 | 
						|
 | 
						|
    #[inline]
 | 
						|
    #[allow(clippy::new_without_default)]
 | 
						|
    pub const fn new() -> Self {
 | 
						|
        Self {
 | 
						|
            inner: AtomicBool::new(false),
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
pub struct SpinLockGuard<'a>(&'a SpinLock);
 | 
						|
 | 
						|
impl Drop for SpinLockGuard<'_> {
 | 
						|
    #[inline]
 | 
						|
    fn drop(&mut self) {
 | 
						|
        self.0.inner.store(false, Ordering::Release);
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
#[pin_data]
 | 
						|
pub struct CMutex<T> {
 | 
						|
    #[pin]
 | 
						|
    wait_list: ListHead,
 | 
						|
    spin_lock: SpinLock,
 | 
						|
    locked: Cell<bool>,
 | 
						|
    #[pin]
 | 
						|
    data: UnsafeCell<T>,
 | 
						|
}
 | 
						|
 | 
						|
impl<T> CMutex<T> {
 | 
						|
    #[inline]
 | 
						|
    pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
 | 
						|
        pin_init!(CMutex {
 | 
						|
            wait_list <- ListHead::new(),
 | 
						|
            spin_lock: SpinLock::new(),
 | 
						|
            locked: Cell::new(false),
 | 
						|
            data <- unsafe {
 | 
						|
                pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
 | 
						|
                    val.__pinned_init(slot.cast::<T>())
 | 
						|
                })
 | 
						|
            },
 | 
						|
        })
 | 
						|
    }
 | 
						|
 | 
						|
    #[inline]
 | 
						|
    pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
 | 
						|
        let mut sguard = self.spin_lock.acquire();
 | 
						|
        if self.locked.get() {
 | 
						|
            stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
 | 
						|
            // println!("wait list length: {}", self.wait_list.size());
 | 
						|
            while self.locked.get() {
 | 
						|
                drop(sguard);
 | 
						|
                park();
 | 
						|
                sguard = self.spin_lock.acquire();
 | 
						|
            }
 | 
						|
            // This does have an effect, as the ListHead inside wait_entry implements Drop!
 | 
						|
            #[expect(clippy::drop_non_drop)]
 | 
						|
            drop(wait_entry);
 | 
						|
        }
 | 
						|
        self.locked.set(true);
 | 
						|
        unsafe {
 | 
						|
            Pin::new_unchecked(CMutexGuard {
 | 
						|
                mtx: self,
 | 
						|
                _pin: PhantomPinned,
 | 
						|
            })
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    #[allow(dead_code)]
 | 
						|
    pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
 | 
						|
        // SAFETY: we have an exclusive reference and thus nobody has access to data.
 | 
						|
        unsafe { &mut *self.data.get() }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
unsafe impl<T: Send> Send for CMutex<T> {}
 | 
						|
unsafe impl<T: Send> Sync for CMutex<T> {}
 | 
						|
 | 
						|
pub struct CMutexGuard<'a, T> {
 | 
						|
    mtx: &'a CMutex<T>,
 | 
						|
    _pin: PhantomPinned,
 | 
						|
}
 | 
						|
 | 
						|
impl<T> Drop for CMutexGuard<'_, T> {
 | 
						|
    #[inline]
 | 
						|
    fn drop(&mut self) {
 | 
						|
        let sguard = self.mtx.spin_lock.acquire();
 | 
						|
        self.mtx.locked.set(false);
 | 
						|
        if let Some(list_field) = self.mtx.wait_list.next() {
 | 
						|
            let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
 | 
						|
            unsafe { (*wait_entry).thread.unpark() };
 | 
						|
        }
 | 
						|
        drop(sguard);
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
impl<T> Deref for CMutexGuard<'_, T> {
 | 
						|
    type Target = T;
 | 
						|
 | 
						|
    #[inline]
 | 
						|
    fn deref(&self) -> &Self::Target {
 | 
						|
        unsafe { &*self.mtx.data.get() }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
impl<T> DerefMut for CMutexGuard<'_, T> {
 | 
						|
    #[inline]
 | 
						|
    fn deref_mut(&mut self) -> &mut Self::Target {
 | 
						|
        unsafe { &mut *self.mtx.data.get() }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
#[pin_data]
 | 
						|
#[repr(C)]
 | 
						|
struct WaitEntry {
 | 
						|
    #[pin]
 | 
						|
    wait_list: ListHead,
 | 
						|
    thread: Thread,
 | 
						|
}
 | 
						|
 | 
						|
impl WaitEntry {
 | 
						|
    #[inline]
 | 
						|
    fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
 | 
						|
        pin_init!(Self {
 | 
						|
            thread: thread::current(),
 | 
						|
            wait_list <- ListHead::insert_prev(list),
 | 
						|
        })
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
#[cfg(not(any(feature = "std", feature = "alloc")))]
 | 
						|
fn main() {}
 | 
						|
 | 
						|
#[allow(dead_code)]
 | 
						|
#[cfg_attr(test, test)]
 | 
						|
#[cfg(any(feature = "std", feature = "alloc"))]
 | 
						|
fn main() {
 | 
						|
    let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
 | 
						|
    let mut handles = vec![];
 | 
						|
    let thread_count = 20;
 | 
						|
    let workload = if cfg!(miri) { 100 } else { 1_000 };
 | 
						|
    for i in 0..thread_count {
 | 
						|
        let mtx = mtx.clone();
 | 
						|
        handles.push(
 | 
						|
            Builder::new()
 | 
						|
                .name(format!("worker #{i}"))
 | 
						|
                .spawn(move || {
 | 
						|
                    for _ in 0..workload {
 | 
						|
                        *mtx.lock() += 1;
 | 
						|
                    }
 | 
						|
                    println!("{i} halfway");
 | 
						|
                    sleep(Duration::from_millis((i as u64) * 10));
 | 
						|
                    for _ in 0..workload {
 | 
						|
                        *mtx.lock() += 1;
 | 
						|
                    }
 | 
						|
                    println!("{i} finished");
 | 
						|
                })
 | 
						|
                .expect("should not fail"),
 | 
						|
        );
 | 
						|
    }
 | 
						|
    for h in handles {
 | 
						|
        h.join().expect("thread panicked");
 | 
						|
    }
 | 
						|
    println!("{:?}", &*mtx.lock());
 | 
						|
    assert_eq!(*mtx.lock(), workload * thread_count * 2);
 | 
						|
}
 |