Skip to content

Instantly share code, notes, and snippets.

@KJTsanaktsidis
Created January 10, 2024 23:25
Show Gist options
  • Save KJTsanaktsidis/40e2a8e23012bf16af823db9ff9a890e to your computer and use it in GitHub Desktop.
Save KJTsanaktsidis/40e2a8e23012bf16af823db9ff9a890e to your computer and use it in GitHub Desktop.
User fault handling example
#define _GNU_SOURCE
#include <err.h>
#include <fcntl.h>
#include <linux/userfaultfd.h>
#include <poll.h>
#include <pthread.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/syscall.h>
#include <stdbool.h>
#include <unistd.h>
static char *page;
static size_t page_size;
static void *stashed_page_location;
static int uffd;
static int wakepipe_r;
static void * fault_handler_thread(void *arg) {
while (true) {
struct pollfd pfd[2] = {0};
pfd[0].fd = uffd;
pfd[0].events = POLLIN;
pfd[1].fd = wakepipe_r;
pfd[1].events = POLLIN;
int nready = poll(pfd, 2, -1);
if (nready == -1) {
err(1, "failed to call poll");
}
if (pfd[0].revents & POLLIN) {
struct uffd_msg msg;
int nread = read(uffd, &msg, sizeof(msg));
if (nread == -1) {
err(1, "failed to read userfaultfd");
}
if (nread == 0) {
fprintf(stderr, "read EOF on userfaultfd\n");
exit(1);
}
if (msg.event != UFFD_EVENT_PAGEFAULT) {
fprintf(stderr, "unexpected event %d read from userfaultfd\n", msg.event);
exit(1);
}
/* needs to be stderr, becaused otherwise we'll deadlock accessing stdout whilst
* the other thread is also trying to access it, and blocked on a page fault we're handling */
fprintf(
stderr, "received pagefault: addr=0x%llx flags=0x%llx\n",
msg.arg.pagefault.flags, msg.arg.pagefault.address
);
/* In Ruby, we'd look up the page based on the address */
if (msg.arg.pagefault.address != (uintptr_t)page) {
fprintf(stderr, "unexpected faulting address (want 0x%lx)\n", (uintptr_t)page);
exit(1);
}
/* Here, we would do what the signal handler does to undo the object move of
* this page */
/* Then, put the original page mapping back and wake up the faulting thread */
void *orig_loc = mremap(
stashed_page_location, page_size, page_size, MREMAP_MAYMOVE | MREMAP_FIXED, page
);
if (orig_loc == MAP_FAILED) {
err(1, "mremap inside userfaultfd handler failed");
}
/* Needs to be reregistered after remapping */
struct uffdio_register registration;
registration.range.start = (uintptr_t)page;
registration.range.len = page_size;
registration.mode = UFFDIO_REGISTER_MODE_MISSING;
int ret = ioctl(uffd, UFFDIO_REGISTER, &registration);
if (ret == -1) {
err(1, "failed to call ioctl(UFFDIO_REGISTER) for remap");
}
struct uffdio_range wake_range;
wake_range.start = msg.arg.pagefault.address;
wake_range.len = page_size;
ret = ioctl(uffd, UFFDIO_WAKE, &wake_range);
if (ret == -1) {
err(1, "failed to ioctl(UFFDIO_WAKE)");
}
}
if (pfd[1].revents & POLLIN) {
char buf;
int nread = read(wakepipe_r, &buf, sizeof(buf));
if (nread == -1) {
err(1, "failed to read wakepipe");
}
return NULL;
}
}
}
int main(int argc, char **argv) {
int ret;
page_size = getpagesize();
page = mmap(NULL, page_size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (page == MAP_FAILED) {
err(1, "failed to mmap the test page");
}
/* fault the page in */
memset(page, 'X', page_size);
uffd = syscall(SYS_userfaultfd, O_CLOEXEC | O_NONBLOCK);
if (uffd == -1) {
err(1, "failed to call userfaultfd");
}
struct uffdio_api api_handshake;
api_handshake.api = UFFD_API;
api_handshake.features = 0;
ret = ioctl(uffd, UFFDIO_API, &api_handshake);
if (ret == -1) {
err(1, "failed to call ioctl(UFFDIO_API)");
}
struct uffdio_register registration;
registration.range.start = (uintptr_t)page;
registration.range.len = page_size;
registration.mode = UFFDIO_REGISTER_MODE_MISSING;
ret = ioctl(uffd, UFFDIO_REGISTER, &registration);
if (ret == -1) {
err(1, "failed to call ioctl(UFFDIO_REGISTER)");
}
int pipefds[2];
ret = pipe2(pipefds, 0);
if (ret == -1) {
err(1, "failed to call pipe2");
}
wakepipe_r = pipefds[0];
int wakepipe_w = pipefds[1];
pthread_t thr;
ret = pthread_create(&thr, NULL, fault_handler_thread, NULL);
if (ret != 0) {
err(1, "failed to call pthread_create");
}
/* Unmap the page and put it somewhere else */
stashed_page_location = mremap(page, page_size, page_size, MREMAP_MAYMOVE | MREMAP_DONTUNMAP);
if (stashed_page_location == MAP_FAILED) {
err(1, "failed to remap page to dummy location");
}
printf("migrated page at 0x%lx to 0x%lx\n", (uintptr_t)page, (uintptr_t)stashed_page_location);
printf("reading original location...\n");
fflush(stdout);
printf("read: %.*s\n", 12, page);
fflush(stdout);
/* Do it again, with a systemcall */
stashed_page_location = mremap(page, page_size, page_size, MREMAP_MAYMOVE | MREMAP_DONTUNMAP);
if (stashed_page_location == MAP_FAILED) {
err(1, "failed to remap page to dummy location");
}
printf("migrated page at 0x%lx to 0x%lx\n", (uintptr_t)page, (uintptr_t)stashed_page_location);
int buffer_fd = memfd_create("buffer", MFD_CLOEXEC);
printf("writing page into memfd...\n");
ret = write(buffer_fd, page, 12);
if (ret == -1) {
err(1, "failed to write page into memfd");
}
ret = lseek(buffer_fd, SEEK_SET, 0);
if (ret == -1) {
err(1, "failed to seek memfd");
}
char readback[12] = { 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g', 'g' };
ret = read(buffer_fd, readback, 12);
if (ret == -1) {
err(1, "failed to read page from memfd");
}
printf("read back: %.*s\n", 12, readback);
close(buffer_fd);
char dot = '.';
ret = write(wakepipe_w, &dot, sizeof(dot));
if (ret == -1) {
err(1, "faied to write wakepipe");
}
ret = pthread_join(thr, NULL);
if (ret != 0) {
err(1, "failed to call pthread_join");
}
close(wakepipe_r);
close(wakepipe_w);
close(uffd);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment