Created
May 18, 2022 23:47
-
-
Save ammarfaizi2/9072d28a270218567fe4ae9ab4166f0b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: GPL-2.0-only | |
/* | |
* Copyright (C) 2022 Ammar Faizi <[email protected]> | |
* | |
* A Simple multi-threaded coroutine written in C. | |
* Running on Linux x86-64. | |
*/ | |
#ifndef _GNU_SOURCE | |
#define _GNU_SOURCE | |
#endif | |
#include <sys/types.h> | |
#include <stdio.h> | |
#include <sys/mman.h> | |
#include <unistd.h> | |
#include <string.h> | |
#include <stdlib.h> | |
#include <stdbool.h> | |
#include <stdint.h> | |
#include <errno.h> | |
#include <pthread.h> | |
#ifndef likely | |
#define likely(EXPR) __builtin_expect((EXPR), 1) | |
#endif | |
#ifndef unlikely | |
#define unlikely(EXPR) __builtin_expect((EXPR), 0) | |
#endif | |
#ifndef noinline | |
#define noinline __attribute__((__noinline__)) | |
#endif | |
#define PTRU(PTR) ((void *)(PTR)) | |
#define UPTR(PTR) ((uintptr_t)(PTR)) | |
#define PAGE_SIZE 4096UL | |
#define ALIGN_PAGE_SIZE(PTR) PTRU(UPTR(PTR) & UPTR(-PAGE_SIZE)) | |
#define CORO_STACK_SIZE (1024*1024*8UL) | |
#define RESET_CORO_SP(PTR) ALIGN_PAGE_SIZE(UPTR(PTR) + CORO_STACK_SIZE) | |
#define ASM_PUSH_CALL_SAVED_REGS \ | |
"pushq %%r12\n\t" \ | |
"pushq %%r13\n\t" \ | |
"pushq %%r14\n\t" \ | |
"pushq %%r15\n\t" \ | |
"pushq %%rbp\n\t" \ | |
"pushq %%rbx\n\t" | |
#define ASM_POP_CALL_SAVED_REGS \ | |
"popq %%rbx\n\t" \ | |
"popq %%rbp\n\t" \ | |
"popq %%r15\n\t" \ | |
"popq %%r14\n\t" \ | |
"popq %%r13\n\t" \ | |
"popq %%r12\n\t" | |
struct coro_task_ctx { | |
void *sp; | |
void *cur_sp; | |
void *cur_ip; | |
void (*func)(void *user); | |
void *user_data; | |
}; | |
/* | |
* struct coro_task_ctx fields offset, used for inline ASM. | |
*/ | |
#define CT_STACK "0" | |
#define CT_CUR_RSP "8" | |
#define CT_CUR_RIP "16" | |
#define CT_FUNC "24" | |
#define CT_USER_DATA "32" | |
/* | |
* To save the main context %rsp. | |
*/ | |
static __thread void *g_main_rsp; | |
/* | |
* The number of tasks. | |
*/ | |
static __thread size_t g_nr_tasks = 0; | |
/* | |
* The current executing task. | |
*/ | |
static __thread size_t g_task_cur = 0; | |
/* | |
* Coroutine tasks. | |
*/ | |
static __thread struct coro_task_ctx *g_tasks = NULL; | |
/* | |
* Prepare the functions to be exported later, for a header file. | |
*/ | |
extern void coro_destroy(void); | |
extern int coro_init(size_t nr_tasks); | |
extern int coro_add_func(void (*func)(void *user), void *user); | |
extern void __coro_task_entry(struct coro_task_ctx *task); | |
extern void coro_run(void); | |
extern void schedule(void); | |
void coro_destroy(void) | |
{ | |
size_t i; | |
if (!g_tasks) | |
return; | |
for (i = 0; i < g_nr_tasks; i++) { | |
if (!g_tasks[i].sp) | |
continue; | |
munmap(g_tasks[i].sp, CORO_STACK_SIZE); | |
} | |
free(g_tasks); | |
g_tasks = NULL; | |
} | |
int coro_init(size_t nr_tasks) | |
{ | |
size_t i; | |
g_tasks = calloc(nr_tasks, sizeof(*g_tasks)); | |
if (!g_tasks) | |
return -ENOMEM; | |
g_nr_tasks = nr_tasks; | |
for (i = 0; i < nr_tasks; i++) { | |
void *rsp; | |
rsp = mmap(NULL, CORO_STACK_SIZE, PROT_READ|PROT_WRITE, | |
MAP_PRIVATE|MAP_ANONYMOUS|MAP_STACK|MAP_GROWSDOWN, | |
-1, 0); | |
if (rsp == MAP_FAILED) { | |
int ret = -errno; | |
coro_destroy(); | |
return ret; | |
} | |
g_tasks[i].sp = rsp; | |
g_tasks[i].cur_sp = RESET_CORO_SP(rsp); | |
g_tasks[i].cur_ip = NULL; | |
g_tasks[i].func = NULL; | |
} | |
return 0; | |
} | |
int coro_add_func(void (*func)(void *user), void *user) | |
{ | |
size_t i; | |
for (i = 0; i < g_nr_tasks; i++) { | |
if (g_tasks[i].func) | |
continue; | |
g_tasks[i].func = func; | |
g_tasks[i].user_data = user; | |
return 0; | |
} | |
return -EAGAIN; | |
} | |
/* | |
* This is the entry point for all coroutine tasks. | |
*/ | |
noinline void __coro_task_entry(struct coro_task_ctx *task) | |
{ | |
/* | |
* The @task->func() should call schedule() periodically to | |
* fairly execute all coroutine tasks. | |
*/ | |
task->func(task->user_data); | |
/* | |
* The coroutine task has finished, reset its state and give | |
* the control back to the main task. | |
*/ | |
task->func = NULL; | |
task->cur_ip = NULL; | |
task->user_data = NULL; | |
task->cur_sp = RESET_CORO_SP(task->sp); | |
__asm__ volatile ( | |
/* | |
* Load the main task's %rsp. | |
*/ | |
"movq %[g_main_rsp], %%rsp\n\t" | |
/* | |
* Give the control back to the main task. | |
*/ | |
"retq" | |
: | |
: [g_main_rsp]"m"(g_main_rsp) | |
/* | |
* We don't care about registers clobbering here, | |
* because this task has finished, all of them are | |
* unused after the "retq". | |
*/ | |
: "memory", "cc" | |
); | |
__builtin_unreachable(); | |
} | |
/* | |
* Suspend the coroutine task execution and give the control | |
* back to the main task. Must only be called from a coroutine | |
* task context. | |
*/ | |
noinline void schedule(void) | |
{ | |
struct coro_task_ctx *task = &g_tasks[g_task_cur]; | |
__asm__ volatile ( | |
/* | |
* Save the coroutine task registers. | |
*/ | |
ASM_PUSH_CALL_SAVED_REGS | |
/* | |
* Save the coroutine task %rsp and return address. | |
*/ | |
"leaq .Lschedule_ret_addr(%%rip), %%rax\n\t" | |
"pushq %%rax\n\t" | |
"movq %%rsp, " CT_CUR_RSP "(%[task])\n\t" | |
"movq %%rax, " CT_CUR_RIP "(%[task])\n\t" | |
/* | |
* Load the main task %rsp. | |
*/ | |
"movq %[g_main_rsp], %%rsp\n\t" | |
/* | |
* Give the control back to the main task. | |
*/ | |
"retq\n" | |
/* | |
* When the main task calls coro_switch_to(task), it will | |
* jump back here! | |
* | |
* Then we will return to the coroutine task to resume | |
* the execution. | |
* | |
* Restore the coroutine task registers below... | |
*/ | |
".Lschedule_ret_addr:\n\t" | |
ASM_POP_CALL_SAVED_REGS | |
: | |
: [task]"D"(task), | |
[g_main_rsp]"m"(g_main_rsp) | |
/* | |
* The task sees this as a function call. | |
* Clobber all call-clobbered registers here! | |
*/ | |
: "rax", "rsi", "rdx", "rcx", "r8", "r9", "r10", "r11", | |
"memory", "cc" | |
); | |
} | |
/* | |
* Start the coroutine @task. Must only be called for a | |
* @task that hasn't been started. | |
* | |
* (IOW, we must have @task->cur_ip == NULL here). | |
*/ | |
static void coro_start_task(struct coro_task_ctx *task) | |
{ | |
__asm__ volatile ( | |
/* | |
* Save the main task registers and return address. | |
*/ | |
ASM_PUSH_CALL_SAVED_REGS | |
"leaq .Lcoro_start_task_ret_addr(%%rip), %%rax\n\t" | |
"pushq %%rax\n\t" | |
/* | |
* Save the main task %rsp. | |
*/ | |
"movq %%rsp, %[g_main_rsp]\n\t" | |
/* | |
* Load the coroutine task %rsp. | |
* At this point, %rsp mod 4096 == 0. | |
*/ | |
"movq " CT_CUR_RSP "(%[task]), %%rsp\n\t" | |
/* | |
* Zero the frame pointer for a good backtrace. | |
* | |
* The original %rbp has been saved by | |
* ASM_PUSH_CALL_SAVED_REGS. Will be restored | |
* when the coroutine task calls schedule(). | |
*/ | |
"xorl %%ebp, %%ebp\n\t" | |
/* | |
* The System V ABI x86-64 mandates: | |
* On function entry, we must have %rsp mod 16 == 8. | |
*/ | |
"subq $8, %%rsp\n\t" | |
"jmp __coro_task_entry\n" | |
/* | |
* When the coroutine task calls the first schedule(), | |
* it will jump back here! | |
*/ | |
".Lcoro_start_task_ret_addr:\n\t" | |
ASM_POP_CALL_SAVED_REGS | |
: [g_main_rsp]"=m"(g_main_rsp) | |
: [task]"D"(task) | |
/* | |
* The task sees this as a function call. | |
* Clobber all call-clobbered registers here! | |
* | |
* Call-saved registers have already been | |
* preserved by ASM_{PUSH,POP}_CALL_SAVED_REGS. | |
*/ | |
: "rax", "rsi", "rdx", "rcx", "r8", "r9", "r10", "r11", | |
"memory", "cc" | |
); | |
} | |
/* | |
* Resume a coroutine @task that has been suspended by a | |
* schedule() call. Must only be called for already started | |
* @task (IOW, @task->cur_rip != NULL). | |
*/ | |
static void coro_resume_task(struct coro_task_ctx *task) | |
{ | |
__asm__ volatile ( | |
/* | |
* Save the main task registers and return address. | |
*/ | |
ASM_PUSH_CALL_SAVED_REGS | |
"leaq .Lcoro_resume_task_ret_addr(%%rip), %%rax\n\t" | |
"pushq %%rax\n\t" | |
/* | |
* Save the main task %rsp. | |
*/ | |
"movq %%rsp, %[g_main_rsp]\n\t" | |
/* | |
* Load the coroutine task's %rsp. | |
*/ | |
"movq " CT_CUR_RSP "(%[task]), %%rsp\n\t" | |
/* | |
* Give the control to the coroutine task. | |
*/ | |
"retq\n" | |
/* | |
* When the coroutine task calls schedule(), | |
* it will jump back here! | |
*/ | |
".Lcoro_resume_task_ret_addr:\n\t" | |
ASM_POP_CALL_SAVED_REGS | |
: [g_main_rsp]"=m"(g_main_rsp) | |
: [task]"D"(task) | |
/* | |
* The task sees this as a function call. | |
* Clobber all call-clobbered registers here! | |
* | |
* Call-saved registers have already been | |
* preserved by ASM_{PUSH,POP}_CALL_SAVED_REGS. | |
*/ | |
: "rax", "rsi", "rdx", "rcx", "r8", "r9", "r10", "r11", | |
"memory", "cc" | |
); | |
} | |
static void coro_switch_to(struct coro_task_ctx *task) | |
{ | |
if (unlikely(!task->cur_ip)) | |
/* | |
* This @task hasn't been started, start it! | |
*/ | |
coro_start_task(task); | |
else | |
/* | |
* This @task has been started, but it's suspended | |
* by a schedule() call, resume it! | |
*/ | |
coro_resume_task(task); | |
} | |
void coro_run(void) | |
{ | |
struct coro_task_ctx *task, *tasks = g_tasks; | |
bool all_clear; | |
size_t i; | |
repeat: | |
all_clear = true; | |
for (i = 0; i < g_nr_tasks; i++) { | |
task = &tasks[i]; | |
if (!task->func) | |
continue; | |
all_clear = false; | |
/* | |
* Let the schedule() function know which task is | |
* currently running. | |
*/ | |
g_task_cur = i; | |
coro_switch_to(task); | |
} | |
if (!all_clear) | |
goto repeat; | |
} | |
static void func_a(void *user) | |
{ | |
size_t i; | |
for (i = 0; i < 3; i++) { | |
usleep(500000); | |
printf("tid = %d, %s\n", gettid(), __func__); | |
schedule(); | |
} | |
(void)user; | |
} | |
static void func_b(void *user) | |
{ | |
size_t i; | |
for (i = 0; i < 3; i++) { | |
usleep(500000); | |
printf("tid = %d, %s\n", gettid(), __func__); | |
schedule(); | |
} | |
(void)user; | |
} | |
static void func_c(void *user) | |
{ | |
size_t i; | |
for (i = 0; i < 3; i++) { | |
usleep(500000); | |
printf("tid = %d, %s\n", gettid(), __func__); | |
schedule(); | |
} | |
(void)user; | |
} | |
void *thread_fn(void *arg) | |
{ | |
int ret; | |
ret = coro_init(10); | |
if (ret) { | |
errno = -ret; | |
perror("coro_init"); | |
return NULL; | |
} | |
ret = coro_add_func(func_a, NULL); | |
if (ret) | |
goto err_add; | |
ret = coro_add_func(func_b, NULL); | |
if (ret) | |
goto err_add; | |
ret = coro_add_func(func_c, NULL); | |
if (ret) | |
goto err_add; | |
coro_run(); | |
out: | |
coro_destroy(); | |
(void)arg; | |
return NULL; | |
err_add: | |
errno = -ret; | |
perror("coro_add_func"); | |
goto out; | |
} | |
#define NR_THREADS 10 | |
int main(void) | |
{ | |
pthread_t tr[NR_THREADS]; | |
int i; | |
for (i = 0; i < NR_THREADS; i++) { | |
if (pthread_create(&tr[i], NULL, thread_fn, NULL)) | |
break; | |
} | |
for (; i--;) | |
pthread_join(tr[i], NULL); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment