Skip to content

Instantly share code, notes, and snippets.

@cfsamson
Last active July 24, 2019 00:51
Show Gist options
  • Save cfsamson/3f3c694f7e235226f02188ee2f284c37 to your computer and use it in GitHub Desktop.
Save cfsamson/3f3c694f7e235226f02188ee2f284c37 to your computer and use it in GitHub Desktop.
Working Windows Context Switch in Crystal
# TODO: Save XMM6-XMM15 as well
STACK_SIZE = 1024 * 1024 * 8
SLEEP_INTERVAL = 0.2 # give the Windows some time to kill the process if NT_TIB is not set correctly
# Simple Context Switch Implementation for Crystal
# The two interesting functions is `FiberPoc.makecontext` and `Scheduler.swapcontext`.
# Setting up WSL and cross compile for Win 10 see:
# https://github.com/crystal-lang/crystal/wiki/Porting-to-Windows
# https://github.com/crystal-lang/crystal/issues/7932
puts "Context switch test for Crystal"
executor = Executor.new
executor.run
class Executor
@scheduler = Scheduler.new
def run
hello1 = ->{
(0..5).each do |i|
STDOUT << "Fiber1: " << i.to_s << "\n"
STDOUT.flush
sleep SLEEP_INTERVAL
@scheduler.yield_control
end
@scheduler.ret
}
hello2 = ->{
(0..10).each do |i|
STDOUT << "Fiber2: " << i.to_s << "\n"
STDOUT.flush
sleep SLEEP_INTERVAL
@scheduler.yield_control
end
@scheduler.ret
}
@scheduler.spawn(1, hello1)
@scheduler.spawn(2, hello2)
@scheduler.run
end
end
# makecontext is conditionally compiled and the interesting part to get to work on Windows.
class FiberPoc
{% if flag?(:win32) %}
def makecontext(proc : Proc)
@f = proc
s_top = @stack.to_unsafe
s_bottom = s_top.as(UInt8*) + STACK_SIZE # stack grows down -> high address is bottom
s_bottom = Pointer(UInt8).new(s_bottom.address & ~15) # align stack on 16 bytes
s_start = s_bottom - 32
# Our entry
# 8 registers + 2 qwords for NT_TIB + 1 parameter + 10 128bit XMM registers
@rsp = (s_start - (11*8 + 10*16)).as(Void*)
s_start_ptr = s_start.as(UInt64*)
s_start_ptr.value = ->(f : FiberPoc) { f.run }.pointer.address
# https://en.wikipedia.org/wiki/Win32_Thread_Information_Block
# stack end = "stack limit" = low address
stack_limit = (s_start - 24).as(UInt64*) # where to put the value on our stack
stack_limit.value = s_top.address
# stack top = "stack base" = high address
stack_base = (s_start - 16).as(UInt64*) # where to put the value on our stack
stack_base.value = s_bottom.address
# First parameter
first_param = (s_start - 8).as(UInt64*)
first_param.value = self.as(Void*).address
end
{% else %}
def makecontext(proc : Proc)
@f = proc
s_top = @stack.to_unsafe
s_bottom = s_top.as(UInt8*) + STACK_SIZE # stack grows down -> high address is bottom
s_bottom = Pointer(UInt8).new(s_bottom.address & ~15) # align stack on 16 bytes
# note that this is different from the original crystal impl (uses offset of 16 IIRC)
s_start = s_bottom - 32
# Our entry
# we pop 7 registers before we return
@rsp = (s_start - 7*8).as(Void*)
s_start_ptr = s_start.as(UInt64*)
s_start_ptr.value = ->(f : FiberPoc) { f.run }.pointer.address
# First parameter
first_param = (s_start - 8).as(UInt64*)
first_param.value = self.as(Void*).address
end
{% end %}
property stack : Array(UInt8)
property rsp : Void* = Pointer(Void).null
property f = Proc(Void).new { }
def initialize
@stack = Array.new(STACK_SIZE, 0_u8) # OK, as long as we don't push/pop -> reallocate
end
def run
@f.call
end
end
# Swapcontext is the interesting part in Scheduler since it mimics Crystals `swapcontext`
class Scheduler
@fibers = Array(FiberPoc).new(3)
@current = 0
@running = 1
@finished : Int32 = 0 # just a hack since we only have two executing fibers, this is the one that finished first
@[NoInline]
@[Naked]
def self.swapcontext(current : Pointer(Pointer(Void)), to : Pointer(Void)) : Nil
{% if flag?(:win32) %}
asm("
pushq %rcx
pushq %gs:0x10
pushq %gs:0x08
pushq %rdi
pushq %rbx
pushq %rbp
pushq %rsi
pushq %r12
pushq %r13
pushq %r14
pushq %r15
# XMM registers
subq $$160, %rsp
movups %xmm6, 0x00(%rsp)
movups %xmm7, 0x10(%rsp)
movups %xmm8, 0x20(%rsp)
movups %xmm9, 0x30(%rsp)
movups %xmm10, 0x40(%rsp)
movups %xmm11, 0x50(%rsp)
movups %xmm12, 0x60(%rsp)
movups %xmm13, 0x70(%rsp)
movups %xmm14, 0x80(%rsp)
movups %xmm15, 0x90(%rsp)
movq %rsp, ($0)
movq $1, %rsp
# XMM registers
movups 0x00(%rsp), %xmm6
movups 0x10(%rsp), %xmm7
movups 0x20(%rsp), %xmm8
movups 0x30(%rsp), %xmm9
movups 0x40(%rsp), %xmm10
movups 0x50(%rsp), %xmm11
movups 0x60(%rsp), %xmm12
movups 0x70(%rsp), %xmm13
movups 0x80(%rsp), %xmm14
movups 0x90(%rsp), %xmm15
addq $$160, %rsp
popq %r15
popq %r14
popq %r13
popq %r12
popq %rsi
popq %rbp
popq %rbx
popq %rdi
popq %gs:0x08
popq %gs:010
popq %rcx
"
:
: "r"(current), "r"(to)
)
{% else %}
asm("
pushq %rdi
pushq %rbx
pushq %rbp
pushq %r12
pushq %r13
pushq %r14
pushq %r15
movq %rsp, ($0)
movq $1, %rsp
popq %r15
popq %r14
popq %r13
popq %r12
popq %rbp
popq %rbx
popq %rdi
"
:: "r"(current), "r"(to))
{% end %}
end
def initialize
@fibers = Array(FiberPoc).new
(0..2).each do |i|
@fibers << FiberPoc.new
end
# base thread
@fibers[0].makecontext(->{})
end
def spawn(fiber_no : UInt64, f : Proc)
puts "Queuing fiber #{fiber_no}"
@running += 1
@fibers[fiber_no].makecontext(f)
end
def run
self.yield_control
end
def yield_control
# if we only have two threads, one is finished so we only swap if it just finished to continoue on the last
if @running == 2
if @current == @finished
next_ctx = @finished == 2 ? 1 : 2
current = @fibers[next_ctx]
@current = next_ctx
Scheduler.swapcontext(pointerof(current.@rsp), current.@rsp)
end
elsif @running == 3
if @current == 0
current = @fibers[0]
@current = 1
Scheduler.swapcontext(pointerof(current.@rsp), @fibers[1].@rsp)
elsif @current == 1
current = @fibers[1]
@current = 2
Scheduler.swapcontext(pointerof(current.@rsp), @fibers[2].@rsp)
elsif @current == 2
@current = 1
current = @fibers[2]
Scheduler.swapcontext(pointerof(current.@rsp), @fibers[1].@rsp)
end
end
end
def ret
STDOUT << "ret\n"
STDOUT.flush
@finished = @current
@running -= 1
if @running == 1
STDOUT << "Done. Exiting."
STDOUT.flush
Process.exit(0)
end
self.yield_control
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment