Skip to content

Instantly share code, notes, and snippets.

@maedoc
Last active November 25, 2024 09:13
Show Gist options
  • Save maedoc/3ce7084078f915f3846b2b66068c81fa to your computer and use it in GitHub Desktop.
Save maedoc/3ce7084078f915f3846b2b66068c81fa to your computer and use it in GitHub Desktop.
Fused kernels for simulations
#include<stdbool.h>
#include<stdio.h>
struct sim {
const int rng_seed;
const int num_item;
const int num_node;
const int num_svar;
const int num_time;
const int num_params;
const int num_spatial_params;
const int num_simd;
const int num_batch;
// time step
const float dt;
// TODO ode substeps for stability
// const int num_substeps;
// num of dt to skip per output sample
const int num_skip;
float *state_trace; // (num_time//num_skip, num_svar, num_nodes, num_item)
// noise scaling per svar, sigma[:]*sqrt(dt)
const float *z_scale; // (num_svar, num_item)
// full states
float *states; // (num_svar, num_nodes, num_item)
// delay buffer for coupling variables
const int horizon;
const int horizon_minus_1;
const bool horizon_is_pow_of_2;
float *delay_buffer; // (num_nodes, horizon, num_item)
// parameters
const float *params; // (num_params, batch_size)
const float *spatial_params; // (num_spatial_params, num_nodes, batch_size)
// csr connectivty
const int num_nonzero;
const float *weights; // (num_nonzero,)
const int *indices; // (num_nonzero,)
const int *indptr; // (num_nodes+1,)
const int *idelays; // (num_nonzero,)
};
struct idx {
int node;
int time;
int item;
};
typedef const struct sim the_sim;
#define NOINLINE __attribute__ ((noinline))
#define DOINLINE __attribute__((always_inline))
#define INLINE DOINLINE
#define aligned(var) var = __builtin_assume_aligned(var, 32)
// relax
#define aligned(var) var = var;
static void INLINE inc8_aligned(float *dst, float *src, float w) {
aligned(dst);
aligned(src);
// clang does not want to vectorize this :o
#pragma omp simd
for (int i=0; i<8; i++) dst[i] += w*src[i];
}
static void INLINE load8(float *dst, float *src) {
aligned(dst);
aligned(src);
#pragma omp simd
for (int i=0; i<8; i++) dst[i] = src[i];
}
static void INLINE zero8(float * __restrict dst) {
aligned(dst);
#pragma omp simd
for (int i=0; i<8; i++) dst[i] = 0.f;
}
static void INLINE prep_ij(the_sim *s,
const int i_time, const int nz, float **b1, float **b2, float *w)
{
aligned(b1);
aligned(b2);
aligned(w);
w[0] = s->weights[nz];
float *b0 = s->delay_buffer + s->horizon*s->indices[nz]*8;
int t0 = s->horizon + i_time - s->idelays[nz];
b1[0] = b0 + ((t0 + 0)&s->horizon_minus_1)*8;
b2[0] = b0 + ((t0 + 1)&s->horizon_minus_1)*8;
}
static void INLINE csr_ij(
the_sim *s, const int nz, const int i_time,
float *cx1, float *cx2
) {
float *b1, *b2, w;
prep_ij(s, i_time, nz, &b1, &b2, &w);
inc8_aligned(cx1, b1, w);
inc8_aligned(cx2, b2, w);
}
// vectorized prep_ij but poitners are 8 bytes, so only do 4 at time
static void INLINE prep4_ij(the_sim *s,
const int i_time, const int nz,
float ** __restrict b1, float ** __restrict b2, float * __restrict w)
{
// inputs
const float *weights = (float*) __builtin_assume_aligned(s->weights+nz, 32);
const int *indices = (int*) __builtin_assume_aligned(s->indices+nz, 32);
const int *idelays = (int*) __builtin_assume_aligned(s->idelays+nz, 32);
float *buf = (float*) __builtin_assume_aligned(s->delay_buffer, 32);
// outputs
aligned(w);
aligned(b1);
aligned(b2);
// get busy
#pragma omp simd
for (int i=0; i<4; i++)
{
// XXX clang doesn't vectorize this, gcc does
w[i] = weights[i];
int t0 = s->horizon + i_time - idelays[i];
b1[i] = buf + s->horizon*indices[i]*8 + ((t0 + 0)&s->horizon_minus_1)*8;
b2[i] = buf + s->horizon*indices[i]*8 + ((t0 + 1)&s->horizon_minus_1)*8;
}
}
static void INLINE csr4_ij(
the_sim *s, const int nz, const int i_time,
float *cx1, float *cx2
) {
float *b1[4], *b2[4], w[4];
prep4_ij(s, i_time, nz, b1, b2, w);
inc8_aligned(cx1, b1[0], w[0]);
inc8_aligned(cx2, b2[0], w[0]);
inc8_aligned(cx1, b1[1], w[1]);
inc8_aligned(cx2, b2[1], w[1]);
inc8_aligned(cx1, b1[2], w[2]);
inc8_aligned(cx2, b2[2], w[2]);
inc8_aligned(cx1, b1[3], w[3]);
inc8_aligned(cx2, b2[3], w[3]);
}
static void INLINE dfun8(
// restrict required for autovectorization to work
const float * __restrict x, const float *__restrict y,
const float *__restrict cx, const float *__restrict a,
const float *__restrict tau, const float *__restrict k,
float *__restrict dx, float *__restrict dy)
{
aligned(x);
aligned(y);
aligned(cx);
aligned(a);
aligned(tau);
aligned(k);
aligned(dx);
aligned(dy);
#pragma omp simd
for (int i=0; i<8; i++)
{
dx[i] = tau[i] * (x[i] - x[i]*x[i]*x[i]/3 + y[i]);
dy[i] = (1/tau[i]) * (a[i] + k[i]*cx[i] - x[i]);
}
}
static void INLINE heun1(
const float * __restrict x, float * __restrict xi,
const float * __restrict dx, const float dt)
{
aligned(x);
aligned(xi);
aligned(dx);
#pragma omp simd
for (int i=0; i<8; i++)
xi[i] = x[i] + dt*dx[i];
}
static void INLINE heun2(
float * __restrict x, const float * __restrict dx1,
const float * __restrict dx2, const float dt)
{
aligned(x);
aligned(dx1);
aligned(dx2);
#pragma omp simd
for (int i=0; i<8; i++)
x[i] += dt*0.5f*(dx1[i] + dx2[i]);
}
static void INLINE step_ode(
the_sim *s,
const int i_node,
const int i_time,
float *cx1,
float *cx2
)
{
float x[2*8], xi[2*8], dx1[2*8], dx2[2*8], z[2*8];
#pragma GCC unroll 2
for (int svar=0; svar < 2; svar++)
{
load8(x+svar*8, s->states+8*(i_node + s->num_node*svar));
zero8(xi+svar*8);
zero8(dx1+svar*8);
zero8(dx2+svar*8);
zero8(z+svar*8);
}
dfun8(x, x+8, cx1, s->params, s->params+8, s->params+16, dx1, dx1+8);
#pragma GCC unroll 2
for (int svar=0; svar < 2; svar++)
heun1(x+svar*8, xi+svar*8, dx1+svar*8, s->dt);
dfun8(xi, xi+8, cx2, s->params, s->params+8, s->params+16, dx2, dx2+8);
#pragma GCC unroll 2
for (int svar=0; svar < 2; svar++)
{
heun2(x+svar*8, dx1+svar*8, dx2+svar*8, s->dt);
load8(s->states+8*(i_node + s->num_node*svar), x+svar*8);
}
int write_time = i_time & s->horizon_minus_1;
load8(s->delay_buffer + 8*(i_node*s->horizon + write_time),
s->states+8*(i_node + s->num_node*0));
}
void step_nodes(the_sim *s, int i_time)
{
for (int i_node=0; i_node < s->num_node; i_node++)
{
float cx1[8], cx2[8];
zero8(cx1);
zero8(cx2);
int i0 = s->indptr[i_node];
int i1 = s->indptr[i_node+1];
int nnz = i1 - i0;
int n4 = nnz/4;
for (int i_n4=0; i_n4<n4; i_n4++)
csr4_ij(s, i0+i_n4*4, i_time, cx1, cx2);
nnz -= (n4 * 4);
for (int nz=i1-nnz; nz < i1; nz++)
csr_ij(s, nz, i_time, cx1, cx2);
step_ode(s, i_node, i_time, cx1, cx2);
}
}
static void save_trace(the_sim *s, const int t_trace)
{
int i0 = t_trace * s->num_svar * s->num_node * s->num_item;
float *out = s->state_trace + i0;
for (int i_node=0; i_node < s->num_node; i_node++)
{
load8(out + i_node*8, s->states + i_node*8);
load8(out + (s->num_node + i_node)*8, s->states + (s->num_node + i_node)*8);
}
}
void run_batches(the_sim *s)
{
const int num_trace = s->num_time / s->num_skip;
for (int t_trace=0; t_trace < num_trace; t_trace++)
{
const int t_total = t_trace * s->num_skip;
for (int t_skip=0; t_skip<s->num_skip; t_skip++)
step_nodes(s, t_total + t_skip);
save_trace(s, t_trace);
}
}
import numpy as np
import ctypes as ct
import scipy.sparse
import time
import subprocess
class Sim(ct.Structure):
_fields_ = [
('rng_seed', ct.c_int32),
('num_item', ct.c_int32),
('num_node', ct.c_int32),
('num_svar', ct.c_int32),
('num_time', ct.c_int32),
('num_params', ct.c_int32),
('num_spatial_params', ct.c_int32),
('num_simd', ct.c_int32),
('num_batch', ct.c_int32),
('dt', ct.c_float),
('num_skip', ct.c_int32),
('state_trace', ct.POINTER(ct.c_float)),
('z_scale', ct.POINTER(ct.c_float)),
('states', ct.POINTER(ct.c_float)),
('horizon', ct.c_int32),
('horizon_minus_1', ct.c_int32),
('horizon_is_pow_of_2', ct.c_bool),
('delay_buffer', ct.POINTER(ct.c_float)),
('params', ct.POINTER(ct.c_float)),
('spatial_params', ct.POINTER(ct.c_float)),
('num_nonzero', ct.c_int32),
('weights', ct.POINTER(ct.c_float)),
('indices', ct.POINTER(ct.c_int32)),
('indptr', ct.POINTER(ct.c_int32)),
('idelays', ct.POINTER(ct.c_int32)),
]
def map_array(sim, key, array):
for k, v in sim._fields_:
if k == key:
break
setattr(sim, key, array.ctypes.data_as(v))
def make_sim(
csr_weights: scipy.sparse.csr_matrix,
idelays: np.ndarray,
sim_params: np.ndarray,
z_scale: np.ndarray,
horizon: int,
rng_seed=43, num_item=8, num_node=90, num_svar=2, num_time=1000, dt=0.1,
num_skip=5, num_simd=8
):
sim = Sim()
sim.rng_seed = rng_seed
sim.num_item = num_item
sim.num_simd = num_simd
sim.num_batch = num_item // num_simd
sim.num_node = num_node
sim.num_svar = num_svar
sim.num_time = num_time
sim.dt = dt
sim.num_skip = num_skip
sim.horizon = 256
assert num_item >= num_simd
assert sim.num_batch*num_simd == num_item
sim.horizon_minus_1 = sim.horizon - 1
sim.horizon_is_pow_of_2 = True
sim.num_params = sim_params.shape[0]
sim.num_spatial_params = 0
sim.num_nonzero = csr_weights.nnz
fd = open('main.c', 'w')
fd.write(f'''#include "fused.c"
#include<stdlib.h>
int main()
{{
struct sim sim = {{
.rng_seed = {rng_seed},
.num_item = {num_item},
.num_simd = {num_simd},
.num_batch = {num_item // num_simd},
.num_node = {num_node},
.num_svar = {num_svar},
.num_time = {num_time},
.dt = {dt},
.num_skip = {num_skip},
.horizon = {256},
.horizon_minus_1 = {sim.horizon - 1},
.horizon_is_pow_of_2 = true,
.num_params = {sim_params.shape[0]},
.num_spatial_params = {0},
.num_nonzero = {csr_weights.nnz}
}};
''')
sim_arrays = []
def zeros(shape, dtype='f'):
arr = np.zeros(shape, dtype)
sim_arrays.append(arr)
return arr
sim_state_trace = zeros((sim.num_time // sim.num_skip + 2,
sim.num_svar,
sim.num_node,
sim.num_item), 'f')
map_array(sim, 'state_trace', sim_state_trace)
fd.write('sim.state_trace = malloc(sizeof(float)*(sim.num_time / sim.num_skip + 1)*sim.num_svar*sim.num_node*sim.num_item);')
# TODO make this varying per item as well
# sim_z_scale = (np.r_[0.01, 0.1]*np.sqrt(sim.dt)).astype('f')
sim_z_scale = z_scale.astype('f')
map_array(sim, 'z_scale', sim_z_scale)
sim_states = zeros((sim.num_svar, sim.num_node, sim.num_item), 'f')
map_array(sim, 'states', sim_states)
# XXX needs to be power of 2
sim_delay_buffer = zeros((sim.num_node, sim.horizon, sim.num_item), 'f')
map_array(sim, 'delay_buffer', sim_delay_buffer)
# sim_params = np.zeros((sim.num_params, sim.num_item), 'f')
# sim_params[0] = 1.05;
# sim_params[1] = 3.0;
assert sim_params.shape == (sim.num_params, sim.num_item)
sim_params = sim_params.copy().astype('f')
sim_spatial_params = zeros((sim.num_spatial_params,
sim.num_node,
sim.num_item), 'f')
map_array(sim, 'params', sim_params)
map_array(sim, 'spatial_params', sim_spatial_params)
sim_weights = zeros((sim.num_nonzero,), 'f')
sim_indices = zeros((sim.num_nonzero,), np.int32)
sim_indptr = zeros((sim.num_node+1,), np.int32)
sim_idelays = zeros((sim.num_nonzero,), np.int32)
sim_weights[:] = csr_weights.data.astype('f')
sim_indices[:] = csr_weights.indices.astype('i')
sim_indptr[:] = csr_weights.indptr.astype('i')
sim_idelays[:] = idelays.astype('i')
map_array(sim, 'weights', sim_weights)
map_array(sim, 'indices', sim_indices)
map_array(sim, 'indptr', sim_indptr)
map_array(sim, 'idelays', sim_idelays)
fd.write('''
run_batches(&sim);
return 0;
}
''')
fd.close()
return sim, sim_arrays
def load_c():
flags = '-O3 -g -funroll-loops -ffast-math -march=native -mtune=native -fopenmp-simd'
subprocess.check_call(f'gcc {flags} -c fused.c'.split(' '))
subprocess.check_call('gcc -shared fused.o -o fused.so -lm'.split(' '))
return ct.CDLL('./fused.so')
def run_sim_np(
csr_weights: scipy.sparse.csr_matrix,
idelays: np.ndarray,
sim_params: np.ndarray,
z_scale: np.ndarray,
horizon: int,
rng_seed=43, num_item=8, num_node=90, num_svar=2, num_time=1000, dt=0.1,
num_skip=5
):
trace_shape = num_time // num_skip + 1, num_svar, num_node, num_item
trace = np.zeros(trace_shape, 'f')
assert idelays.max() < horizon-2
idelays2 = -horizon + np.c_[idelays, idelays-1].T
assert idelays2.shape == (2, csr_weights.nnz)
buffer = np.zeros((num_node, horizon, num_item))
def cfun(t):
cx = buffer[csr_weights.indices, (t-idelays2) % horizon]
# if t==100:
# print(((t-idelays2) % horizon)[:,:5])
# print(cx[:,:5,0])
cx *= csr_weights.data.reshape(-1, 1)
cx = np.add.reduceat(cx, csr_weights.indptr[:-1], axis=1)
return cx # (2, num_node, num_item)
def dfun(x, cx):
a, tau, k = sim_params
return np.array([
tau*(x[0] - x[0]**3/3 + x[1]),
(1/tau)*(a + k*cx - x[0])
]) # (2, num_node, num_item)
def heun(x, cx):
# skip rng in C code ftm
z = np.zeros((2,num_node,num_item)) # np.random.randn(2, num_node, num_item)
# z *= z_scale.reshape((2, 1, 1))
dx1 = dfun(x, cx[0])
dx2 = dfun(x + dt*dx1 + z, cx[1])
return x + dt/2*(dx1 + dx2) + z
x = np.zeros((2, num_node, num_item), 'f')
for t in range(trace.shape[0]):
for tt in range(num_skip):
ttt = t*num_skip + tt
cx = cfun(ttt)
x = heun(x, cx)
buffer[:, ttt % horizon] = x[0]
trace[t] = x
return trace
if __name__ == '__main__':
np.random.seed(42)
num_item = 8
num_node = 90
num_skip = 50
dt = 0.1
sparsity = 0.3
horizon = 256
num_time = int(10e3/dt)
horizonm1 = horizon - 1
sim_params = np.zeros((3, num_item), 'f')
sim_params[0] = 1.001
sim_params[1] = 1.0
sim_params[2] = np.logspace(-1.8, -2.0, num_item)/num_node*80 # k
z_scale = np.sqrt(dt)*np.r_[0.01, 0.1].astype('f')*1e-8
weights, lengths = np.random.rand(2, num_node, num_node).astype('f')
lengths[:] *= 0.8
lengths *= (horizon*dt*0.8)
zero_mask = weights < (1-sparsity)
weights[zero_mask] = 0
csr_weights = scipy.sparse.csr_matrix(weights)
idelays = (lengths[~zero_mask]/dt).astype('i')+2
run_args = csr_weights, idelays, sim_params, z_scale, horizon
run_kwargs = dict(num_item=num_item, num_node=num_node, num_time=num_time,
dt=dt, num_skip=num_skip)
lib_c = load_c()
traces = {}
# tic = time.time()
# traces['numpy'] = run_sim_np(*run_args, **run_kwargs)[:-1]
# tok = time.time()
# print(tok - tic, 's', num_time*num_item/(tok-tic)*1e-3, 'Kiter/s numpy')
# ntnp = traces['numpy'].shape[0]
sim, sim_arrays = make_sim(*run_args, **run_kwargs)
traces['c'] = sim_arrays[0]
tic = time.time()
sim_ref = ct.byref(sim)
lib_c.run_batches(sim_ref)
tok = time.time()
print(tok - tic, 's', num_time*num_item/(tok-tic)*1e-3, 'Kiter/s c')
np.testing.assert_allclose(traces['c'][:ntnp], traces['numpy'], 1e-5, 1e-5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment