Created
November 7, 2018 09:21
-
-
Save usbuild/ba21ff0079264260a222085e45615a71 to your computer and use it in GitHub Desktop.
A Simple cpp coroutine implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "co.hpp" | |
#include <assert.h> | |
#include <stdint.h> | |
#include <stdlib.h> | |
#include <string.h> | |
#include <map> | |
#include <mutex> | |
using namespace pm::common; | |
static void co_wrap_main(void) { | |
__asm__ __volatile__("\tmovq %r13, %rdi\n" // %rdi is the first argument | |
"\tjmpq *%r12\n"); | |
} | |
static inline void co_jump(co_jmp_buf from, co_jmp_buf to) { | |
__asm__ __volatile__("leaq 1f(%%rip), %%rax\n\t" | |
"movq %%rax, (%0)\n\t" | |
"movq %%rsp, 8(%0)\n\t" | |
"movq %%rbp, 16(%0)\n\t" | |
"movq %%rbx, 24(%0)\n\t" | |
"movq %%r12, 32(%0)\n\t" | |
"movq %%r13, 40(%0)\n\t" | |
"movq %%r14, 48(%0)\n\t" | |
"movq %%r15, 56(%0)\n\t" | |
"movq 56(%1), %%r15\n\t" | |
"movq 48(%1), %%r14\n\t" | |
"movq 40(%1), %%r13\n\t" | |
"movq 32(%1), %%r12\n\t" | |
"movq 24(%1), %%rbx\n\t" | |
"movq 16(%1), %%rbp\n\t" | |
"movq 8(%1), %%rsp\n\t" | |
"jmpq *(%1)\n" | |
"1:\n" | |
: "+S"(from), "+D"(to) | |
: | |
: "rax", "rcx", "rdx", "r8", "r9", "r10", "r11", "memory", "cc"); | |
} | |
std::map<uintptr_t, Coroutine*> g_co_map; | |
std::mutex g_co_mu; | |
#define CO_STACK_SIZE (16 * 1024 * 1024) | |
Coroutine::Coroutine(start_routine_t start_routine) { | |
stack_ = malloc(CO_STACK_SIZE); | |
sp_ = (char *)(stack_) + CO_STACK_SIZE; | |
start_routine_ = start_routine; | |
status_ = Status::CREATE; | |
force_unwind_ = false; | |
std::lock_guard<std::mutex> lk(g_co_mu); | |
g_co_map[(uintptr_t)stack_] = this; | |
} | |
Coroutine *Coroutine::GetCoroutine() { | |
std::lock_guard<std::mutex> lk(g_co_mu); | |
if (g_co_map.empty()) return nullptr; | |
int dummy; | |
auto it = g_co_map.lower_bound((uintptr_t)&dummy); | |
if (it == g_co_map.begin()) return nullptr; | |
it--; | |
Coroutine *co = it->second; | |
if (&dummy > co->stack_ && co->sp_ >= &dummy) { | |
return co; | |
} else { | |
return nullptr; | |
} | |
} | |
Coroutine::~Coroutine() { | |
if (status_ == Status::SUSPEND) { | |
force_unwind_ = true; | |
resume(NULL); | |
} | |
free(stack_); | |
std::lock_guard<std::mutex> lk(g_co_mu); | |
g_co_map.erase((uintptr_t)stack_); | |
} | |
void *Coroutine::resume(void *arg) { | |
assert(status_ == Status::CREATE || status_ == Status::SUSPEND); | |
co_jmp_buf target; | |
resume_arg_ = arg; | |
if (status_ == Status::CREATE) { | |
init_jmp_buf_(target); | |
} else if (status_ == Status::SUSPEND) { | |
memcpy(target, saved_ctx_, sizeof(target)); | |
} | |
status_ = Status::RUNNING; | |
co_jump(saved_ctx_, target); | |
if (status_ != Status::EXIT) | |
status_ = Status::SUSPEND; | |
return yield_arg_; | |
} | |
void *Coroutine::yield(void *ret) { | |
yield_arg_ = ret; | |
co_jmp_buf target; | |
memcpy(target, saved_ctx_, sizeof(target)); | |
co_jump(saved_ctx_, target); | |
if (force_unwind_) { | |
throw ForceUnwind{}; | |
} | |
return resume_arg_; | |
} | |
void Coroutine::init_jmp_buf_(co_jmp_buf regs) { | |
regs[0] = (void *)(co_wrap_main); | |
regs[1] = this->sp_; | |
regs[2] = NULL; | |
regs[3] = NULL; | |
regs[4] = reinterpret_cast<void *>(+[](Coroutine *self) { | |
try { | |
self->yield_arg_ = self->start_routine_(self->resume_arg_); | |
} catch (const ForceUnwind &) { | |
} | |
self->status_ = Status::EXIT; | |
co_jmp_buf tmp; | |
co_jump(tmp, self->saved_ctx_); | |
}); | |
regs[5] = this; | |
regs[6] = NULL; | |
regs[7] = NULL; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <stdint.h> | |
namespace pm { | |
namespace common { | |
typedef void *co_jmp_buf[8]; /* rip, rsp, rbp, rbx, r12, r13, r14, r15 */ | |
class Coroutine { | |
typedef void *(*start_routine_t)(void *); | |
public: | |
enum class Status { CREATE, SUSPEND, RUNNING, EXIT }; | |
struct ForceUnwind {}; | |
static Coroutine *GetCoroutine(); | |
private: | |
co_jmp_buf saved_ctx_; | |
void *stack_; | |
void *sp_; | |
start_routine_t start_routine_; | |
Status status_; | |
void *resume_arg_; | |
void *yield_arg_; | |
bool force_unwind_ = false; | |
const void *tag_; | |
void init_jmp_buf_(co_jmp_buf regs); | |
public: | |
Coroutine(start_routine_t start_routine); | |
~Coroutine(); | |
void *resume(void *arg); | |
void *yield(void *ret); | |
Status state() const { return this->status_; } | |
void setTag(const void *tag) { this->tag_ = tag; } | |
const void *getTag() const { return this->tag_; } | |
}; | |
} /* common */ | |
} /* pm */ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment