Last active
December 19, 2024 15:15
-
-
Save riga/32ba6c77943eaf34f059663cd11d90df to your computer and use it in GitHub Desktop.
Numpy random generator with C API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
Interface for accessing Numpy's random number generators and configurable | |
bit generators via its C API. Usage example in main.cpp below. | |
Author: Marcel Rieger | |
*/ | |
#include <string> | |
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION | |
#include <Python.h> | |
#include <numpy/arrayobject.h> | |
class NumpyGenerator { | |
public: | |
NumpyGenerator(std::string bit_generator_name) | |
: np_(nullptr), | |
random_mod_(nullptr), | |
gen_cls_(nullptr), | |
bit_gen_cls_(nullptr), | |
bit_gen_(nullptr), | |
gen_(nullptr), | |
gen_normal_(nullptr), | |
gen_uniform_(nullptr), | |
bit_generator_name_(bit_generator_name) { | |
// initialize python if not already done | |
if (!Py_IsInitialized()) { | |
Py_Initialize(); | |
} | |
// initialize numpy | |
if (!PyArray_API) { | |
_import_array(); | |
assert(PyArray_API); | |
} | |
// load module and class references | |
np_ = PyImport_ImportModule("numpy"); | |
random_mod_ = PyObject_GetAttrString(np_, "random"); | |
gen_cls_ = PyObject_GetAttrString(random_mod_, "Generator"); | |
bit_gen_cls_ = PyObject_GetAttrString(random_mod_, bit_generator_name.c_str()); | |
assert(bit_gen_cls_ != nullptr); | |
} | |
~NumpyGenerator() { | |
// reset all objects, no need to reset function and module pointers | |
reset(gen_); | |
reset(bit_gen_); | |
reset(bit_gen_cls_); | |
reset(gen_cls_); | |
} | |
void set_seed(uint64_t seed) { | |
// reset all python objects | |
reset(gen_normal_); | |
reset(gen_uniform_); | |
reset(gen_); | |
reset(bit_gen_); | |
// create bit generator | |
bit_gen_ = PyObject_CallFunctionObjArgs(bit_gen_cls_, PyLong_FromUnsignedLong(seed), NULL); | |
// create generator | |
gen_ = PyObject_CallFunctionObjArgs(gen_cls_, bit_gen_, NULL); | |
// lookup methods | |
gen_normal_ = PyObject_GetAttrString(gen_, "normal"); | |
gen_uniform_ = PyObject_GetAttrString(gen_, "uniform"); | |
} | |
double normal(size_t n, uint64_t seed = 0) { | |
if (seed > 0) { | |
set_seed(seed); | |
} | |
return next(gen_normal_, n); | |
} | |
double uniform(size_t n, uint64_t seed = 0) { | |
if (seed > 0) { | |
set_seed(seed); | |
} | |
return next(gen_uniform_, n); | |
} | |
private: | |
PyObject* np_; | |
PyObject* random_mod_; | |
PyObject* gen_cls_; | |
PyObject* bit_gen_cls_; | |
PyObject* bit_gen_; | |
PyObject* gen_; | |
PyObject* gen_normal_; | |
PyObject* gen_uniform_; | |
std::string bit_generator_name_; | |
void reset(PyObject*& obj) { | |
if (obj != nullptr) { | |
Py_DECREF(obj); | |
obj = nullptr; | |
} | |
} | |
double next(PyObject*& gen_func, size_t n) { | |
// all state objects must be initialized | |
assert(gen_func != nullptr); | |
// generate n random numbers, use the last one | |
PyObject* ret = nullptr; | |
for (size_t i = 0; i <= n; i++) { | |
ret = PyObject_CallFunctionObjArgs(gen_func, NULL); | |
} | |
// cast and cleanup | |
double value = PyFloat_AsDouble(ret); | |
Py_DECREF(ret); | |
return value; | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
Compile and run via: | |
> g++ numpy_gen_c.cpp \ | |
-o numpy_rnd \ | |
-std=c++14 \ | |
-I$(python3 -c "import numpy; print(numpy.get_include())") \ | |
$(python3-config --includes) \ | |
$(python3-config --ldflags) \ | |
-lpython$(python3 -c "import sys; print('{0.major}.{0.minor}'.format(sys.version_info))") | |
> ./numpy_rnd | |
*/ | |
#include <iostream> | |
#include <iomanip> | |
#include "gen_interface_numpy.h" | |
int main() { | |
// setup the generator with the SFC64 bit generator | |
NumpyGenerator ng("SFC64"); | |
// first, second and fourth "normal" sampled numbers for seed 1 | |
std::cout << "setting seed 1" << std::endl; | |
std::cout << std::setprecision(16) << "normal, 0: " << ng.normal(0, 1) << std::endl; | |
std::cout << std::setprecision(16) << "normal, 1: " << ng.normal(0) << std::endl; | |
std::cout << std::setprecision(16) << "normal, 3: " << ng.normal(1) << std::endl; | |
// same as | |
// import numpy as np | |
// gen = np.random.Generator(np.random.SFC64(1)) | |
// print(list(gen.normal(size=4)[[0, 1, 3]])) | |
// using some large seed | |
ng.set_seed(uint64_t(18446744073709551615)); // max uint64 | |
std::cout << "setting largest seed" << std::endl; | |
for (size_t i = 0; i < 5; i++) { | |
std::cout << std::setprecision(16) << "normal, " << i << ": " << ng.normal(0) << std::endl; | |
} | |
// first "uniform" sampled numbers for seed 1 | |
std::cout << "setting seed 1" << std::endl; | |
std::cout << std::setprecision(16) << "uniform, 0: " << ng.uniform(0, 1) << std::endl; | |
// same as | |
// import numpy as np | |
// gen = np.random.Generator(np.random.SFC64(1)) | |
// print(gen.uniform()) | |
// you might want to shutdown the python interface at the end of your application | |
Py_Finalize(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment