Last active
July 24, 2019 00:51
-
-
Save cfsamson/3f3c694f7e235226f02188ee2f284c37 to your computer and use it in GitHub Desktop.
Working Windows Context Switch in Crystal
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# TODO: Save XMM6-XMM15 as well | |
STACK_SIZE = 1024 * 1024 * 8 | |
SLEEP_INTERVAL = 0.2 # give the Windows some time to kill the process if NT_TIB is not set correctly | |
# Simple Context Switch Implementation for Crystal | |
# The two interesting functions is `FiberPoc.makecontext` and `Scheduler.swapcontext`. | |
# Setting up WSL and cross compile for Win 10 see: | |
# https://github.com/crystal-lang/crystal/wiki/Porting-to-Windows | |
# https://github.com/crystal-lang/crystal/issues/7932 | |
puts "Context switch test for Crystal" | |
executor = Executor.new | |
executor.run | |
class Executor | |
@scheduler = Scheduler.new | |
def run | |
hello1 = ->{ | |
(0..5).each do |i| | |
STDOUT << "Fiber1: " << i.to_s << "\n" | |
STDOUT.flush | |
sleep SLEEP_INTERVAL | |
@scheduler.yield_control | |
end | |
@scheduler.ret | |
} | |
hello2 = ->{ | |
(0..10).each do |i| | |
STDOUT << "Fiber2: " << i.to_s << "\n" | |
STDOUT.flush | |
sleep SLEEP_INTERVAL | |
@scheduler.yield_control | |
end | |
@scheduler.ret | |
} | |
@scheduler.spawn(1, hello1) | |
@scheduler.spawn(2, hello2) | |
@scheduler.run | |
end | |
end | |
# makecontext is conditionally compiled and the interesting part to get to work on Windows. | |
class FiberPoc | |
{% if flag?(:win32) %} | |
def makecontext(proc : Proc) | |
@f = proc | |
s_top = @stack.to_unsafe | |
s_bottom = s_top.as(UInt8*) + STACK_SIZE # stack grows down -> high address is bottom | |
s_bottom = Pointer(UInt8).new(s_bottom.address & ~15) # align stack on 16 bytes | |
s_start = s_bottom - 32 | |
# Our entry | |
# 8 registers + 2 qwords for NT_TIB + 1 parameter + 10 128bit XMM registers | |
@rsp = (s_start - (11*8 + 10*16)).as(Void*) | |
s_start_ptr = s_start.as(UInt64*) | |
s_start_ptr.value = ->(f : FiberPoc) { f.run }.pointer.address | |
# https://en.wikipedia.org/wiki/Win32_Thread_Information_Block | |
# stack end = "stack limit" = low address | |
stack_limit = (s_start - 24).as(UInt64*) # where to put the value on our stack | |
stack_limit.value = s_top.address | |
# stack top = "stack base" = high address | |
stack_base = (s_start - 16).as(UInt64*) # where to put the value on our stack | |
stack_base.value = s_bottom.address | |
# First parameter | |
first_param = (s_start - 8).as(UInt64*) | |
first_param.value = self.as(Void*).address | |
end | |
{% else %} | |
def makecontext(proc : Proc) | |
@f = proc | |
s_top = @stack.to_unsafe | |
s_bottom = s_top.as(UInt8*) + STACK_SIZE # stack grows down -> high address is bottom | |
s_bottom = Pointer(UInt8).new(s_bottom.address & ~15) # align stack on 16 bytes | |
# note that this is different from the original crystal impl (uses offset of 16 IIRC) | |
s_start = s_bottom - 32 | |
# Our entry | |
# we pop 7 registers before we return | |
@rsp = (s_start - 7*8).as(Void*) | |
s_start_ptr = s_start.as(UInt64*) | |
s_start_ptr.value = ->(f : FiberPoc) { f.run }.pointer.address | |
# First parameter | |
first_param = (s_start - 8).as(UInt64*) | |
first_param.value = self.as(Void*).address | |
end | |
{% end %} | |
property stack : Array(UInt8) | |
property rsp : Void* = Pointer(Void).null | |
property f = Proc(Void).new { } | |
def initialize | |
@stack = Array.new(STACK_SIZE, 0_u8) # OK, as long as we don't push/pop -> reallocate | |
end | |
def run | |
@f.call | |
end | |
end | |
# Swapcontext is the interesting part in Scheduler since it mimics Crystals `swapcontext` | |
class Scheduler | |
@fibers = Array(FiberPoc).new(3) | |
@current = 0 | |
@running = 1 | |
@finished : Int32 = 0 # just a hack since we only have two executing fibers, this is the one that finished first | |
@[NoInline] | |
@[Naked] | |
def self.swapcontext(current : Pointer(Pointer(Void)), to : Pointer(Void)) : Nil | |
{% if flag?(:win32) %} | |
asm(" | |
pushq %rcx | |
pushq %gs:0x10 | |
pushq %gs:0x08 | |
pushq %rdi | |
pushq %rbx | |
pushq %rbp | |
pushq %rsi | |
pushq %r12 | |
pushq %r13 | |
pushq %r14 | |
pushq %r15 | |
# XMM registers | |
subq $$160, %rsp | |
movups %xmm6, 0x00(%rsp) | |
movups %xmm7, 0x10(%rsp) | |
movups %xmm8, 0x20(%rsp) | |
movups %xmm9, 0x30(%rsp) | |
movups %xmm10, 0x40(%rsp) | |
movups %xmm11, 0x50(%rsp) | |
movups %xmm12, 0x60(%rsp) | |
movups %xmm13, 0x70(%rsp) | |
movups %xmm14, 0x80(%rsp) | |
movups %xmm15, 0x90(%rsp) | |
movq %rsp, ($0) | |
movq $1, %rsp | |
# XMM registers | |
movups 0x00(%rsp), %xmm6 | |
movups 0x10(%rsp), %xmm7 | |
movups 0x20(%rsp), %xmm8 | |
movups 0x30(%rsp), %xmm9 | |
movups 0x40(%rsp), %xmm10 | |
movups 0x50(%rsp), %xmm11 | |
movups 0x60(%rsp), %xmm12 | |
movups 0x70(%rsp), %xmm13 | |
movups 0x80(%rsp), %xmm14 | |
movups 0x90(%rsp), %xmm15 | |
addq $$160, %rsp | |
popq %r15 | |
popq %r14 | |
popq %r13 | |
popq %r12 | |
popq %rsi | |
popq %rbp | |
popq %rbx | |
popq %rdi | |
popq %gs:0x08 | |
popq %gs:010 | |
popq %rcx | |
" | |
: | |
: "r"(current), "r"(to) | |
) | |
{% else %} | |
asm(" | |
pushq %rdi | |
pushq %rbx | |
pushq %rbp | |
pushq %r12 | |
pushq %r13 | |
pushq %r14 | |
pushq %r15 | |
movq %rsp, ($0) | |
movq $1, %rsp | |
popq %r15 | |
popq %r14 | |
popq %r13 | |
popq %r12 | |
popq %rbp | |
popq %rbx | |
popq %rdi | |
" | |
:: "r"(current), "r"(to)) | |
{% end %} | |
end | |
def initialize | |
@fibers = Array(FiberPoc).new | |
(0..2).each do |i| | |
@fibers << FiberPoc.new | |
end | |
# base thread | |
@fibers[0].makecontext(->{}) | |
end | |
def spawn(fiber_no : UInt64, f : Proc) | |
puts "Queuing fiber #{fiber_no}" | |
@running += 1 | |
@fibers[fiber_no].makecontext(f) | |
end | |
def run | |
self.yield_control | |
end | |
def yield_control | |
# if we only have two threads, one is finished so we only swap if it just finished to continoue on the last | |
if @running == 2 | |
if @current == @finished | |
next_ctx = @finished == 2 ? 1 : 2 | |
current = @fibers[next_ctx] | |
@current = next_ctx | |
Scheduler.swapcontext(pointerof(current.@rsp), current.@rsp) | |
end | |
elsif @running == 3 | |
if @current == 0 | |
current = @fibers[0] | |
@current = 1 | |
Scheduler.swapcontext(pointerof(current.@rsp), @fibers[1].@rsp) | |
elsif @current == 1 | |
current = @fibers[1] | |
@current = 2 | |
Scheduler.swapcontext(pointerof(current.@rsp), @fibers[2].@rsp) | |
elsif @current == 2 | |
@current = 1 | |
current = @fibers[2] | |
Scheduler.swapcontext(pointerof(current.@rsp), @fibers[1].@rsp) | |
end | |
end | |
end | |
def ret | |
STDOUT << "ret\n" | |
STDOUT.flush | |
@finished = @current | |
@running -= 1 | |
if @running == 1 | |
STDOUT << "Done. Exiting." | |
STDOUT.flush | |
Process.exit(0) | |
end | |
self.yield_control | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment