Skip to content

Instantly share code, notes, and snippets.

@CryZe
Last active October 30, 2023 17:50
Show Gist options
  • Save CryZe/0fde6bcdc24f8679ebac85e4da528734 to your computer and use it in GitHub Desktop.
Save CryZe/0fde6bcdc24f8679ebac85e4da528734 to your computer and use it in GitHub Desktop.
Protected memory
use std::{
marker::PhantomData,
mem,
ops::{Deref, DerefMut},
ptr::{self, NonNull},
};
use bytemuck::{AnyBitPattern, Zeroable};
mod utils;
pub struct Protected<T: ?Sized> {
ptr: NonNull<T>,
}
impl<T: AnyBitPattern> Protected<[T]> {
pub fn new_slice(len: usize) -> Self {
unsafe {
let ptr = utils::allocarray(len, mem::size_of::<T>());
let ptr = NonNull::new(ptr::slice_from_raw_parts_mut(ptr.cast(), len)).unwrap();
utils::mprotect_noaccess(ptr.as_ptr().cast());
Self { ptr }
}
}
}
impl<T: AnyBitPattern> Protected<T> {
pub fn new() -> Self {
unsafe {
let ptr = utils::malloc(mem::size_of::<T>().next_multiple_of(mem::align_of::<T>()));
let ptr = NonNull::<T>::new(ptr.cast()).unwrap();
utils::mprotect_noaccess(ptr.as_ptr().cast());
Self { ptr }
}
}
}
impl<T: AnyBitPattern> Default for Protected<T> {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<T: Zeroable> Protected<T> {
pub fn new_zeroed() -> Self {
unsafe {
let ptr = utils::malloc(mem::size_of::<T>().next_multiple_of(mem::align_of::<T>()));
let ptr = NonNull::<T>::new(ptr.cast()).unwrap();
utils::mprotect_readwrite(ptr.as_ptr().cast());
ptr::write_bytes(ptr.as_ptr(), 0, 1);
utils::mprotect_noaccess(ptr.as_ptr().cast());
Self { ptr }
}
}
}
impl<T: ?Sized> Protected<T> {
#[inline]
pub fn read(&mut self) -> ReadGuard<'_, T> {
unsafe {
utils::mprotect_readonly(self.ptr.as_ptr().cast());
ReadGuard {
ptr: self.ptr,
_phantom: PhantomData,
}
}
}
#[inline]
pub fn write(&mut self) -> WriteGuard<'_, T> {
unsafe {
utils::mprotect_readwrite(self.ptr.as_ptr().cast());
WriteGuard {
ptr: self.ptr,
_phantom: PhantomData,
}
}
}
}
impl<T: ?Sized> Drop for Protected<T> {
#[inline]
fn drop(&mut self) {
unsafe {
if mem::needs_drop::<T>() {
utils::mprotect_readwrite(self.ptr.as_ptr().cast());
ptr::drop_in_place(self.ptr.as_ptr());
}
utils::free(self.ptr.as_ptr().cast());
}
}
}
pub struct ReadGuard<'a, T: ?Sized> {
ptr: NonNull<T>,
_phantom: PhantomData<&'a T>,
}
impl<T: ?Sized> Drop for ReadGuard<'_, T> {
#[inline]
fn drop(&mut self) {
unsafe {
utils::mprotect_noaccess(self.ptr.as_ptr().cast());
}
}
}
impl<T: ?Sized> Deref for ReadGuard<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.ptr.as_ptr() }
}
}
pub struct WriteGuard<'a, T: ?Sized> {
ptr: NonNull<T>,
_phantom: PhantomData<&'a mut T>,
}
impl<T: ?Sized> Drop for WriteGuard<'_, T> {
#[inline]
fn drop(&mut self) {
unsafe {
utils::mprotect_noaccess(self.ptr.as_ptr().cast());
}
}
}
impl<T: ?Sized> Deref for WriteGuard<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.ptr.as_ptr() }
}
}
impl<T: ?Sized> DerefMut for WriteGuard<'_, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.ptr.as_ptr() }
}
}
use std::{mem, ptr, slice, sync::OnceLock};
#[cfg(unix)]
use nix::libc;
#[cfg(windows)]
use windows_sys::Win32::System::Memory;
type Canary = [u8; 16];
fn get_canary() -> &'static Canary {
static CANARY: OnceLock<Canary> = OnceLock::new();
CANARY.get_or_init(|| {
let mut data = [0; 16];
getrandom::getrandom(&mut data).unwrap();
data
})
}
fn get_page_size() -> usize {
static PAGE_SIZE: OnceLock<usize> = OnceLock::new();
*PAGE_SIZE.get_or_init(|| {
#[cfg(unix)]
let page_size = nix::unistd::sysconf(nix::unistd::SysconfVar::PAGE_SIZE)
.ok()
.flatten()
.and_then(|v| usize::try_from(v).ok())
.unwrap_or(0x10000);
#[cfg(windows)]
let page_size = 4 << 10; // TODO:
if page_size < 16 || page_size < mem::size_of::<usize>() {
panic!();
}
page_size
})
}
unsafe fn memcmp(b1: *const u8, b2: *const u8, len: usize) -> i32 {
let mut d = 0;
for i in 0..len {
ptr::write_volatile(
&mut d,
ptr::read_volatile(&d) | (b1.add(i).read_volatile() ^ b2.add(i).read_volatile()),
);
}
(1 & (d as i32 - 1) >> 8) - 1
}
pub unsafe fn mlock(_addr: *mut u8, _len: usize) -> i32 {
#[cfg(unix)]
{
libc::madvise(_addr.cast(), _len, libc::MADV_DONTDUMP);
libc::mlock(_addr.cast(), _len)
}
#[cfg(windows)]
{
-((Memory::VirtualLock(_addr.cast(), _len) == 0) as i32)
}
#[cfg(not(any(unix, windows)))]
-1
}
pub unsafe fn munlock(addr: *mut u8, len: usize) -> i32 {
zeroize::Zeroize::zeroize(&mut *slice::from_raw_parts_mut(addr, len));
#[cfg(unix)]
{
libc::madvise(addr.cast(), len, libc::MADV_DODUMP);
libc::munlock(addr.cast(), len) as _
}
#[cfg(windows)]
{
-((Memory::VirtualUnlock(addr.cast(), len) == 0) as i32)
}
#[cfg(not(any(unix, windows)))]
-1
}
unsafe fn mprotect_noaccess_inner(_ptr: *mut u8, _size: usize) -> i32 {
#[cfg(unix)]
{
libc::mprotect(_ptr.cast(), _size, libc::PROT_NONE) as _
}
#[cfg(windows)]
{
-((Memory::VirtualProtect(_ptr.cast(), _size, Memory::PAGE_NOACCESS, &mut 0) == 0) as i32)
}
#[cfg(not(any(unix, windows)))]
-1
}
unsafe fn mprotect_readonly_inner(_ptr: *mut u8, _size: usize) -> i32 {
#[cfg(unix)]
{
libc::mprotect(_ptr.cast(), _size, libc::PROT_READ) as _
}
#[cfg(windows)]
{
-((Memory::VirtualProtect(_ptr.cast(), _size, Memory::PAGE_READONLY, &mut 0) == 0) as i32)
}
#[cfg(not(any(unix, windows)))]
-1
}
unsafe fn mprotect_readwrite_inner(_ptr: *mut u8, _size: usize) -> i32 {
#[cfg(unix)]
{
libc::mprotect(_ptr.cast(), _size, libc::PROT_READ | libc::PROT_WRITE) as _
}
#[cfg(windows)]
{
-((Memory::VirtualProtect(_ptr.cast(), _size, Memory::PAGE_READWRITE, &mut 0) == 0) as i32)
}
#[cfg(not(any(unix, windows)))]
-1
}
#[inline]
fn page_round(size: usize) -> usize {
let page_mask: usize = get_page_size().wrapping_sub(1);
size.wrapping_add(page_mask) & !page_mask
}
fn alloc_aligned(_size: usize) -> *mut u8 {
#[cfg(unix)]
{
let map_nocore = 0; // TODO:
let ptr = unsafe {
libc::mmap(
ptr::null_mut(),
_size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_ANON | libc::MAP_PRIVATE | map_nocore,
-1,
0,
)
};
if ptr == libc::MAP_FAILED {
ptr::null_mut()
} else {
ptr.cast()
}
}
#[cfg(windows)]
unsafe {
Memory::VirtualAlloc(
ptr::null(),
_size,
Memory::MEM_COMMIT | Memory::MEM_RESERVE,
Memory::PAGE_READWRITE,
)
.cast()
}
#[cfg(not(any(unix, windows)))]
ptr::null_mut()
}
unsafe fn free_aligned(_ptr: *mut u8, _size: usize) {
#[cfg(unix)]
libc::munmap(_ptr.cast(), _size);
#[cfg(windows)]
unsafe {
Memory::VirtualFree(_ptr.cast(), 0, Memory::MEM_RELEASE);
}
}
unsafe fn unprotected_ptr_from_user_ptr(ptr: *mut u8) -> *mut u8 {
let canary_ptr: *mut u8 = ptr.offset(-(mem::size_of::<Canary>() as isize));
let page_size = get_page_size();
let page_mask = page_size.wrapping_sub(1);
let unprotected_ptr_u = canary_ptr as usize & !page_mask;
if unprotected_ptr_u <= page_size.wrapping_mul(2) {
panic!();
}
unprotected_ptr_u as *mut u8
}
fn malloc_inner(size: usize) -> *mut u8 {
let page_size = get_page_size();
if size >= usize::MAX.wrapping_sub(page_size.wrapping_mul(4)) {
return ptr::null_mut();
}
if page_size <= mem::size_of::<Canary>() || page_size < mem::size_of::<usize>() {
panic!();
}
let size_with_canary = mem::size_of::<Canary>().wrapping_add(size);
let unprotected_size = page_round(size_with_canary);
let total_size = page_size
.wrapping_add(page_size)
.wrapping_add(unprotected_size)
.wrapping_add(page_size);
let base_ptr = alloc_aligned(total_size);
if base_ptr.is_null() {
return ptr::null_mut();
}
unsafe {
let unprotected_ptr = base_ptr.add(page_size.wrapping_mul(2));
mprotect_noaccess_inner(base_ptr.add(page_size), page_size);
let canary = get_canary();
*unprotected_ptr.add(unprotected_size).cast() = *canary;
mprotect_noaccess_inner(unprotected_ptr.add(unprotected_size), page_size);
mlock(unprotected_ptr, unprotected_size);
let canary_ptr = unprotected_ptr
.add(page_round(size_with_canary))
.offset(-(size_with_canary as isize));
let user_ptr = canary_ptr.add(mem::size_of::<Canary>());
*canary_ptr.cast() = *canary;
*base_ptr.cast() = unprotected_size;
mprotect_readonly_inner(base_ptr, page_size);
assert_eq!(unprotected_ptr_from_user_ptr(user_ptr), unprotected_ptr);
user_ptr
}
}
pub fn malloc(size: usize) -> *mut u8 {
let ptr = malloc_inner(size);
if ptr.is_null() {
return ptr::null_mut();
}
unsafe {
ptr::write_bytes(ptr, 0xdb, size);
}
ptr
}
pub fn allocarray(count: usize, size: usize) -> *mut u8 {
if count > 0 && size >= usize::MAX.wrapping_div(count) {
return ptr::null_mut();
}
malloc(count.wrapping_mul(size))
}
pub unsafe fn free(ptr: *mut u8) {
if ptr.is_null() {
return;
}
let page_size = get_page_size();
let canary_ptr = ptr.offset(-(mem::size_of::<Canary>() as isize));
let unprotected_ptr = unprotected_ptr_from_user_ptr(ptr);
let base_ptr = unprotected_ptr.offset(-(page_size.wrapping_mul(2) as isize));
let unprotected_size = *base_ptr.cast();
let total_size = page_size
.wrapping_add(page_size)
.wrapping_add(unprotected_size)
.wrapping_add(page_size);
mprotect_readwrite_inner(base_ptr, total_size);
if memcmp(canary_ptr, get_canary().as_ptr(), mem::size_of::<Canary>()) != 0 {
panic!();
}
if memcmp(
unprotected_ptr.add(unprotected_size),
get_canary().as_ptr(),
mem::size_of::<Canary>(),
) != 0
{
panic!();
}
munlock(unprotected_ptr, unprotected_size);
free_aligned(base_ptr, total_size);
}
unsafe fn mprotect(ptr: *mut u8, cb: unsafe fn(*mut u8, usize) -> i32) -> i32 {
let unprotected_ptr = unprotected_ptr_from_user_ptr(ptr);
let base_ptr = unprotected_ptr.offset(-(get_page_size().wrapping_mul(2) as isize));
let unprotected_size = *base_ptr.cast();
cb(unprotected_ptr, unprotected_size)
}
pub unsafe fn mprotect_noaccess(ptr: *mut u8) -> i32 {
mprotect(
ptr,
mprotect_noaccess_inner as unsafe fn(*mut u8, usize) -> i32,
)
}
pub unsafe fn mprotect_readonly(ptr: *mut u8) -> i32 {
mprotect(
ptr,
mprotect_readonly_inner as unsafe fn(*mut u8, usize) -> i32,
)
}
pub unsafe fn mprotect_readwrite(ptr: *mut u8) -> i32 {
mprotect(
ptr,
mprotect_readwrite_inner as unsafe fn(*mut u8, usize) -> i32,
)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment