Last active
November 25, 2024 09:13
-
-
Save maedoc/3ce7084078f915f3846b2b66068c81fa to your computer and use it in GitHub Desktop.
Fused kernels for simulations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include<stdbool.h> | |
#include<stdio.h> | |
struct sim { | |
const int rng_seed; | |
const int num_item; | |
const int num_node; | |
const int num_svar; | |
const int num_time; | |
const int num_params; | |
const int num_spatial_params; | |
const int num_simd; | |
const int num_batch; | |
// time step | |
const float dt; | |
// TODO ode substeps for stability | |
// const int num_substeps; | |
// num of dt to skip per output sample | |
const int num_skip; | |
float *state_trace; // (num_time//num_skip, num_svar, num_nodes, num_item) | |
// noise scaling per svar, sigma[:]*sqrt(dt) | |
const float *z_scale; // (num_svar, num_item) | |
// full states | |
float *states; // (num_svar, num_nodes, num_item) | |
// delay buffer for coupling variables | |
const int horizon; | |
const int horizon_minus_1; | |
const bool horizon_is_pow_of_2; | |
float *delay_buffer; // (num_nodes, horizon, num_item) | |
// parameters | |
const float *params; // (num_params, batch_size) | |
const float *spatial_params; // (num_spatial_params, num_nodes, batch_size) | |
// csr connectivty | |
const int num_nonzero; | |
const float *weights; // (num_nonzero,) | |
const int *indices; // (num_nonzero,) | |
const int *indptr; // (num_nodes+1,) | |
const int *idelays; // (num_nonzero,) | |
}; | |
struct idx { | |
int node; | |
int time; | |
int item; | |
}; | |
typedef const struct sim the_sim; | |
#define NOINLINE __attribute__ ((noinline)) | |
#define DOINLINE __attribute__((always_inline)) | |
#define INLINE DOINLINE | |
#define aligned(var) var = __builtin_assume_aligned(var, 32) | |
// relax | |
#define aligned(var) var = var; | |
static void INLINE inc8_aligned(float *dst, float *src, float w) { | |
aligned(dst); | |
aligned(src); | |
// clang does not want to vectorize this :o | |
#pragma omp simd | |
for (int i=0; i<8; i++) dst[i] += w*src[i]; | |
} | |
static void INLINE load8(float *dst, float *src) { | |
aligned(dst); | |
aligned(src); | |
#pragma omp simd | |
for (int i=0; i<8; i++) dst[i] = src[i]; | |
} | |
static void INLINE zero8(float * __restrict dst) { | |
aligned(dst); | |
#pragma omp simd | |
for (int i=0; i<8; i++) dst[i] = 0.f; | |
} | |
static void INLINE prep_ij(the_sim *s, | |
const int i_time, const int nz, float **b1, float **b2, float *w) | |
{ | |
aligned(b1); | |
aligned(b2); | |
aligned(w); | |
w[0] = s->weights[nz]; | |
float *b0 = s->delay_buffer + s->horizon*s->indices[nz]*8; | |
int t0 = s->horizon + i_time - s->idelays[nz]; | |
b1[0] = b0 + ((t0 + 0)&s->horizon_minus_1)*8; | |
b2[0] = b0 + ((t0 + 1)&s->horizon_minus_1)*8; | |
} | |
static void INLINE csr_ij( | |
the_sim *s, const int nz, const int i_time, | |
float *cx1, float *cx2 | |
) { | |
float *b1, *b2, w; | |
prep_ij(s, i_time, nz, &b1, &b2, &w); | |
inc8_aligned(cx1, b1, w); | |
inc8_aligned(cx2, b2, w); | |
} | |
// vectorized prep_ij but poitners are 8 bytes, so only do 4 at time | |
static void INLINE prep4_ij(the_sim *s, | |
const int i_time, const int nz, | |
float ** __restrict b1, float ** __restrict b2, float * __restrict w) | |
{ | |
// inputs | |
const float *weights = (float*) __builtin_assume_aligned(s->weights+nz, 32); | |
const int *indices = (int*) __builtin_assume_aligned(s->indices+nz, 32); | |
const int *idelays = (int*) __builtin_assume_aligned(s->idelays+nz, 32); | |
float *buf = (float*) __builtin_assume_aligned(s->delay_buffer, 32); | |
// outputs | |
aligned(w); | |
aligned(b1); | |
aligned(b2); | |
// get busy | |
#pragma omp simd | |
for (int i=0; i<4; i++) | |
{ | |
// XXX clang doesn't vectorize this, gcc does | |
w[i] = weights[i]; | |
int t0 = s->horizon + i_time - idelays[i]; | |
b1[i] = buf + s->horizon*indices[i]*8 + ((t0 + 0)&s->horizon_minus_1)*8; | |
b2[i] = buf + s->horizon*indices[i]*8 + ((t0 + 1)&s->horizon_minus_1)*8; | |
} | |
} | |
static void INLINE csr4_ij( | |
the_sim *s, const int nz, const int i_time, | |
float *cx1, float *cx2 | |
) { | |
float *b1[4], *b2[4], w[4]; | |
prep4_ij(s, i_time, nz, b1, b2, w); | |
inc8_aligned(cx1, b1[0], w[0]); | |
inc8_aligned(cx2, b2[0], w[0]); | |
inc8_aligned(cx1, b1[1], w[1]); | |
inc8_aligned(cx2, b2[1], w[1]); | |
inc8_aligned(cx1, b1[2], w[2]); | |
inc8_aligned(cx2, b2[2], w[2]); | |
inc8_aligned(cx1, b1[3], w[3]); | |
inc8_aligned(cx2, b2[3], w[3]); | |
} | |
static void INLINE dfun8( | |
// restrict required for autovectorization to work | |
const float * __restrict x, const float *__restrict y, | |
const float *__restrict cx, const float *__restrict a, | |
const float *__restrict tau, const float *__restrict k, | |
float *__restrict dx, float *__restrict dy) | |
{ | |
aligned(x); | |
aligned(y); | |
aligned(cx); | |
aligned(a); | |
aligned(tau); | |
aligned(k); | |
aligned(dx); | |
aligned(dy); | |
#pragma omp simd | |
for (int i=0; i<8; i++) | |
{ | |
dx[i] = tau[i] * (x[i] - x[i]*x[i]*x[i]/3 + y[i]); | |
dy[i] = (1/tau[i]) * (a[i] + k[i]*cx[i] - x[i]); | |
} | |
} | |
static void INLINE heun1( | |
const float * __restrict x, float * __restrict xi, | |
const float * __restrict dx, const float dt) | |
{ | |
aligned(x); | |
aligned(xi); | |
aligned(dx); | |
#pragma omp simd | |
for (int i=0; i<8; i++) | |
xi[i] = x[i] + dt*dx[i]; | |
} | |
static void INLINE heun2( | |
float * __restrict x, const float * __restrict dx1, | |
const float * __restrict dx2, const float dt) | |
{ | |
aligned(x); | |
aligned(dx1); | |
aligned(dx2); | |
#pragma omp simd | |
for (int i=0; i<8; i++) | |
x[i] += dt*0.5f*(dx1[i] + dx2[i]); | |
} | |
static void INLINE step_ode( | |
the_sim *s, | |
const int i_node, | |
const int i_time, | |
float *cx1, | |
float *cx2 | |
) | |
{ | |
float x[2*8], xi[2*8], dx1[2*8], dx2[2*8], z[2*8]; | |
#pragma GCC unroll 2 | |
for (int svar=0; svar < 2; svar++) | |
{ | |
load8(x+svar*8, s->states+8*(i_node + s->num_node*svar)); | |
zero8(xi+svar*8); | |
zero8(dx1+svar*8); | |
zero8(dx2+svar*8); | |
zero8(z+svar*8); | |
} | |
dfun8(x, x+8, cx1, s->params, s->params+8, s->params+16, dx1, dx1+8); | |
#pragma GCC unroll 2 | |
for (int svar=0; svar < 2; svar++) | |
heun1(x+svar*8, xi+svar*8, dx1+svar*8, s->dt); | |
dfun8(xi, xi+8, cx2, s->params, s->params+8, s->params+16, dx2, dx2+8); | |
#pragma GCC unroll 2 | |
for (int svar=0; svar < 2; svar++) | |
{ | |
heun2(x+svar*8, dx1+svar*8, dx2+svar*8, s->dt); | |
load8(s->states+8*(i_node + s->num_node*svar), x+svar*8); | |
} | |
int write_time = i_time & s->horizon_minus_1; | |
load8(s->delay_buffer + 8*(i_node*s->horizon + write_time), | |
s->states+8*(i_node + s->num_node*0)); | |
} | |
void step_nodes(the_sim *s, int i_time) | |
{ | |
for (int i_node=0; i_node < s->num_node; i_node++) | |
{ | |
float cx1[8], cx2[8]; | |
zero8(cx1); | |
zero8(cx2); | |
int i0 = s->indptr[i_node]; | |
int i1 = s->indptr[i_node+1]; | |
int nnz = i1 - i0; | |
int n4 = nnz/4; | |
for (int i_n4=0; i_n4<n4; i_n4++) | |
csr4_ij(s, i0+i_n4*4, i_time, cx1, cx2); | |
nnz -= (n4 * 4); | |
for (int nz=i1-nnz; nz < i1; nz++) | |
csr_ij(s, nz, i_time, cx1, cx2); | |
step_ode(s, i_node, i_time, cx1, cx2); | |
} | |
} | |
static void save_trace(the_sim *s, const int t_trace) | |
{ | |
int i0 = t_trace * s->num_svar * s->num_node * s->num_item; | |
float *out = s->state_trace + i0; | |
for (int i_node=0; i_node < s->num_node; i_node++) | |
{ | |
load8(out + i_node*8, s->states + i_node*8); | |
load8(out + (s->num_node + i_node)*8, s->states + (s->num_node + i_node)*8); | |
} | |
} | |
void run_batches(the_sim *s) | |
{ | |
const int num_trace = s->num_time / s->num_skip; | |
for (int t_trace=0; t_trace < num_trace; t_trace++) | |
{ | |
const int t_total = t_trace * s->num_skip; | |
for (int t_skip=0; t_skip<s->num_skip; t_skip++) | |
step_nodes(s, t_total + t_skip); | |
save_trace(s, t_trace); | |
} | |
} | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import ctypes as ct | |
import scipy.sparse | |
import time | |
import subprocess | |
class Sim(ct.Structure): | |
_fields_ = [ | |
('rng_seed', ct.c_int32), | |
('num_item', ct.c_int32), | |
('num_node', ct.c_int32), | |
('num_svar', ct.c_int32), | |
('num_time', ct.c_int32), | |
('num_params', ct.c_int32), | |
('num_spatial_params', ct.c_int32), | |
('num_simd', ct.c_int32), | |
('num_batch', ct.c_int32), | |
('dt', ct.c_float), | |
('num_skip', ct.c_int32), | |
('state_trace', ct.POINTER(ct.c_float)), | |
('z_scale', ct.POINTER(ct.c_float)), | |
('states', ct.POINTER(ct.c_float)), | |
('horizon', ct.c_int32), | |
('horizon_minus_1', ct.c_int32), | |
('horizon_is_pow_of_2', ct.c_bool), | |
('delay_buffer', ct.POINTER(ct.c_float)), | |
('params', ct.POINTER(ct.c_float)), | |
('spatial_params', ct.POINTER(ct.c_float)), | |
('num_nonzero', ct.c_int32), | |
('weights', ct.POINTER(ct.c_float)), | |
('indices', ct.POINTER(ct.c_int32)), | |
('indptr', ct.POINTER(ct.c_int32)), | |
('idelays', ct.POINTER(ct.c_int32)), | |
] | |
def map_array(sim, key, array): | |
for k, v in sim._fields_: | |
if k == key: | |
break | |
setattr(sim, key, array.ctypes.data_as(v)) | |
def make_sim( | |
csr_weights: scipy.sparse.csr_matrix, | |
idelays: np.ndarray, | |
sim_params: np.ndarray, | |
z_scale: np.ndarray, | |
horizon: int, | |
rng_seed=43, num_item=8, num_node=90, num_svar=2, num_time=1000, dt=0.1, | |
num_skip=5, num_simd=8 | |
): | |
sim = Sim() | |
sim.rng_seed = rng_seed | |
sim.num_item = num_item | |
sim.num_simd = num_simd | |
sim.num_batch = num_item // num_simd | |
sim.num_node = num_node | |
sim.num_svar = num_svar | |
sim.num_time = num_time | |
sim.dt = dt | |
sim.num_skip = num_skip | |
sim.horizon = 256 | |
assert num_item >= num_simd | |
assert sim.num_batch*num_simd == num_item | |
sim.horizon_minus_1 = sim.horizon - 1 | |
sim.horizon_is_pow_of_2 = True | |
sim.num_params = sim_params.shape[0] | |
sim.num_spatial_params = 0 | |
sim.num_nonzero = csr_weights.nnz | |
fd = open('main.c', 'w') | |
fd.write(f'''#include "fused.c" | |
#include<stdlib.h> | |
int main() | |
{{ | |
struct sim sim = {{ | |
.rng_seed = {rng_seed}, | |
.num_item = {num_item}, | |
.num_simd = {num_simd}, | |
.num_batch = {num_item // num_simd}, | |
.num_node = {num_node}, | |
.num_svar = {num_svar}, | |
.num_time = {num_time}, | |
.dt = {dt}, | |
.num_skip = {num_skip}, | |
.horizon = {256}, | |
.horizon_minus_1 = {sim.horizon - 1}, | |
.horizon_is_pow_of_2 = true, | |
.num_params = {sim_params.shape[0]}, | |
.num_spatial_params = {0}, | |
.num_nonzero = {csr_weights.nnz} | |
}}; | |
''') | |
sim_arrays = [] | |
def zeros(shape, dtype='f'): | |
arr = np.zeros(shape, dtype) | |
sim_arrays.append(arr) | |
return arr | |
sim_state_trace = zeros((sim.num_time // sim.num_skip + 2, | |
sim.num_svar, | |
sim.num_node, | |
sim.num_item), 'f') | |
map_array(sim, 'state_trace', sim_state_trace) | |
fd.write('sim.state_trace = malloc(sizeof(float)*(sim.num_time / sim.num_skip + 1)*sim.num_svar*sim.num_node*sim.num_item);') | |
# TODO make this varying per item as well | |
# sim_z_scale = (np.r_[0.01, 0.1]*np.sqrt(sim.dt)).astype('f') | |
sim_z_scale = z_scale.astype('f') | |
map_array(sim, 'z_scale', sim_z_scale) | |
sim_states = zeros((sim.num_svar, sim.num_node, sim.num_item), 'f') | |
map_array(sim, 'states', sim_states) | |
# XXX needs to be power of 2 | |
sim_delay_buffer = zeros((sim.num_node, sim.horizon, sim.num_item), 'f') | |
map_array(sim, 'delay_buffer', sim_delay_buffer) | |
# sim_params = np.zeros((sim.num_params, sim.num_item), 'f') | |
# sim_params[0] = 1.05; | |
# sim_params[1] = 3.0; | |
assert sim_params.shape == (sim.num_params, sim.num_item) | |
sim_params = sim_params.copy().astype('f') | |
sim_spatial_params = zeros((sim.num_spatial_params, | |
sim.num_node, | |
sim.num_item), 'f') | |
map_array(sim, 'params', sim_params) | |
map_array(sim, 'spatial_params', sim_spatial_params) | |
sim_weights = zeros((sim.num_nonzero,), 'f') | |
sim_indices = zeros((sim.num_nonzero,), np.int32) | |
sim_indptr = zeros((sim.num_node+1,), np.int32) | |
sim_idelays = zeros((sim.num_nonzero,), np.int32) | |
sim_weights[:] = csr_weights.data.astype('f') | |
sim_indices[:] = csr_weights.indices.astype('i') | |
sim_indptr[:] = csr_weights.indptr.astype('i') | |
sim_idelays[:] = idelays.astype('i') | |
map_array(sim, 'weights', sim_weights) | |
map_array(sim, 'indices', sim_indices) | |
map_array(sim, 'indptr', sim_indptr) | |
map_array(sim, 'idelays', sim_idelays) | |
fd.write(''' | |
run_batches(&sim); | |
return 0; | |
} | |
''') | |
fd.close() | |
return sim, sim_arrays | |
def load_c(): | |
flags = '-O3 -g -funroll-loops -ffast-math -march=native -mtune=native -fopenmp-simd' | |
subprocess.check_call(f'gcc {flags} -c fused.c'.split(' ')) | |
subprocess.check_call('gcc -shared fused.o -o fused.so -lm'.split(' ')) | |
return ct.CDLL('./fused.so') | |
def run_sim_np( | |
csr_weights: scipy.sparse.csr_matrix, | |
idelays: np.ndarray, | |
sim_params: np.ndarray, | |
z_scale: np.ndarray, | |
horizon: int, | |
rng_seed=43, num_item=8, num_node=90, num_svar=2, num_time=1000, dt=0.1, | |
num_skip=5 | |
): | |
trace_shape = num_time // num_skip + 1, num_svar, num_node, num_item | |
trace = np.zeros(trace_shape, 'f') | |
assert idelays.max() < horizon-2 | |
idelays2 = -horizon + np.c_[idelays, idelays-1].T | |
assert idelays2.shape == (2, csr_weights.nnz) | |
buffer = np.zeros((num_node, horizon, num_item)) | |
def cfun(t): | |
cx = buffer[csr_weights.indices, (t-idelays2) % horizon] | |
# if t==100: | |
# print(((t-idelays2) % horizon)[:,:5]) | |
# print(cx[:,:5,0]) | |
cx *= csr_weights.data.reshape(-1, 1) | |
cx = np.add.reduceat(cx, csr_weights.indptr[:-1], axis=1) | |
return cx # (2, num_node, num_item) | |
def dfun(x, cx): | |
a, tau, k = sim_params | |
return np.array([ | |
tau*(x[0] - x[0]**3/3 + x[1]), | |
(1/tau)*(a + k*cx - x[0]) | |
]) # (2, num_node, num_item) | |
def heun(x, cx): | |
# skip rng in C code ftm | |
z = np.zeros((2,num_node,num_item)) # np.random.randn(2, num_node, num_item) | |
# z *= z_scale.reshape((2, 1, 1)) | |
dx1 = dfun(x, cx[0]) | |
dx2 = dfun(x + dt*dx1 + z, cx[1]) | |
return x + dt/2*(dx1 + dx2) + z | |
x = np.zeros((2, num_node, num_item), 'f') | |
for t in range(trace.shape[0]): | |
for tt in range(num_skip): | |
ttt = t*num_skip + tt | |
cx = cfun(ttt) | |
x = heun(x, cx) | |
buffer[:, ttt % horizon] = x[0] | |
trace[t] = x | |
return trace | |
if __name__ == '__main__': | |
np.random.seed(42) | |
num_item = 8 | |
num_node = 90 | |
num_skip = 50 | |
dt = 0.1 | |
sparsity = 0.3 | |
horizon = 256 | |
num_time = int(10e3/dt) | |
horizonm1 = horizon - 1 | |
sim_params = np.zeros((3, num_item), 'f') | |
sim_params[0] = 1.001 | |
sim_params[1] = 1.0 | |
sim_params[2] = np.logspace(-1.8, -2.0, num_item)/num_node*80 # k | |
z_scale = np.sqrt(dt)*np.r_[0.01, 0.1].astype('f')*1e-8 | |
weights, lengths = np.random.rand(2, num_node, num_node).astype('f') | |
lengths[:] *= 0.8 | |
lengths *= (horizon*dt*0.8) | |
zero_mask = weights < (1-sparsity) | |
weights[zero_mask] = 0 | |
csr_weights = scipy.sparse.csr_matrix(weights) | |
idelays = (lengths[~zero_mask]/dt).astype('i')+2 | |
run_args = csr_weights, idelays, sim_params, z_scale, horizon | |
run_kwargs = dict(num_item=num_item, num_node=num_node, num_time=num_time, | |
dt=dt, num_skip=num_skip) | |
lib_c = load_c() | |
traces = {} | |
# tic = time.time() | |
# traces['numpy'] = run_sim_np(*run_args, **run_kwargs)[:-1] | |
# tok = time.time() | |
# print(tok - tic, 's', num_time*num_item/(tok-tic)*1e-3, 'Kiter/s numpy') | |
# ntnp = traces['numpy'].shape[0] | |
sim, sim_arrays = make_sim(*run_args, **run_kwargs) | |
traces['c'] = sim_arrays[0] | |
tic = time.time() | |
sim_ref = ct.byref(sim) | |
lib_c.run_batches(sim_ref) | |
tok = time.time() | |
print(tok - tic, 's', num_time*num_item/(tok-tic)*1e-3, 'Kiter/s c') | |
np.testing.assert_allclose(traces['c'][:ntnp], traces['numpy'], 1e-5, 1e-5) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment