Skip to content

Instantly share code, notes, and snippets.

@ammarfaizi2
Created May 18, 2022 23:47
Show Gist options
  • Save ammarfaizi2/9072d28a270218567fe4ae9ab4166f0b to your computer and use it in GitHub Desktop.
Save ammarfaizi2/9072d28a270218567fe4ae9ab4166f0b to your computer and use it in GitHub Desktop.
// SPDX-License-Identifier: GPL-2.0-only
/*
* Copyright (C) 2022 Ammar Faizi <[email protected]>
*
* A Simple multi-threaded coroutine written in C.
* Running on Linux x86-64.
*/
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include <sys/types.h>
#include <stdio.h>
#include <sys/mman.h>
#include <unistd.h>
#include <string.h>
#include <stdlib.h>
#include <stdbool.h>
#include <stdint.h>
#include <errno.h>
#include <pthread.h>
#ifndef likely
#define likely(EXPR) __builtin_expect((EXPR), 1)
#endif
#ifndef unlikely
#define unlikely(EXPR) __builtin_expect((EXPR), 0)
#endif
#ifndef noinline
#define noinline __attribute__((__noinline__))
#endif
#define PTRU(PTR) ((void *)(PTR))
#define UPTR(PTR) ((uintptr_t)(PTR))
#define PAGE_SIZE 4096UL
#define ALIGN_PAGE_SIZE(PTR) PTRU(UPTR(PTR) & UPTR(-PAGE_SIZE))
#define CORO_STACK_SIZE (1024*1024*8UL)
#define RESET_CORO_SP(PTR) ALIGN_PAGE_SIZE(UPTR(PTR) + CORO_STACK_SIZE)
#define ASM_PUSH_CALL_SAVED_REGS \
"pushq %%r12\n\t" \
"pushq %%r13\n\t" \
"pushq %%r14\n\t" \
"pushq %%r15\n\t" \
"pushq %%rbp\n\t" \
"pushq %%rbx\n\t"
#define ASM_POP_CALL_SAVED_REGS \
"popq %%rbx\n\t" \
"popq %%rbp\n\t" \
"popq %%r15\n\t" \
"popq %%r14\n\t" \
"popq %%r13\n\t" \
"popq %%r12\n\t"
struct coro_task_ctx {
void *sp;
void *cur_sp;
void *cur_ip;
void (*func)(void *user);
void *user_data;
};
/*
* struct coro_task_ctx fields offset, used for inline ASM.
*/
#define CT_STACK "0"
#define CT_CUR_RSP "8"
#define CT_CUR_RIP "16"
#define CT_FUNC "24"
#define CT_USER_DATA "32"
/*
* To save the main context %rsp.
*/
static __thread void *g_main_rsp;
/*
* The number of tasks.
*/
static __thread size_t g_nr_tasks = 0;
/*
* The current executing task.
*/
static __thread size_t g_task_cur = 0;
/*
* Coroutine tasks.
*/
static __thread struct coro_task_ctx *g_tasks = NULL;
/*
* Prepare the functions to be exported later, for a header file.
*/
extern void coro_destroy(void);
extern int coro_init(size_t nr_tasks);
extern int coro_add_func(void (*func)(void *user), void *user);
extern void __coro_task_entry(struct coro_task_ctx *task);
extern void coro_run(void);
extern void schedule(void);
void coro_destroy(void)
{
size_t i;
if (!g_tasks)
return;
for (i = 0; i < g_nr_tasks; i++) {
if (!g_tasks[i].sp)
continue;
munmap(g_tasks[i].sp, CORO_STACK_SIZE);
}
free(g_tasks);
g_tasks = NULL;
}
int coro_init(size_t nr_tasks)
{
size_t i;
g_tasks = calloc(nr_tasks, sizeof(*g_tasks));
if (!g_tasks)
return -ENOMEM;
g_nr_tasks = nr_tasks;
for (i = 0; i < nr_tasks; i++) {
void *rsp;
rsp = mmap(NULL, CORO_STACK_SIZE, PROT_READ|PROT_WRITE,
MAP_PRIVATE|MAP_ANONYMOUS|MAP_STACK|MAP_GROWSDOWN,
-1, 0);
if (rsp == MAP_FAILED) {
int ret = -errno;
coro_destroy();
return ret;
}
g_tasks[i].sp = rsp;
g_tasks[i].cur_sp = RESET_CORO_SP(rsp);
g_tasks[i].cur_ip = NULL;
g_tasks[i].func = NULL;
}
return 0;
}
int coro_add_func(void (*func)(void *user), void *user)
{
size_t i;
for (i = 0; i < g_nr_tasks; i++) {
if (g_tasks[i].func)
continue;
g_tasks[i].func = func;
g_tasks[i].user_data = user;
return 0;
}
return -EAGAIN;
}
/*
* This is the entry point for all coroutine tasks.
*/
noinline void __coro_task_entry(struct coro_task_ctx *task)
{
/*
* The @task->func() should call schedule() periodically to
* fairly execute all coroutine tasks.
*/
task->func(task->user_data);
/*
* The coroutine task has finished, reset its state and give
* the control back to the main task.
*/
task->func = NULL;
task->cur_ip = NULL;
task->user_data = NULL;
task->cur_sp = RESET_CORO_SP(task->sp);
__asm__ volatile (
/*
* Load the main task's %rsp.
*/
"movq %[g_main_rsp], %%rsp\n\t"
/*
* Give the control back to the main task.
*/
"retq"
:
: [g_main_rsp]"m"(g_main_rsp)
/*
* We don't care about registers clobbering here,
* because this task has finished, all of them are
* unused after the "retq".
*/
: "memory", "cc"
);
__builtin_unreachable();
}
/*
* Suspend the coroutine task execution and give the control
* back to the main task. Must only be called from a coroutine
* task context.
*/
noinline void schedule(void)
{
struct coro_task_ctx *task = &g_tasks[g_task_cur];
__asm__ volatile (
/*
* Save the coroutine task registers.
*/
ASM_PUSH_CALL_SAVED_REGS
/*
* Save the coroutine task %rsp and return address.
*/
"leaq .Lschedule_ret_addr(%%rip), %%rax\n\t"
"pushq %%rax\n\t"
"movq %%rsp, " CT_CUR_RSP "(%[task])\n\t"
"movq %%rax, " CT_CUR_RIP "(%[task])\n\t"
/*
* Load the main task %rsp.
*/
"movq %[g_main_rsp], %%rsp\n\t"
/*
* Give the control back to the main task.
*/
"retq\n"
/*
* When the main task calls coro_switch_to(task), it will
* jump back here!
*
* Then we will return to the coroutine task to resume
* the execution.
*
* Restore the coroutine task registers below...
*/
".Lschedule_ret_addr:\n\t"
ASM_POP_CALL_SAVED_REGS
:
: [task]"D"(task),
[g_main_rsp]"m"(g_main_rsp)
/*
* The task sees this as a function call.
* Clobber all call-clobbered registers here!
*/
: "rax", "rsi", "rdx", "rcx", "r8", "r9", "r10", "r11",
"memory", "cc"
);
}
/*
* Start the coroutine @task. Must only be called for a
* @task that hasn't been started.
*
* (IOW, we must have @task->cur_ip == NULL here).
*/
static void coro_start_task(struct coro_task_ctx *task)
{
__asm__ volatile (
/*
* Save the main task registers and return address.
*/
ASM_PUSH_CALL_SAVED_REGS
"leaq .Lcoro_start_task_ret_addr(%%rip), %%rax\n\t"
"pushq %%rax\n\t"
/*
* Save the main task %rsp.
*/
"movq %%rsp, %[g_main_rsp]\n\t"
/*
* Load the coroutine task %rsp.
* At this point, %rsp mod 4096 == 0.
*/
"movq " CT_CUR_RSP "(%[task]), %%rsp\n\t"
/*
* Zero the frame pointer for a good backtrace.
*
* The original %rbp has been saved by
* ASM_PUSH_CALL_SAVED_REGS. Will be restored
* when the coroutine task calls schedule().
*/
"xorl %%ebp, %%ebp\n\t"
/*
* The System V ABI x86-64 mandates:
* On function entry, we must have %rsp mod 16 == 8.
*/
"subq $8, %%rsp\n\t"
"jmp __coro_task_entry\n"
/*
* When the coroutine task calls the first schedule(),
* it will jump back here!
*/
".Lcoro_start_task_ret_addr:\n\t"
ASM_POP_CALL_SAVED_REGS
: [g_main_rsp]"=m"(g_main_rsp)
: [task]"D"(task)
/*
* The task sees this as a function call.
* Clobber all call-clobbered registers here!
*
* Call-saved registers have already been
* preserved by ASM_{PUSH,POP}_CALL_SAVED_REGS.
*/
: "rax", "rsi", "rdx", "rcx", "r8", "r9", "r10", "r11",
"memory", "cc"
);
}
/*
* Resume a coroutine @task that has been suspended by a
* schedule() call. Must only be called for already started
* @task (IOW, @task->cur_rip != NULL).
*/
static void coro_resume_task(struct coro_task_ctx *task)
{
__asm__ volatile (
/*
* Save the main task registers and return address.
*/
ASM_PUSH_CALL_SAVED_REGS
"leaq .Lcoro_resume_task_ret_addr(%%rip), %%rax\n\t"
"pushq %%rax\n\t"
/*
* Save the main task %rsp.
*/
"movq %%rsp, %[g_main_rsp]\n\t"
/*
* Load the coroutine task's %rsp.
*/
"movq " CT_CUR_RSP "(%[task]), %%rsp\n\t"
/*
* Give the control to the coroutine task.
*/
"retq\n"
/*
* When the coroutine task calls schedule(),
* it will jump back here!
*/
".Lcoro_resume_task_ret_addr:\n\t"
ASM_POP_CALL_SAVED_REGS
: [g_main_rsp]"=m"(g_main_rsp)
: [task]"D"(task)
/*
* The task sees this as a function call.
* Clobber all call-clobbered registers here!
*
* Call-saved registers have already been
* preserved by ASM_{PUSH,POP}_CALL_SAVED_REGS.
*/
: "rax", "rsi", "rdx", "rcx", "r8", "r9", "r10", "r11",
"memory", "cc"
);
}
static void coro_switch_to(struct coro_task_ctx *task)
{
if (unlikely(!task->cur_ip))
/*
* This @task hasn't been started, start it!
*/
coro_start_task(task);
else
/*
* This @task has been started, but it's suspended
* by a schedule() call, resume it!
*/
coro_resume_task(task);
}
void coro_run(void)
{
struct coro_task_ctx *task, *tasks = g_tasks;
bool all_clear;
size_t i;
repeat:
all_clear = true;
for (i = 0; i < g_nr_tasks; i++) {
task = &tasks[i];
if (!task->func)
continue;
all_clear = false;
/*
* Let the schedule() function know which task is
* currently running.
*/
g_task_cur = i;
coro_switch_to(task);
}
if (!all_clear)
goto repeat;
}
static void func_a(void *user)
{
size_t i;
for (i = 0; i < 3; i++) {
usleep(500000);
printf("tid = %d, %s\n", gettid(), __func__);
schedule();
}
(void)user;
}
static void func_b(void *user)
{
size_t i;
for (i = 0; i < 3; i++) {
usleep(500000);
printf("tid = %d, %s\n", gettid(), __func__);
schedule();
}
(void)user;
}
static void func_c(void *user)
{
size_t i;
for (i = 0; i < 3; i++) {
usleep(500000);
printf("tid = %d, %s\n", gettid(), __func__);
schedule();
}
(void)user;
}
void *thread_fn(void *arg)
{
int ret;
ret = coro_init(10);
if (ret) {
errno = -ret;
perror("coro_init");
return NULL;
}
ret = coro_add_func(func_a, NULL);
if (ret)
goto err_add;
ret = coro_add_func(func_b, NULL);
if (ret)
goto err_add;
ret = coro_add_func(func_c, NULL);
if (ret)
goto err_add;
coro_run();
out:
coro_destroy();
(void)arg;
return NULL;
err_add:
errno = -ret;
perror("coro_add_func");
goto out;
}
#define NR_THREADS 10
int main(void)
{
pthread_t tr[NR_THREADS];
int i;
for (i = 0; i < NR_THREADS; i++) {
if (pthread_create(&tr[i], NULL, thread_fn, NULL))
break;
}
for (; i--;)
pthread_join(tr[i], NULL);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment