Skip to content

Instantly share code, notes, and snippets.

@mratsim
Created November 20, 2019 20:11
Show Gist options
  • Save mratsim/6e9716d5aaaaa8c9f01ae10db6751432 to your computer and use it in GitHub Desktop.
Save mratsim/6e9716d5aaaaa8c9f01ae10db6751432 to your computer and use it in GitHub Desktop.
# Copyright MIT and Apache, Mamy Ratsimbazafy 2019
import std/atomics #, optim_hints # for prefetch
const WV_CacheLineSize = 128
type
Enqueueable = concept x, type T
x is ptr
x.next is Atomic[T]
ChannelMPSCunbounded[T: Enqueueable] = object
front: T
# TODO: align
back: Atomic[T]
dummy: T
proc initialize[T](chan: var ChannelMPSCunbounded[T], dummy: T) =
assert not dummy.isNil
dummy.next.store(nil, moRelaxed)
chan.dummy = dummy
chan.front = dummy
chan.back.store(dummy, moRelaxed)
assert not(chan.front.isNil)
assert not(chan.back.load(moRelaxed).isNil)
proc trySend[T](chan: var ChannelMPSCunbounded[T], src: sink T): bool =
## Send an item to the back of the channel
## As the channel as unbounded capacity, this should never fail
assert not(chan.front.isNil)
assert not(chan.back.load(moRelaxed).isNil)
src.next.store(nil, moRelaxed)
# Without that fence the store src.next = nil could be reordered after the exchange on the head
# and after the store to oldBack.next done by the next push, which would result in the pop
# incorrectly seeing the queue as empty.
# Also synchronise with the pop on prev.next
fence(moRelease)
let oldBack = chan.back.exchange(src, moRelaxed)
oldBack.next.store(src, moRelaxed)
return true
proc tryRecv[T](chan: var ChannelMPSCunbounded[T], dst: var T): bool =
## Try receiving the next item buffered in the channel
## Returns true if successful (channel was not empty)
assert not(chan.front.isNil)
assert not(chan.back.load(moRelaxed).isNil)
let first = chan.front # dummy
let next = first.next.load(moRelaxed)
if not next.isNil:
chan.front = next
# prefetch(first.next.load(moRelaxed))
fence(moAcquire)
dst = next
return true
dst = nil
return false
# Sanity checks
# ------------------------------------------------------------------------------
when isMainModule:
import strutils
# Data structure test
# --------------------------------------------------------
# TODO: ensure that we don't write past the allocated buffer
# due to mismanagement of the front and back indices
# Multithreading tests
# --------------------------------------------------------
when not compileOption("threads"):
{.error: "This requires --threads:on compilation flag".}
template sendLoop[T](chan: var ChannelMPSCunbounded[T],
data: sink T,
body: untyped): untyped =
while not chan.trySend(data):
body
template recvLoop[T](chan: var ChannelMPSCunbounded[T],
data: var T,
body: untyped): untyped =
while not chan.tryRecv(data):
body
type
WorkerKind = enum
Receiver
Sender1
Sender2
Sender3
Val = ptr ValObj
ValObj = object
next: Atomic[Val]
val: int
ThreadArgs = object
ID: WorkerKind
chan: ptr ChannelMPSCunbounded[Val]
template Worker(id: WorkerKind, body: untyped): untyped {.dirty.} =
if args.ID == id:
body
template Worker(id: Slice[WorkerKind], body: untyped): untyped {.dirty.} =
if args.ID in id:
body
proc thread_func(args: ThreadArgs) =
# Worker RECEIVER:
# ---------
# <- chan
# <- chan
# <- chan
#
# Worker SENDER:
# ---------
# chan <- 42
# chan <- 53
# chan <- 64
Worker(Receiver):
var counts: array[Sender1..Sender3, int]
for j in 0 ..< 30:
var val: Val
args.chan[].recvLoop(val):
# Busy loop, in prod we might want to yield the core/thread timeslice
discard
echo "Receiver got: ", val.val, " at address 0x", toLowerASCII toHex cast[ByteAddress](val)
let sender = WorkerKind(val.val div 10)
doAssert val.val == counts[sender] + ord(sender) * 10, "Incorrect value: " & $val.val
inc counts[sender]
freeShared(val)
Worker(Sender1..Sender3):
for j in 0 ..< 10:
let val = createShared(ValObj)
val.val = ord(args.ID) * 10 + j
args.chan[].sendLoop(val):
# Busy loop, in prod we might want to yield the core/thread timeslice
discard
const pad = spaces(18)
echo pad.repeat(ord(args.ID)), $args.ID, " sent: ", val.val
proc main() =
echo "Testing if 3 threads can send data to 1 consumer"
echo "------------------------------------------------------------------------"
var threads: array[4, Thread[ThreadArgs]]
let chan = createSharedU(ChannelMPSCunbounded[Val]) # CreateU is not zero-init
let dummy = createShared(ValObj)
chan[].initialize(dummy)
createThread(threads[0], thread_func, ThreadArgs(ID: Receiver, chan: chan))
createThread(threads[1], thread_func, ThreadArgs(ID: Sender1, chan: chan))
createThread(threads[2], thread_func, ThreadArgs(ID: Sender2, chan: chan))
createThread(threads[3], thread_func, ThreadArgs(ID: Sender3, chan: chan))
joinThread(threads[0])
joinThread(threads[1])
joinThread(threads[2])
joinThread(threads[3])
deallocShared(dummy)
deallocShared(chan)
echo "------------------------------------------------------------------------"
echo "Success"
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment