Skip to content

Instantly share code, notes, and snippets.

@MMesch
Last active March 10, 2018 17:06
Show Gist options
  • Save MMesch/85356fefb778168aacb5423b4e729d22 to your computer and use it in GitHub Desktop.
Save MMesch/85356fefb778168aacb5423b4e729d22 to your computer and use it in GitHub Desktop.
Call Haskell from Python, mini project after https://wiki.python.org/moin/PythonVsHaskell
{-# LANGUAGE ForeignFunctionInterface #-}
module Geometry where
import Foreign.C.Types
import Foreign.C.String
add :: Num a => a -> a -> a
add x y = x + y
f1 :: CInt -> IO CInt
f1 x = do
return (42 + x)
f2 :: CFloat -> IO CFloat
f2 x = do
return (10.0 + x)
f3 :: CFloat -> IO CFloat
f3 x = do
return (add 10.0 x)
f4 :: CString -> IO CString
f4 s = do
w <- peekCString s
newCString (w ++ " world!")
foreign export ccall
f1 :: CInt -> IO CInt
foreign export ccall
f2 :: CFloat -> IO CFloat
foreign export ccall
f3 :: CFloat -> IO CFloat
foreign export ccall
f4 :: CString -> IO CString
HSRTS_LIB=HSrts-ghc8.2.2
MODULE_NAME=Geometry
all: build_lib clean_objects test_lib
build_lib:
stack exec -- ghc -O2 -dynamic -shared -fPIC --make -no-hs-main \
-optl '-shared' '-l$(HSRTS_LIB)' -optc '-DMODULE=$(MODULE_NAME)' \
-o Geometry.so Geometry.hs module_init.c
clean_objects:
rm *.o
rm *.hi
rm *_stub.h
test_lib:
python run_lib.py
clean:
rm *.so
#define CAT(a,b) XCAT(a,b)
#define XCAT(a,b) a ## b
#define STR(a) XSTR(a)
#define XSTR(a) #a
#include <HsFFI.h>
extern void CAT (__stginit_, MODULE) (void);
static void library_init (void) __attribute__ ((constructor));
static void
library_init (void)
{
/* This seems to be a no-op, but it makes the GHCRTS envvar work. */
static char *argv[] = { STR (MODULE) ".so", 0 }, **argv_ = argv;
static int argc = 1;
hs_init (&argc, &argv_);
hs_add_root (CAT (__stginit_, MODULE));
}
static void library_exit (void) __attribute__ ((destructor));
static void
library_exit (void)
{
hs_exit ();
}
from ctypes import *
lib = cdll.LoadLibrary('./Geometry.so')
funcs = {
# name restype argtypes input expected value
'f1': (c_int, [c_int], (10, 52)),
'f2': (c_float, [c_float], (10.0, 20.0)),
'f3': (c_float, [c_float], (11.0, 21.0)),
'f4': (c_char_p, [c_char_p], ("hello", "hello world!")),
}
for func in funcs:
f = getattr(lib, func)
f.restype, f.argtypes, test = funcs[func]
input, expected = test
assert f(input) == expected
print('{0}({1}) == {2}'.format(func, input, expected))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment