Skip to content

Instantly share code, notes, and snippets.

@riga
Last active December 19, 2024 15:15
Show Gist options
  • Save riga/32ba6c77943eaf34f059663cd11d90df to your computer and use it in GitHub Desktop.
Save riga/32ba6c77943eaf34f059663cd11d90df to your computer and use it in GitHub Desktop.
Numpy random generator with C API
/*
Interface for accessing Numpy's random number generators and configurable
bit generators via its C API. Usage example in main.cpp below.
Author: Marcel Rieger
*/
#include <string>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include <numpy/arrayobject.h>
class NumpyGenerator {
public:
NumpyGenerator(std::string bit_generator_name)
: np_(nullptr),
random_mod_(nullptr),
gen_cls_(nullptr),
bit_gen_cls_(nullptr),
bit_gen_(nullptr),
gen_(nullptr),
gen_normal_(nullptr),
gen_uniform_(nullptr),
bit_generator_name_(bit_generator_name) {
// initialize python if not already done
if (!Py_IsInitialized()) {
Py_Initialize();
}
// initialize numpy
if (!PyArray_API) {
_import_array();
assert(PyArray_API);
}
// load module and class references
np_ = PyImport_ImportModule("numpy");
random_mod_ = PyObject_GetAttrString(np_, "random");
gen_cls_ = PyObject_GetAttrString(random_mod_, "Generator");
bit_gen_cls_ = PyObject_GetAttrString(random_mod_, bit_generator_name.c_str());
assert(bit_gen_cls_ != nullptr);
}
~NumpyGenerator() {
// reset all objects, no need to reset function and module pointers
reset(gen_);
reset(bit_gen_);
reset(bit_gen_cls_);
reset(gen_cls_);
}
void set_seed(uint64_t seed) {
// reset all python objects
reset(gen_normal_);
reset(gen_uniform_);
reset(gen_);
reset(bit_gen_);
// create bit generator
bit_gen_ = PyObject_CallFunctionObjArgs(bit_gen_cls_, PyLong_FromUnsignedLong(seed), NULL);
// create generator
gen_ = PyObject_CallFunctionObjArgs(gen_cls_, bit_gen_, NULL);
// lookup methods
gen_normal_ = PyObject_GetAttrString(gen_, "normal");
gen_uniform_ = PyObject_GetAttrString(gen_, "uniform");
}
double normal(size_t n, uint64_t seed = 0) {
if (seed > 0) {
set_seed(seed);
}
return next(gen_normal_, n);
}
double uniform(size_t n, uint64_t seed = 0) {
if (seed > 0) {
set_seed(seed);
}
return next(gen_uniform_, n);
}
private:
PyObject* np_;
PyObject* random_mod_;
PyObject* gen_cls_;
PyObject* bit_gen_cls_;
PyObject* bit_gen_;
PyObject* gen_;
PyObject* gen_normal_;
PyObject* gen_uniform_;
std::string bit_generator_name_;
void reset(PyObject*& obj) {
if (obj != nullptr) {
Py_DECREF(obj);
obj = nullptr;
}
}
double next(PyObject*& gen_func, size_t n) {
// all state objects must be initialized
assert(gen_func != nullptr);
// generate n random numbers, use the last one
PyObject* ret = nullptr;
for (size_t i = 0; i <= n; i++) {
ret = PyObject_CallFunctionObjArgs(gen_func, NULL);
}
// cast and cleanup
double value = PyFloat_AsDouble(ret);
Py_DECREF(ret);
return value;
}
};
/*
Compile and run via:
> g++ numpy_gen_c.cpp \
-o numpy_rnd \
-std=c++14 \
-I$(python3 -c "import numpy; print(numpy.get_include())") \
$(python3-config --includes) \
$(python3-config --ldflags) \
-lpython$(python3 -c "import sys; print('{0.major}.{0.minor}'.format(sys.version_info))")
> ./numpy_rnd
*/
#include <iostream>
#include <iomanip>
#include "gen_interface_numpy.h"
int main() {
// setup the generator with the SFC64 bit generator
NumpyGenerator ng("SFC64");
// first, second and fourth "normal" sampled numbers for seed 1
std::cout << "setting seed 1" << std::endl;
std::cout << std::setprecision(16) << "normal, 0: " << ng.normal(0, 1) << std::endl;
std::cout << std::setprecision(16) << "normal, 1: " << ng.normal(0) << std::endl;
std::cout << std::setprecision(16) << "normal, 3: " << ng.normal(1) << std::endl;
// same as
// import numpy as np
// gen = np.random.Generator(np.random.SFC64(1))
// print(list(gen.normal(size=4)[[0, 1, 3]]))
// using some large seed
ng.set_seed(uint64_t(18446744073709551615)); // max uint64
std::cout << "setting largest seed" << std::endl;
for (size_t i = 0; i < 5; i++) {
std::cout << std::setprecision(16) << "normal, " << i << ": " << ng.normal(0) << std::endl;
}
// first "uniform" sampled numbers for seed 1
std::cout << "setting seed 1" << std::endl;
std::cout << std::setprecision(16) << "uniform, 0: " << ng.uniform(0, 1) << std::endl;
// same as
// import numpy as np
// gen = np.random.Generator(np.random.SFC64(1))
// print(gen.uniform())
// you might want to shutdown the python interface at the end of your application
Py_Finalize();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment