Skip to content

Instantly share code, notes, and snippets.

@mkitti
Created December 24, 2024 04:37
Show Gist options
  • Save mkitti/bc13a8c030502e0b00a415f732cf32bb to your computer and use it in GitHub Desktop.
Save mkitti/bc13a8c030502e0b00a415f732cf32bb to your computer and use it in GitHub Desktop.
Accelerate calling Julia from Python via juliacall and ctypes
#!/usr/bin/env python
# https://discourse.julialang.org/t/accelerating-calling-a-julia-function-from-python-via-juliacall-and-ctypes/124143
# mamba create -n juliacall_test python numpy ipython pyjuliacall
import ctypes
import numpy as np
import juliapkg
from juliacall import Main as jl
from timeit import timeit
# Julia setup
juliapkg.add("BenchmarkTools", "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf")
jl.seval("sumsquares(v) = sum(x->x^2, v)")
jl.seval("sumsquares(v::Ptr{Float64}, len::Int) = sumsquares(unsafe_wrap(Array, v, len; own = false))")
p = jl.seval("Int(@cfunction(sumsquares, Float64, (Ptr{Float64}, Int)))")
FUNCTYPE = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.POINTER(ctypes.c_double), ctypes.c_int64)
jl_sumsquares = FUNCTYPE(p)
# Python comparison
def sumsquares_numpy(v):
return np.sum(v**2)
def sumsquares_purepython(v):
return sum(map(lambda x: x**2, v))
# Timing setup
v = np.random.rand(100_000)
jl.seval("global vcopy")
jl.vcopy = jl.copy(v)
# Call Julia's sumsquares via juliacall, warmup
jl.sumsquares(v)
# Call Julia's sumsquares via ctypes, warmup
jl_sumsquares(v.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), len(v))
# Perform actual timing
number_of_calls = 1_000
julia_time = jl.seval("""
using BenchmarkTools;
@belapsed sumsquares(vcopy)
""")
print(f"sumsquares(vcopy) in Julia took {julia_time*10**6} microseconds")
juliacall_time = timeit(
"jl.sumsquares(v)",
"from __main__ import v, jl",
number=number_of_calls
)
print(f"jl.sumsquares(v) in Python via juliacall took {juliacall_time/number_of_calls*10**6} microseconds")
ctypes_time = timeit(
"jl_sumsquares(v.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), len(v))",
"from __main__ import v, jl_sumsquares, ctypes",
number=number_of_calls
)
print(f"jl_sumsquares(v) in Python via ctypes took {ctypes_time/number_of_calls*10**6} microseconds")
numpy_time = timeit(
"sumsquares_numpy(v)",
"from __main__ import v, sumsquares_numpy",
number=number_of_calls
)
print(f"sumsquares_numpy(v) in Python via numpy took {numpy_time/number_of_calls*10**6} microseconds")
purepython_time = timeit(
"sumsquares_purepython(v)",
"from __main__ import v, sumsquares_purepython",
number=number_of_calls
)
print(f"sumsquares_purepython(v) in pure Python took {purepython_time/number_of_calls*10**6} microseconds")
# Example Output
"""
$ ./juliacall_ctypes_test.py
sumsquares(vcopy) in Julia took 31.698 microseconds
jl.sumsquares(v) in Python via juliacall took 136.38922899917816 microseconds
jl_sumsquares(v) in Python via ctypes took 40.09708300145576 microseconds
sumsquares_numpy(v) in Python via numpy took 80.66400000097929 microseconds
sumsquares_purepython(v) in pure Python took 27677.13364700103 microseconds
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment