Created
December 24, 2024 04:37
-
-
Save mkitti/bc13a8c030502e0b00a415f732cf32bb to your computer and use it in GitHub Desktop.
Accelerate calling Julia from Python via juliacall and ctypes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# https://discourse.julialang.org/t/accelerating-calling-a-julia-function-from-python-via-juliacall-and-ctypes/124143 | |
# mamba create -n juliacall_test python numpy ipython pyjuliacall | |
import ctypes | |
import numpy as np | |
import juliapkg | |
from juliacall import Main as jl | |
from timeit import timeit | |
# Julia setup | |
juliapkg.add("BenchmarkTools", "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf") | |
jl.seval("sumsquares(v) = sum(x->x^2, v)") | |
jl.seval("sumsquares(v::Ptr{Float64}, len::Int) = sumsquares(unsafe_wrap(Array, v, len; own = false))") | |
p = jl.seval("Int(@cfunction(sumsquares, Float64, (Ptr{Float64}, Int)))") | |
FUNCTYPE = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.POINTER(ctypes.c_double), ctypes.c_int64) | |
jl_sumsquares = FUNCTYPE(p) | |
# Python comparison | |
def sumsquares_numpy(v): | |
return np.sum(v**2) | |
def sumsquares_purepython(v): | |
return sum(map(lambda x: x**2, v)) | |
# Timing setup | |
v = np.random.rand(100_000) | |
jl.seval("global vcopy") | |
jl.vcopy = jl.copy(v) | |
# Call Julia's sumsquares via juliacall, warmup | |
jl.sumsquares(v) | |
# Call Julia's sumsquares via ctypes, warmup | |
jl_sumsquares(v.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), len(v)) | |
# Perform actual timing | |
number_of_calls = 1_000 | |
julia_time = jl.seval(""" | |
using BenchmarkTools; | |
@belapsed sumsquares(vcopy) | |
""") | |
print(f"sumsquares(vcopy) in Julia took {julia_time*10**6} microseconds") | |
juliacall_time = timeit( | |
"jl.sumsquares(v)", | |
"from __main__ import v, jl", | |
number=number_of_calls | |
) | |
print(f"jl.sumsquares(v) in Python via juliacall took {juliacall_time/number_of_calls*10**6} microseconds") | |
ctypes_time = timeit( | |
"jl_sumsquares(v.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), len(v))", | |
"from __main__ import v, jl_sumsquares, ctypes", | |
number=number_of_calls | |
) | |
print(f"jl_sumsquares(v) in Python via ctypes took {ctypes_time/number_of_calls*10**6} microseconds") | |
numpy_time = timeit( | |
"sumsquares_numpy(v)", | |
"from __main__ import v, sumsquares_numpy", | |
number=number_of_calls | |
) | |
print(f"sumsquares_numpy(v) in Python via numpy took {numpy_time/number_of_calls*10**6} microseconds") | |
purepython_time = timeit( | |
"sumsquares_purepython(v)", | |
"from __main__ import v, sumsquares_purepython", | |
number=number_of_calls | |
) | |
print(f"sumsquares_purepython(v) in pure Python took {purepython_time/number_of_calls*10**6} microseconds") | |
# Example Output | |
""" | |
$ ./juliacall_ctypes_test.py | |
sumsquares(vcopy) in Julia took 31.698 microseconds | |
jl.sumsquares(v) in Python via juliacall took 136.38922899917816 microseconds | |
jl_sumsquares(v) in Python via ctypes took 40.09708300145576 microseconds | |
sumsquares_numpy(v) in Python via numpy took 80.66400000097929 microseconds | |
sumsquares_purepython(v) in pure Python took 27677.13364700103 microseconds | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment