Created
November 20, 2019 20:11
-
-
Save mratsim/6e9716d5aaaaa8c9f01ae10db6751432 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright MIT and Apache, Mamy Ratsimbazafy 2019 | |
import std/atomics #, optim_hints # for prefetch | |
const WV_CacheLineSize = 128 | |
type | |
Enqueueable = concept x, type T | |
x is ptr | |
x.next is Atomic[T] | |
ChannelMPSCunbounded[T: Enqueueable] = object | |
front: T | |
# TODO: align | |
back: Atomic[T] | |
dummy: T | |
proc initialize[T](chan: var ChannelMPSCunbounded[T], dummy: T) = | |
assert not dummy.isNil | |
dummy.next.store(nil, moRelaxed) | |
chan.dummy = dummy | |
chan.front = dummy | |
chan.back.store(dummy, moRelaxed) | |
assert not(chan.front.isNil) | |
assert not(chan.back.load(moRelaxed).isNil) | |
proc trySend[T](chan: var ChannelMPSCunbounded[T], src: sink T): bool = | |
## Send an item to the back of the channel | |
## As the channel as unbounded capacity, this should never fail | |
assert not(chan.front.isNil) | |
assert not(chan.back.load(moRelaxed).isNil) | |
src.next.store(nil, moRelaxed) | |
# Without that fence the store src.next = nil could be reordered after the exchange on the head | |
# and after the store to oldBack.next done by the next push, which would result in the pop | |
# incorrectly seeing the queue as empty. | |
# Also synchronise with the pop on prev.next | |
fence(moRelease) | |
let oldBack = chan.back.exchange(src, moRelaxed) | |
oldBack.next.store(src, moRelaxed) | |
return true | |
proc tryRecv[T](chan: var ChannelMPSCunbounded[T], dst: var T): bool = | |
## Try receiving the next item buffered in the channel | |
## Returns true if successful (channel was not empty) | |
assert not(chan.front.isNil) | |
assert not(chan.back.load(moRelaxed).isNil) | |
let first = chan.front # dummy | |
let next = first.next.load(moRelaxed) | |
if not next.isNil: | |
chan.front = next | |
# prefetch(first.next.load(moRelaxed)) | |
fence(moAcquire) | |
dst = next | |
return true | |
dst = nil | |
return false | |
# Sanity checks | |
# ------------------------------------------------------------------------------ | |
when isMainModule: | |
import strutils | |
# Data structure test | |
# -------------------------------------------------------- | |
# TODO: ensure that we don't write past the allocated buffer | |
# due to mismanagement of the front and back indices | |
# Multithreading tests | |
# -------------------------------------------------------- | |
when not compileOption("threads"): | |
{.error: "This requires --threads:on compilation flag".} | |
template sendLoop[T](chan: var ChannelMPSCunbounded[T], | |
data: sink T, | |
body: untyped): untyped = | |
while not chan.trySend(data): | |
body | |
template recvLoop[T](chan: var ChannelMPSCunbounded[T], | |
data: var T, | |
body: untyped): untyped = | |
while not chan.tryRecv(data): | |
body | |
type | |
WorkerKind = enum | |
Receiver | |
Sender1 | |
Sender2 | |
Sender3 | |
Val = ptr ValObj | |
ValObj = object | |
next: Atomic[Val] | |
val: int | |
ThreadArgs = object | |
ID: WorkerKind | |
chan: ptr ChannelMPSCunbounded[Val] | |
template Worker(id: WorkerKind, body: untyped): untyped {.dirty.} = | |
if args.ID == id: | |
body | |
template Worker(id: Slice[WorkerKind], body: untyped): untyped {.dirty.} = | |
if args.ID in id: | |
body | |
proc thread_func(args: ThreadArgs) = | |
# Worker RECEIVER: | |
# --------- | |
# <- chan | |
# <- chan | |
# <- chan | |
# | |
# Worker SENDER: | |
# --------- | |
# chan <- 42 | |
# chan <- 53 | |
# chan <- 64 | |
Worker(Receiver): | |
var counts: array[Sender1..Sender3, int] | |
for j in 0 ..< 30: | |
var val: Val | |
args.chan[].recvLoop(val): | |
# Busy loop, in prod we might want to yield the core/thread timeslice | |
discard | |
echo "Receiver got: ", val.val, " at address 0x", toLowerASCII toHex cast[ByteAddress](val) | |
let sender = WorkerKind(val.val div 10) | |
doAssert val.val == counts[sender] + ord(sender) * 10, "Incorrect value: " & $val.val | |
inc counts[sender] | |
freeShared(val) | |
Worker(Sender1..Sender3): | |
for j in 0 ..< 10: | |
let val = createShared(ValObj) | |
val.val = ord(args.ID) * 10 + j | |
args.chan[].sendLoop(val): | |
# Busy loop, in prod we might want to yield the core/thread timeslice | |
discard | |
const pad = spaces(18) | |
echo pad.repeat(ord(args.ID)), $args.ID, " sent: ", val.val | |
proc main() = | |
echo "Testing if 3 threads can send data to 1 consumer" | |
echo "------------------------------------------------------------------------" | |
var threads: array[4, Thread[ThreadArgs]] | |
let chan = createSharedU(ChannelMPSCunbounded[Val]) # CreateU is not zero-init | |
let dummy = createShared(ValObj) | |
chan[].initialize(dummy) | |
createThread(threads[0], thread_func, ThreadArgs(ID: Receiver, chan: chan)) | |
createThread(threads[1], thread_func, ThreadArgs(ID: Sender1, chan: chan)) | |
createThread(threads[2], thread_func, ThreadArgs(ID: Sender2, chan: chan)) | |
createThread(threads[3], thread_func, ThreadArgs(ID: Sender3, chan: chan)) | |
joinThread(threads[0]) | |
joinThread(threads[1]) | |
joinThread(threads[2]) | |
joinThread(threads[3]) | |
deallocShared(dummy) | |
deallocShared(chan) | |
echo "------------------------------------------------------------------------" | |
echo "Success" | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment