Last active
October 30, 2023 17:50
-
-
Save CryZe/0fde6bcdc24f8679ebac85e4da528734 to your computer and use it in GitHub Desktop.
Protected memory
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
marker::PhantomData, | |
mem, | |
ops::{Deref, DerefMut}, | |
ptr::{self, NonNull}, | |
}; | |
use bytemuck::{AnyBitPattern, Zeroable}; | |
mod utils; | |
pub struct Protected<T: ?Sized> { | |
ptr: NonNull<T>, | |
} | |
impl<T: AnyBitPattern> Protected<[T]> { | |
pub fn new_slice(len: usize) -> Self { | |
unsafe { | |
let ptr = utils::allocarray(len, mem::size_of::<T>()); | |
let ptr = NonNull::new(ptr::slice_from_raw_parts_mut(ptr.cast(), len)).unwrap(); | |
utils::mprotect_noaccess(ptr.as_ptr().cast()); | |
Self { ptr } | |
} | |
} | |
} | |
impl<T: AnyBitPattern> Protected<T> { | |
pub fn new() -> Self { | |
unsafe { | |
let ptr = utils::malloc(mem::size_of::<T>().next_multiple_of(mem::align_of::<T>())); | |
let ptr = NonNull::<T>::new(ptr.cast()).unwrap(); | |
utils::mprotect_noaccess(ptr.as_ptr().cast()); | |
Self { ptr } | |
} | |
} | |
} | |
impl<T: AnyBitPattern> Default for Protected<T> { | |
#[inline] | |
fn default() -> Self { | |
Self::new() | |
} | |
} | |
impl<T: Zeroable> Protected<T> { | |
pub fn new_zeroed() -> Self { | |
unsafe { | |
let ptr = utils::malloc(mem::size_of::<T>().next_multiple_of(mem::align_of::<T>())); | |
let ptr = NonNull::<T>::new(ptr.cast()).unwrap(); | |
utils::mprotect_readwrite(ptr.as_ptr().cast()); | |
ptr::write_bytes(ptr.as_ptr(), 0, 1); | |
utils::mprotect_noaccess(ptr.as_ptr().cast()); | |
Self { ptr } | |
} | |
} | |
} | |
impl<T: ?Sized> Protected<T> { | |
#[inline] | |
pub fn read(&mut self) -> ReadGuard<'_, T> { | |
unsafe { | |
utils::mprotect_readonly(self.ptr.as_ptr().cast()); | |
ReadGuard { | |
ptr: self.ptr, | |
_phantom: PhantomData, | |
} | |
} | |
} | |
#[inline] | |
pub fn write(&mut self) -> WriteGuard<'_, T> { | |
unsafe { | |
utils::mprotect_readwrite(self.ptr.as_ptr().cast()); | |
WriteGuard { | |
ptr: self.ptr, | |
_phantom: PhantomData, | |
} | |
} | |
} | |
} | |
impl<T: ?Sized> Drop for Protected<T> { | |
#[inline] | |
fn drop(&mut self) { | |
unsafe { | |
if mem::needs_drop::<T>() { | |
utils::mprotect_readwrite(self.ptr.as_ptr().cast()); | |
ptr::drop_in_place(self.ptr.as_ptr()); | |
} | |
utils::free(self.ptr.as_ptr().cast()); | |
} | |
} | |
} | |
pub struct ReadGuard<'a, T: ?Sized> { | |
ptr: NonNull<T>, | |
_phantom: PhantomData<&'a T>, | |
} | |
impl<T: ?Sized> Drop for ReadGuard<'_, T> { | |
#[inline] | |
fn drop(&mut self) { | |
unsafe { | |
utils::mprotect_noaccess(self.ptr.as_ptr().cast()); | |
} | |
} | |
} | |
impl<T: ?Sized> Deref for ReadGuard<'_, T> { | |
type Target = T; | |
#[inline] | |
fn deref(&self) -> &Self::Target { | |
unsafe { &*self.ptr.as_ptr() } | |
} | |
} | |
pub struct WriteGuard<'a, T: ?Sized> { | |
ptr: NonNull<T>, | |
_phantom: PhantomData<&'a mut T>, | |
} | |
impl<T: ?Sized> Drop for WriteGuard<'_, T> { | |
#[inline] | |
fn drop(&mut self) { | |
unsafe { | |
utils::mprotect_noaccess(self.ptr.as_ptr().cast()); | |
} | |
} | |
} | |
impl<T: ?Sized> Deref for WriteGuard<'_, T> { | |
type Target = T; | |
#[inline] | |
fn deref(&self) -> &Self::Target { | |
unsafe { &*self.ptr.as_ptr() } | |
} | |
} | |
impl<T: ?Sized> DerefMut for WriteGuard<'_, T> { | |
#[inline] | |
fn deref_mut(&mut self) -> &mut Self::Target { | |
unsafe { &mut *self.ptr.as_ptr() } | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{mem, ptr, slice, sync::OnceLock}; | |
#[cfg(unix)] | |
use nix::libc; | |
#[cfg(windows)] | |
use windows_sys::Win32::System::Memory; | |
type Canary = [u8; 16]; | |
fn get_canary() -> &'static Canary { | |
static CANARY: OnceLock<Canary> = OnceLock::new(); | |
CANARY.get_or_init(|| { | |
let mut data = [0; 16]; | |
getrandom::getrandom(&mut data).unwrap(); | |
data | |
}) | |
} | |
fn get_page_size() -> usize { | |
static PAGE_SIZE: OnceLock<usize> = OnceLock::new(); | |
*PAGE_SIZE.get_or_init(|| { | |
#[cfg(unix)] | |
let page_size = nix::unistd::sysconf(nix::unistd::SysconfVar::PAGE_SIZE) | |
.ok() | |
.flatten() | |
.and_then(|v| usize::try_from(v).ok()) | |
.unwrap_or(0x10000); | |
#[cfg(windows)] | |
let page_size = 4 << 10; // TODO: | |
if page_size < 16 || page_size < mem::size_of::<usize>() { | |
panic!(); | |
} | |
page_size | |
}) | |
} | |
unsafe fn memcmp(b1: *const u8, b2: *const u8, len: usize) -> i32 { | |
let mut d = 0; | |
for i in 0..len { | |
ptr::write_volatile( | |
&mut d, | |
ptr::read_volatile(&d) | (b1.add(i).read_volatile() ^ b2.add(i).read_volatile()), | |
); | |
} | |
(1 & (d as i32 - 1) >> 8) - 1 | |
} | |
pub unsafe fn mlock(_addr: *mut u8, _len: usize) -> i32 { | |
#[cfg(unix)] | |
{ | |
libc::madvise(_addr.cast(), _len, libc::MADV_DONTDUMP); | |
libc::mlock(_addr.cast(), _len) | |
} | |
#[cfg(windows)] | |
{ | |
-((Memory::VirtualLock(_addr.cast(), _len) == 0) as i32) | |
} | |
#[cfg(not(any(unix, windows)))] | |
-1 | |
} | |
pub unsafe fn munlock(addr: *mut u8, len: usize) -> i32 { | |
zeroize::Zeroize::zeroize(&mut *slice::from_raw_parts_mut(addr, len)); | |
#[cfg(unix)] | |
{ | |
libc::madvise(addr.cast(), len, libc::MADV_DODUMP); | |
libc::munlock(addr.cast(), len) as _ | |
} | |
#[cfg(windows)] | |
{ | |
-((Memory::VirtualUnlock(addr.cast(), len) == 0) as i32) | |
} | |
#[cfg(not(any(unix, windows)))] | |
-1 | |
} | |
unsafe fn mprotect_noaccess_inner(_ptr: *mut u8, _size: usize) -> i32 { | |
#[cfg(unix)] | |
{ | |
libc::mprotect(_ptr.cast(), _size, libc::PROT_NONE) as _ | |
} | |
#[cfg(windows)] | |
{ | |
-((Memory::VirtualProtect(_ptr.cast(), _size, Memory::PAGE_NOACCESS, &mut 0) == 0) as i32) | |
} | |
#[cfg(not(any(unix, windows)))] | |
-1 | |
} | |
unsafe fn mprotect_readonly_inner(_ptr: *mut u8, _size: usize) -> i32 { | |
#[cfg(unix)] | |
{ | |
libc::mprotect(_ptr.cast(), _size, libc::PROT_READ) as _ | |
} | |
#[cfg(windows)] | |
{ | |
-((Memory::VirtualProtect(_ptr.cast(), _size, Memory::PAGE_READONLY, &mut 0) == 0) as i32) | |
} | |
#[cfg(not(any(unix, windows)))] | |
-1 | |
} | |
unsafe fn mprotect_readwrite_inner(_ptr: *mut u8, _size: usize) -> i32 { | |
#[cfg(unix)] | |
{ | |
libc::mprotect(_ptr.cast(), _size, libc::PROT_READ | libc::PROT_WRITE) as _ | |
} | |
#[cfg(windows)] | |
{ | |
-((Memory::VirtualProtect(_ptr.cast(), _size, Memory::PAGE_READWRITE, &mut 0) == 0) as i32) | |
} | |
#[cfg(not(any(unix, windows)))] | |
-1 | |
} | |
#[inline] | |
fn page_round(size: usize) -> usize { | |
let page_mask: usize = get_page_size().wrapping_sub(1); | |
size.wrapping_add(page_mask) & !page_mask | |
} | |
fn alloc_aligned(_size: usize) -> *mut u8 { | |
#[cfg(unix)] | |
{ | |
let map_nocore = 0; // TODO: | |
let ptr = unsafe { | |
libc::mmap( | |
ptr::null_mut(), | |
_size, | |
libc::PROT_READ | libc::PROT_WRITE, | |
libc::MAP_ANON | libc::MAP_PRIVATE | map_nocore, | |
-1, | |
0, | |
) | |
}; | |
if ptr == libc::MAP_FAILED { | |
ptr::null_mut() | |
} else { | |
ptr.cast() | |
} | |
} | |
#[cfg(windows)] | |
unsafe { | |
Memory::VirtualAlloc( | |
ptr::null(), | |
_size, | |
Memory::MEM_COMMIT | Memory::MEM_RESERVE, | |
Memory::PAGE_READWRITE, | |
) | |
.cast() | |
} | |
#[cfg(not(any(unix, windows)))] | |
ptr::null_mut() | |
} | |
unsafe fn free_aligned(_ptr: *mut u8, _size: usize) { | |
#[cfg(unix)] | |
libc::munmap(_ptr.cast(), _size); | |
#[cfg(windows)] | |
unsafe { | |
Memory::VirtualFree(_ptr.cast(), 0, Memory::MEM_RELEASE); | |
} | |
} | |
unsafe fn unprotected_ptr_from_user_ptr(ptr: *mut u8) -> *mut u8 { | |
let canary_ptr: *mut u8 = ptr.offset(-(mem::size_of::<Canary>() as isize)); | |
let page_size = get_page_size(); | |
let page_mask = page_size.wrapping_sub(1); | |
let unprotected_ptr_u = canary_ptr as usize & !page_mask; | |
if unprotected_ptr_u <= page_size.wrapping_mul(2) { | |
panic!(); | |
} | |
unprotected_ptr_u as *mut u8 | |
} | |
fn malloc_inner(size: usize) -> *mut u8 { | |
let page_size = get_page_size(); | |
if size >= usize::MAX.wrapping_sub(page_size.wrapping_mul(4)) { | |
return ptr::null_mut(); | |
} | |
if page_size <= mem::size_of::<Canary>() || page_size < mem::size_of::<usize>() { | |
panic!(); | |
} | |
let size_with_canary = mem::size_of::<Canary>().wrapping_add(size); | |
let unprotected_size = page_round(size_with_canary); | |
let total_size = page_size | |
.wrapping_add(page_size) | |
.wrapping_add(unprotected_size) | |
.wrapping_add(page_size); | |
let base_ptr = alloc_aligned(total_size); | |
if base_ptr.is_null() { | |
return ptr::null_mut(); | |
} | |
unsafe { | |
let unprotected_ptr = base_ptr.add(page_size.wrapping_mul(2)); | |
mprotect_noaccess_inner(base_ptr.add(page_size), page_size); | |
let canary = get_canary(); | |
*unprotected_ptr.add(unprotected_size).cast() = *canary; | |
mprotect_noaccess_inner(unprotected_ptr.add(unprotected_size), page_size); | |
mlock(unprotected_ptr, unprotected_size); | |
let canary_ptr = unprotected_ptr | |
.add(page_round(size_with_canary)) | |
.offset(-(size_with_canary as isize)); | |
let user_ptr = canary_ptr.add(mem::size_of::<Canary>()); | |
*canary_ptr.cast() = *canary; | |
*base_ptr.cast() = unprotected_size; | |
mprotect_readonly_inner(base_ptr, page_size); | |
assert_eq!(unprotected_ptr_from_user_ptr(user_ptr), unprotected_ptr); | |
user_ptr | |
} | |
} | |
pub fn malloc(size: usize) -> *mut u8 { | |
let ptr = malloc_inner(size); | |
if ptr.is_null() { | |
return ptr::null_mut(); | |
} | |
unsafe { | |
ptr::write_bytes(ptr, 0xdb, size); | |
} | |
ptr | |
} | |
pub fn allocarray(count: usize, size: usize) -> *mut u8 { | |
if count > 0 && size >= usize::MAX.wrapping_div(count) { | |
return ptr::null_mut(); | |
} | |
malloc(count.wrapping_mul(size)) | |
} | |
pub unsafe fn free(ptr: *mut u8) { | |
if ptr.is_null() { | |
return; | |
} | |
let page_size = get_page_size(); | |
let canary_ptr = ptr.offset(-(mem::size_of::<Canary>() as isize)); | |
let unprotected_ptr = unprotected_ptr_from_user_ptr(ptr); | |
let base_ptr = unprotected_ptr.offset(-(page_size.wrapping_mul(2) as isize)); | |
let unprotected_size = *base_ptr.cast(); | |
let total_size = page_size | |
.wrapping_add(page_size) | |
.wrapping_add(unprotected_size) | |
.wrapping_add(page_size); | |
mprotect_readwrite_inner(base_ptr, total_size); | |
if memcmp(canary_ptr, get_canary().as_ptr(), mem::size_of::<Canary>()) != 0 { | |
panic!(); | |
} | |
if memcmp( | |
unprotected_ptr.add(unprotected_size), | |
get_canary().as_ptr(), | |
mem::size_of::<Canary>(), | |
) != 0 | |
{ | |
panic!(); | |
} | |
munlock(unprotected_ptr, unprotected_size); | |
free_aligned(base_ptr, total_size); | |
} | |
unsafe fn mprotect(ptr: *mut u8, cb: unsafe fn(*mut u8, usize) -> i32) -> i32 { | |
let unprotected_ptr = unprotected_ptr_from_user_ptr(ptr); | |
let base_ptr = unprotected_ptr.offset(-(get_page_size().wrapping_mul(2) as isize)); | |
let unprotected_size = *base_ptr.cast(); | |
cb(unprotected_ptr, unprotected_size) | |
} | |
pub unsafe fn mprotect_noaccess(ptr: *mut u8) -> i32 { | |
mprotect( | |
ptr, | |
mprotect_noaccess_inner as unsafe fn(*mut u8, usize) -> i32, | |
) | |
} | |
pub unsafe fn mprotect_readonly(ptr: *mut u8) -> i32 { | |
mprotect( | |
ptr, | |
mprotect_readonly_inner as unsafe fn(*mut u8, usize) -> i32, | |
) | |
} | |
pub unsafe fn mprotect_readwrite(ptr: *mut u8) -> i32 { | |
mprotect( | |
ptr, | |
mprotect_readwrite_inner as unsafe fn(*mut u8, usize) -> i32, | |
) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment