Skip to content

Instantly share code, notes, and snippets.

@louisswarren
Created November 15, 2025 10:07
Show Gist options
  • Select an option

  • Save louisswarren/aac8ceb49c2424db6fc41f0668110263 to your computer and use it in GitHub Desktop.

Select an option

Save louisswarren/aac8ceb49c2424db6fc41f0668110263 to your computer and use it in GitHub Desktop.
Unit testing C via python
#include <stdio.h>
#include <string.h>
#include <assert.h>
#define _concat(X, Y) X##Y
#define concat(X, Y) _concat(X, Y)
#ifdef testmode
#define test_(NAME, CTR) \
const char *concat(_name_test_, CTR) = NAME; \
const char *concat(_test_, CTR)(void)
#else
#define test_(NAME, CTR) \
const char *concat(_name_test_, CTR) = NULL; \
static const char *concat(_test_, CTR)(void)
#endif
#define test(NAME) test_(NAME, __COUNTER__)
char *cpystr(char *dst, const char *src, size_t max)
{
while (max--) {
*dst = *(src++);
if (!*(dst++))
break;
}
return dst;
}
test("cpystr handles max = 0") {
char dst[1] = {'A'};
char src[1] = {'B'};
if (cpystr(dst, src, 0) != dst)
return "Wrong return value";
if (dst[0] != 'A')
return "dst modified";
return 0;
}
test("cpystr copies shorter src correctly") {
char dst[5] = "ABCD";
char src[3] = "XY";
if (cpystr(dst, src, 0) != dst + 3)
return "Wrong return value";
if (memcmp(dst, src, 3))
return "Mismatch";
if (dst[4] != 'D')
return "Tail not preserved";
return 0;
}
int
main(void)
{
return 0;
}
.PHONY: test
test: example.so
python test.py ./$<
example: example.c
$(CC) $(CFLAGS) $(LDFLAGS) $(LDLIBS) -o $@ $<
example.so: example.c
$(CC) $(CFLAGS) $(LDFLAGS) -shared -Dtestmode -o $@ $<
.PHONY: clean
clean:
rm -f example.so example
import subprocess
import sys
import unittest
from ctypes import *
class TestGenerated(unittest.TestCase):
pass
def get_symbols(libpath):
output = subprocess.check_output(["nm", "-D", "--defined-only", libpath])
for line in output.decode().strip().split("\n"):
_, _, symbol = line.split()
yield symbol
def make_tests(libpath):
lib = cdll.LoadLibrary(libpath)
for symbol in get_symbols(libpath):
if not symbol.startswith("_test_"):
continue
func = getattr(lib, symbol)
func.restype = c_char_p
def test(self, f=func):
r = f()
if r:
raise self.failureException(r.decode())
name_p = c_char_p.in_dll(lib, "_name" + symbol)
test.__doc__ = name_p.value.decode()
setattr(TestGenerated, f"test_{symbol[len('_test_'):]}", test)
if __name__ == "__main__":
make_tests(sys.argv.pop(1))
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment