Last active
December 13, 2019 14:55
-
-
Save nsf/466640dd541ed71f5479 to your computer and use it in GitHub Desktop.
Coroutines in nim
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import queues | |
import locks | |
import macros | |
import sequtils | |
type | |
SchedulerCommandType = enum | |
scDone, | |
scYield, | |
scWaitForCoroutine, | |
scWaitForCoroutines, | |
SchedulerCommand = object | |
case kind: SchedulerCommandType | |
of scDone: discard | |
of scYield: discard | |
of scWaitForCoroutine: | |
coroutine: CoroutineBase | |
of scWaitForCoroutines: | |
coroutines: seq[CoroutineBase] | |
Counter = ref object | |
count: int | |
waitingCoroutine: CoroutineBase | |
CoroutineBase = ref object of RootObj | |
iter: iterator(): SchedulerCommand | |
counter: Counter | |
Coroutine[T] = ref object of CoroutineBase | |
when T is not void: | |
result: T | |
type ThreadQueue[T] = object | |
queue: Queue[T] | |
mutex: Lock | |
cond: Cond | |
proc initThreadQueue[T]: ThreadQueue[T] = | |
result.queue = initQueue[T]() | |
initLock(result.mutex) | |
initCond(result.cond) | |
proc push[T](tq: var ThreadQueue[T], item: T) = | |
tq.mutex.acquire() | |
defer: tq.mutex.release() | |
tq.queue.enqueue(item) | |
tq.cond.signal() | |
proc pop[T](tq: var ThreadQueue[T]): T = | |
tq.mutex.acquire() | |
defer: tq.mutex.release() | |
while tq.queue.len == 0: | |
tq.cond.wait(tq.mutex) | |
tq.queue.dequeue() | |
var queue = initThreadQueue[CoroutineBase]() | |
proc go(c: CoroutineBase) = | |
queue.push(c) | |
proc schedule(c: CoroutineBase) = | |
let cmd = c.iter() | |
case cmd.kind | |
of scDone: | |
if c.counter != nil and atomicDec(c.counter.count) == 0: | |
go c.counter.waitingCoroutine | |
c.counter = nil | |
of scYield: | |
go c | |
of scWaitForCoroutine: | |
let counter = Counter(count: 1, waitingCoroutine: c) | |
cmd.coroutine.counter = counter | |
go cmd.coroutine | |
of scWaitForCoroutines: | |
let counter = Counter(count: cmd.coroutines.len, waitingCoroutine: c) | |
for c in cmd.coroutines: | |
c.counter = counter | |
go c | |
proc worker(n: int) = | |
echo "worker " & $n & " on duty" | |
while true: | |
let c = queue.pop() | |
if c == nil: | |
break | |
schedule(c) | |
echo "worker " & $n & " shutting down" | |
var workers: array[4, Thread[int]] | |
for i in 0..high(workers): | |
createThread(workers[i], cast[proc(n: int) {.gcsafe.}](worker), i) | |
proc shutdown() = | |
for i in 0..high(workers): | |
queue.push(nil) | |
template ppTree(e: expr) = | |
echo "-------------------- AST ---------------------" | |
echo treeRepr(e) | |
echo "------------------- Code ---------------------" | |
echo toStrLit(e) | |
echo "----------------------------------------------" | |
proc wrapAwaitValue(tmpSym, cmd, n: NimNode): NimNode = | |
result = newNimNode(nnkStmtList, n).add( | |
newLetStmt( | |
tmpSym, | |
newCall( | |
newIdentNode(!"coroutineAwaitValue"), | |
cmd[1] | |
) | |
), | |
newNimNode(nnkYieldStmt).add( | |
newDotExpr(tmpSym, newIdentNode(!"command")) | |
), | |
) | |
# Let's recursively convert various aspects of a coroutine's body | |
# 1. return expr | |
# -> | |
# retCoroutine.result = expr | |
# yield SchedulerCommand(kind: scDone) | |
# | |
# 2. return | |
# -> | |
# retCoroutine.result = result | |
# yield SchedulerCommand(kind: scDone) | |
# or (if there is no result) | |
# yield SchedulerCommand(kind: scDone) | |
# | |
# 3. await expr | |
# -> | |
# yield coroutineAwait(expr) | |
# | |
# 4. let x = await expr | |
# -> | |
# let tmp = coroutineAwaitValue(expr) | |
# yield tmp.command | |
# let x = tmp.value | |
# | |
# 5. var x = await expr (same as 4) | |
# 6. x = await expr (same as 4) | |
# 7. discard await expr (same as 4) | |
# 8. try statements are not allowed | |
proc convertToCoroutineBody(n, retSym: NimNode, hasResult: bool): NimNode = | |
result = n | |
case n.kind | |
of nnkReturnStmt: | |
result = newNimNode(nnkStmtList, n) | |
if n[0].kind == nnkEmpty: | |
# return | |
if hasResult: | |
result.add( | |
newAssignment( | |
newDotExpr(retSym, newIdentNode(!"result")), | |
newIdentNode(!"result"), | |
) | |
) | |
else: | |
# return expr | |
if not hasResult: | |
error("Non-void return inside a void coroutine") | |
result.add( | |
newAssignment( | |
newDotExpr(retSym, newIdentNode(!"result")), | |
n[0], | |
) | |
) | |
result.add( | |
newNimNode(nnkYieldStmt).add( | |
newNimNode(nnkObjConstr).add( | |
bindSym"SchedulerCommand", | |
newColonExpr( | |
newIdentNode(!"kind"), | |
bindSym"scDone" | |
) | |
) | |
) | |
) | |
of nnkCommand, nnkCall: | |
if n[0].kind == nnkIdent and n[0].ident == !"await": | |
# await expr | |
expectLen(n, 2) | |
result = newNimNode(nnkYieldStmt, n).add( | |
newCall( | |
newIdentNode(!"coroutineAwait"), | |
n[1] | |
) | |
) | |
of nnkVarSection, nnkLetSection: | |
let cmd = n[0][2] | |
case cmd.kind | |
of nnkCommand, nnkCall: | |
if cmd[0].kind == nnkIdent and cmd[0].ident == !"await": | |
# let x = await expr | |
expectLen(cmd, 2) | |
let tmpSym = genSym(nskLet, "await" & $n[0][0].ident) | |
result = wrapAwaitValue(tmpSym, cmd, n).add( | |
newNimNode(n.kind).add( | |
newNimNode(nnkIdentDefs).add( | |
n[0][0], | |
newNimNode(nnkEmpty), | |
newCall( | |
newDotExpr(tmpSym, newIdentNode(!"value")) | |
) | |
) | |
) | |
) | |
else: | |
discard | |
of nnkAsgn: | |
let cmd = n[1] | |
case cmd.kind | |
of nnkCommand, nnkCall: | |
if cmd[0].kind == nnkIdent and cmd[0].ident == !"await": | |
# x = await expr | |
expectLen(cmd, 2) | |
let tmpSym = genSym(nskLet, "await" & $n[0].ident) | |
result = wrapAwaitValue(tmpSym, cmd, n).add( | |
newAssignment( | |
n[0], | |
newCall( | |
newDotExpr(tmpSym, newIdentNode(!"value")) | |
) | |
) | |
) | |
else: | |
discard | |
of nnkDiscardStmt: | |
let cmd = n[0] | |
case cmd.kind | |
of nnkCommand, nnkCall: | |
if cmd[0].kind == nnkIdent and cmd[0].ident == !"await": | |
# discard await x | |
expectLen(cmd, 2) | |
let tmpSym = genSym(nskLet, "awaitDiscard") | |
result = wrapAwaitValue(tmpSym, cmd, n) | |
else: | |
discard | |
of nnkTryStmt: | |
error("try statements are not allowed in coroutine functions") | |
else: discard | |
# TODO: implicit return? | |
for i in 0..<result.len: | |
result[i] = convertToCoroutineBody(result[i], retSym, hasResult) | |
# We create a coroutine with an iterator here: | |
# let retCoroutine = Coroutine[T]() | |
# retCoroutine.iter = iterator() SchedulerCommand = | |
# {.push warning[resultshadowed]: off.} | |
# var result: T | |
# {.pop.} | |
# # <<< body >>> | |
# retCoroutine.result = result | |
# retCoroutine | |
# | |
# Existing body will be preprocessed and included as iterator body | |
proc convertToCoroutine(n: NimNode): NimNode = | |
#ppTree(n) | |
if n.kind notin {nnkProcDef, nnkLambda}: | |
error("Cannot transform this node kind into a coroutine") | |
hint("Converting " & $n[0].ident & " to coroutine") | |
let unRetType = n[3][0] | |
var retType: NimNode | |
case unRetType.kind | |
of nnkBracketExpr: | |
if unRetType[0].ident != !"Coroutine": | |
error("Return type of a coroutine should be Coroutine[T] or void") | |
retType = unRetType[1] | |
of nnkEmpty: | |
retType = newIdentNode(!"void") # all good, no return type means void | |
else: | |
error("Return type of a coroutine should be Coroutine[T] or void") | |
let retSym = genSym(nskLet, "retCoroutine") | |
let hasResult = retType.ident != !"void" | |
let coBody = newNimNode(nnkStmtList, n[6]) # second arg is used for line info | |
let itBody = convertToCoroutineBody(n[6], retSym, hasResult) | |
if hasResult: | |
itBody.insert(0, | |
newNimNode(nnkPragma).add( | |
newIdentNode("push"), | |
newNimNode(nnkExprColonExpr).add( | |
newNimNode(nnkBracketExpr).add( | |
newIdentNode("warning"), | |
newIdentNode("resultshadowed") | |
), | |
newIdentNode("off") | |
) | |
) | |
) | |
itBody.insert(1, | |
newNimNode(nnkVarSection, n[6]).add( | |
newIdentDefs(newIdentNode("result"), retType) | |
) | |
) | |
itBody.insert(2, | |
newNimNode(nnkPragma).add(newIdentNode("pop")) | |
) | |
itBody.add( | |
newAssignment( | |
newDotExpr(retSym, newIdentNode(!"result")), | |
newIdentNode(!"result"), | |
) | |
) | |
else: | |
discard | |
coBody.add( | |
newLetStmt( | |
retSym, | |
newCall( | |
newNimNode(nnkBracketExpr, n[6]).add(newIdentNode(!"Coroutine"), retType) | |
) | |
) | |
) | |
coBody.add( | |
newAssignment( | |
newDotExpr(retSym, newIdentNode(!"iter")), | |
newProc( | |
procType = nnkIteratorDef, | |
params = [bindSym"SchedulerCommand"], | |
body = itBody, | |
) | |
) | |
) | |
coBody.add(retSym) | |
result = n | |
# TODO: do I need this? | |
# for i in 0..<result[4].len: | |
# if result[4][i].kind == nnkIdent and result[4][i].ident == !"coroutine": | |
# result[4].del(i) | |
result[6] = coBody | |
ppTree(result) | |
macro coroutine(n: stmt): stmt {.immediate.} = | |
convertToCoroutine(n) | |
#============================================================================== | |
type Awaiter[T] = object | |
command: SchedulerCommand | |
value: proc(): T | |
proc coroutineAwait(coroutines: seq[CoroutineBase]): SchedulerCommand = | |
SchedulerCommand(kind: scWaitForCoroutines, coroutines: coroutines) | |
proc coroutineAwait(coroutine: CoroutineBase): SchedulerCommand = | |
SchedulerCommand(kind: scWaitForCoroutine, coroutine: coroutine) | |
proc coroutineAwaitValue[T](coroutines: seq[Coroutine[T]]): Awaiter[seq[T]] = | |
let cb = map(coroutines, proc(c: Coroutine[T]): CoroutineBase = c) | |
result.command = SchedulerCommand(kind: scWaitForCoroutines, coroutines: cb) | |
result.value = proc(): seq[T] = | |
result = newSeq[T](coroutines.len) | |
for i in 0..high(coroutines): | |
result[i] = coroutines[i].result | |
proc computeNumber(i, n: int): Coroutine[int] {.coroutine.} = | |
echo "computing... ", i | |
var count = i | |
for j in 0..<n: | |
count += j | |
return count | |
proc computeAll(n: int): Coroutine[void] {.coroutine.} = | |
var s = newSeq[Coroutine[int]]() | |
for i in 0..<100: | |
s.add(computeNumber(i, n)) | |
let values = await s | |
for v in values: | |
echo v | |
echo "done" | |
shutdown() | |
queue.push(computeAll(10000)) | |
joinThreads(workers) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment