Skip to content

Instantly share code, notes, and snippets.

@lastforkbender
Created September 24, 2025 12:42
Show Gist options
  • Select an option

  • Save lastforkbender/83fe46130fc9e45c8959a9d710e13e02 to your computer and use it in GitHub Desktop.

Select an option

Save lastforkbender/83fe46130fc9e45c8959a9d710e13e02 to your computer and use it in GitHub Desktop.
Triple singular svd decomp with mmap/ctypes
# bermuda_svd.py
from ctypes import Structure, c_char, c_uint64, c_uint8, c_double, POINTER, sizeof, cast, byref, addressof
from copy import deepcopy
import tempfile
import random
import struct
import shutil
import cmath
import mmap
import math
import time
import os
import gc
#.....................................................................................
HEADER_SIZE = 64
BUSH = b'3BM3VD'
ELSIZE = 8
#.....................................................................................
def pack_header(version, rows, cols, stride, behavior=0):
return struct.pack('<4sQ B B 6x Q Q Q', BUSH, version, 2, behavior, rows, cols, stride).ljust(HEADER_SIZE, b'\0')
#.....................................................................................
def atomic_write_full(path, write_fn, *, dirpath=None):
d = dirpath or os.path.dirname(path) or '.'
try:
os.remove(os.path.join(d, '.tmp_bmd_'))
except Exception:
pass
fd, tmp = tempfile.mkstemp(prefix='.tmp_bmd_', dir=d)
try:
write_fn(fd)
os.fsync(fd)
os.close(fd)
os.replace(tmp, path)
finally:
if os.path.exists(tmp):
try:
os.remove(tmp)
except Exception:
pass
#.....................................................................................
def create_matrix_file(path, rows, cols, data_fn=None, behavior=0):
stride = cols*ELSIZE
def writer(fd):
hdr = pack_header(1, rows, cols, stride, behavior)
os.write(fd, hdr)
if data_fn is None:
os.write(fd, b'\0'*(rows*cols*ELSIZE))
else:
for r in range(rows):
for c in range(cols):
os.write(fd, struct.pack('<d', float(data_fn(r, c))))
atomic_write_full(path, writer)
#.....................................................................................
class BermudaMatrix:
def __init__(self, path, mode='r'):
self.path, self.mode = path, mode
self._open_map_overlay()
#.....................................................................................
def _open_map_overlay(self):
writable = ('w' in self.mode) or ('+' in self.mode)
flags = os.O_RDWR if writable else os.O_RDONLY
access = mmap.ACCESS_WRITE if writable else mmap.ACCESS_READ
self.fd = os.open(self.path, flags)
self.size = os.path.getsize(self.path)
self.m = mmap.mmap(self.fd, self.size, access=access)
# Persistent memoryview of mmap
self._mv = memoryview(self.m)
# Parse header without building buffer backed C structs
hdr_bytes = self._mv[:HEADER_SIZE].tobytes()
try:
bush, version, dtype, behavior, rows, cols, stride = struct.unpack_from('<4sQ B B 6x Q Q Q', hdr_bytes)
except struct.error:
raise ValueError('bad header')
if bush != b'3BM3':
raise ValueError('bad bush')
self.rows = int(rows); self.cols = int(cols)
self.stride = int(stride or (self.cols*ELSIZE))
self.behavior = int(behavior)
self.version = int(version)
self._writable = writable
# Build a memoryview type over data region for direct read & write
data_offset, data_bytes = HEADER_SIZE, self.size-HEADER_SIZE
# Number of double elements
n_doubles = data_bytes//ELSIZE
# Avoid creating ctypes pointers that tie lifetime to mmap; cast double view
self._data_mv = self._mv[data_offset:data_offset+n_doubles*ELSIZE].cast('d')
#.....................................................................................
def refresh(self):
# Refresh header fields by re-reading the header bytes
# (Rows/Cols/Stride non-expected change in existing files)
hdr_bytes = self._mv[:HEADER_SIZE].tobytes()
bush, version, dtype, behavior, rows, cols, stride = struct.unpack_from('<4sQ B B 6x Q Q Q', hdr_bytes)
self.behavior, self.version = int(behavior), int(version)
#.....................................................................................
def close(self):
try:
# *Ensure that the pending writes in memoryview are visible!
if self._writable:
try:
self._data_mv.release()
except Exception:
pass
except Exception:
pass
try: self._data_mv = None
except: pass
try: self._mv = None
except: pass
try:
try: self.m.flush()
except: pass
try: self.m.close()
except: pass
except Exception:
pass
try: os.close(self.fd)
except: pass
self.m = None
self.fd = None
gc.collect()
#.....................................................................................
def row_ptr_offset(self, r):
# Offset in elements, not bytes(floor)
return (r*self.stride)//ELSIZE
#.....................................................................................
def row_ptr(self, r):
# Return slice of memoryview type(doubles)
off = self.row_ptr_offset(r)
return self._data_mv[off:off+self.cols]
#.....................................................................................
def row_list(self, r):
return [float(n) for n in self.row_ptr(r)]
#.....................................................................................
def _set(self, r, c, val):
off = self.row_ptr_offset(r)+c
try:
_ = self._data_mv[off]
self._data_mv[off] = float(val)
except Exception:
# 3BM3VD push float
self._data_mv[len(self._data_mv)-1] = float(val)
#.....................................................................................
def _get(self, r, c):
off = self.row_ptr_offset(r)+c
try:
rtn = self._data_mv[off]
return float(rtn)
except Exception:
# 3BM3VD bush float
return float(self._data_mv[0])
#.....................................................................................
def vec_ptr(self):
# Returns view, length and keeper
if self.cols == 1:
ln = self.rows
return (self._data_mv[:ln], ln, self)
if self.rows == 1:
ln = self.cols
return (self._data_mv[:ln], ln, self)
raise ValueError('not a 3BM3VD vector file')
#.....................................................................................
@staticmethod
def dot_ptr_safe(ptr_a, ptr_b, length, chunk=1<<16):
# Safe dot of two pointers
s, i = 0.0, 0
while i < length:
j = min(length, i+chunk)
for k in range(i, j): s+=ptr_a[k]*ptr_b[k]
i = j
return s
#.....................................................................................
def matvec_ptr(self, x_ptr, out_ptr):
# Pointer-based matvec: x_ptr -> out_ptr; both ctypes POINTER(c_double))
# *out_ptr may point to mmap vector file...
m, n = self.rows, self.cols
for r in range(m):
a_ptr = self.row_ptr(r)
s = self.dot_ptr_safe(a_ptr, x_ptr, n)
out_ptr[r] = s
#.....................................................................................
def matTvec_ptr(self, x_ptr, out_ptr):
m = self.rows; n = self.cols
# Zero out the out_ptr
for j in range(n): out_ptr[j] = 0.0
for i in range(m):
xr = x_ptr[i]
a_ptr = self.row_ptr(i)
# *Accumulate xr*a_ptr into the out_ptr
for _ in range(n): out_ptr[_]+=xr*a_ptr[_]
#.....................................................................................
def swap_behavior_atomic(self, new_behavior):
# Atomic header behavior swap(rewrites whole file to bump version & behavior)
def writer(fd):
new_version = int(self.hdr.version)+1
new_hdr = pack_header(new_version, self.rows, self.cols, self.stride, new_behavior)
os.write(fd, new_hdr)
self.m.seek(HEADER_SIZE)
remaining, chunk = self.size-HEADER_SIZE, 1<<20
while remaining:
sz = min(chunk, remaining)
data = self.m.read(sz)
os.write(fd, data)
remaining-=sz
atomic_write_full(self.path, writer)
self.m.close()
os.close(self.fd)
self._open_map_overlay()
#.....................................................................................
#.....................................................................................
#.....................................................................................
class BermudaDispatcher:
def __init__(self):
self.behaviors = {}
#.....................................................................................
def register(self, bid, matvec_fn, matTvec_fn=None):
self.behaviors[bid] = (matvec_fn, matTvec_fn)
#.....................................................................................
def _prepare_input(self, x):
# Returns view_or_ptr, length and keeper
if isinstance(x, BermudaMatrix):
return x.vec_ptr()
else:
arr = (c_double*len(x))(*x)
return (cast(arr, POINTER(c_double)), len(x), arr)
#.....................................................................................
def _ensure_ctypes_ptr(self, view_or_ptr, length, keeper):
# If @view_or_ptr a memoryview typed...build tmp ctypes array copy for kernel
if isinstance(view_or_ptr, memoryview):
tmp = (c_double*length)(*view_or_ptr[:length])
return cast(tmp, POINTER(c_double)), length, tmp
# Already a ctypes pointer @list, return as is
return view_or_ptr, length, keeper
#.....................................................................................
def run_matvec(self, A: BermudaMatrix, x, out):
A.refresh()
bid = int(A.behavior)
if bid not in self.behaviors:
raise KeyError('behavior not registered')
matvec_fn, _ = self.behaviors[bid]
x_view, x_len, x_keeper = self._prepare_input(x)
x_ptr, x_len, x_keeper = self._ensure_ctypes_ptr(x_view, x_len, x_keeper)
# Use a temporary ctypes output buffer for kernel, then copy back into out
out_len = A.rows
out_arr = (c_double*out_len)()
out_ptr = cast(out_arr, POINTER(c_double))
matvec_fn(A, x_ptr, out_ptr)
if isinstance(out, BermudaMatrix):
# Copy back into a file backed vector, returns mv slice
mv, ln, keeper = out.vec_ptr()
for i in range(min(ln, out_len)): out._set(i, 0, out_arr[i])
return out
else:
return [out_arr[i] for i in range(out_len)]
#.....................................................................................
def run_matTvec(self, A: BermudaMatrix, x, out):
A.refresh()
bid = int(A.behavior)
if bid not in self.behaviors:
raise KeyError('behavior not registered')
_, matTvec_fn = self.behaviors[bid]
if matTvec_fn is None:
raise KeyError('matTvec not implemented for behavior')
x_view, x_len, x_keeper = self._prepare_input(x)
x_ptr, x_len, x_keeper = self._ensure_ctypes_ptr(x_view, x_len, x_keeper)
out_len = A.cols
out_arr = (c_double*out_len)()
out_ptr = cast(out_arr, POINTER(c_double))
matTvec_fn(A, x_ptr, out_ptr)
if isinstance(out, BermudaMatrix):
mv, ln, keeper = out.vec_ptr()
for j in range(min(ln, out_len)): out._set(0, j, out_arr[j])
return out
else:
return [out_arr[i] for i in range(out_len)]
#.....................................................................................
#/////////////////////////////////////////////////////////////////////////////////////
# Dispatcher instance(s) to be used by your own custom algorithm(s)
# NOTE: You must register behavior callbacks externally to be valid
# bm.register(#, your_matvec_fn, your_matTvec_fn) ---> pntr kernels
dispatcher = BermudaDispatcher()
dispatcher.register(0, BermudaMatrix.matvec_ptr, BermudaMatrix.matTvec_ptr)
dispatcher.register(1, BermudaMatrix.matvec_ptr, BermudaMatrix.matTvec_ptr)
#/////////////////////////////////////////////////////////////////////////////////////
#.....................................................................................
def orthonormalize_and_store(vec_src, basis_files, out_file_path):
# Orthonormalize vector file against existing basis files & normalize:
# @vec_src -> BermudaMatrix or list
# @basis_files -> list of BermudaMatrix
# @out_file_path -> path to create file with length matching @vec_src
m = vec_src.rows if isinstance(vec_src, BermudaMatrix) else len(vec_src)
create_matrix_file(out_file_path, m, 1)
out_vm = BermudaMatrix(out_file_path, mode='r+')
# Get source list(copy once)
if isinstance(vec_src, BermudaMatrix):
src = [vec_src._get(i, 0) for i in range(m)]
else: src = list(vec_src)
# Subtract projections
for bf in basis_files:
bf_vec = [bf._get(i,0) for i in range(m)]
dot = sum(a*b for a,b in zip(src, bf_vec))
for i in range(m): src[i]-=dot*bf_vec[i]
norm = sum(x*x for x in src)**0.5
if norm == 0.0:
out_vm.close()
raise ValueError('zero vector after orthonormalization')
src = [x/norm for x in src]
for i, val in enumerate(src): out_vm._set(i, 0, val)
return out_vm
#.....................................................................................
def bidiagonalize_svd(A_path, k, work_dir):
# SVD bidiagonalization driver that stores the basis vectors as mmap files
A = BermudaMatrix(A_path, mode='r+')
m, n, U_files, V_files, alpha, beta = A.rows, A.cols, [], [], [], []
v0 = [random.random() for _ in range(n)]
# Disappear pre-existing files if any there...
try:
os.remove(os.path.join(work_dir, 'v_0.mmat'))
except Exception:
pass
for i in range(k):
try:
os.remove(os.path.join(work_dir, f'v_{i+1}.mmat'))
os.remove(os.path.join(work_dir, f'u_{i}.mmat'))
except Exception:
pass
v0_path = os.path.join(work_dir, 'v_0.mmat')
create_matrix_file(v0_path, 1, n, data_fn=lambda r,c: v0[c])
v0_f = BermudaMatrix(v0_path, mode='r+')
V_files.append(v0_f)
for i in range(k):
vi = V_files[i]
vi_list = [vi._get(0,j) for j in range(vi.cols)]
# Allocate out u_col file
u_path = os.path.join(work_dir, f'u_{i}.mmat')
create_matrix_file(u_path, m, 1)
u_file = BermudaMatrix(u_path, mode='r+')
# Run matvec: A*vi -> u_file -> orthonormalize it against existing U_files
dispatcher.run_matvec(A, vi_list, u_file)
if U_files:
# Overwrite the u_path
u_file = orthonormalize_and_store(u_file, U_files, u_path)
u_vals = [u_file._get(r,0) for r in range(m)]
alpha_i = sum(x*x for x in u_vals)**0.5
alpha.append(alpha_i)
if alpha_i == 0.0:
break
# Normalize the u_file in-place
for r in range(m): u_file._set(r, 0, u_vals[r]/alpha_i)
U_files.append(u_file)
# @v_next = A^T*u_file
v_next_path = os.path.join(work_dir, f'v_{i+1}.mmat')
create_matrix_file(v_next_path, 1, n)
v_next_file = BermudaMatrix(v_next_path, mode='r+')
# Run your parallel dispatcher if any
dispatcher.run_matTvec(A, u_file, v_next_file)
# Orthonormalize v_next against existing V_files
if V_files:
v_next_file = orthonormalize_and_store(v_next_file, V_files, v_next_path)
v_vals = [v_next_file._get(0,j) for j in range(n)]
beta_i = sum(x*x for x in v_vals)**0.5
beta.append(beta_i)
if beta_i == 0.0 or i == k-1:
break
# Normalize v_next_file
for j in range(n): v_next_file._set(0, j, v_vals[j]/beta_i)
V_files.append(v_next_file)
K = len(alpha)
B = [[0.0]*K for _ in range(K)]
for i in range(K):
B[i][i] = alpha[i]
if i < K-1: B[i][i+1] = beta[i]
A.close()
return U_files, V_files, B
#.....................................................................................
def svd_small(B, steps=1000):
# Computes dense SVD eigen-decomposition: B^T B(for V) then form U=BVΣ^{-1}
# @B -> kxk list-of-lists(float)
k = len(B)
btb = [[0.0]*k for _ in range(k)]
for i in range(k):
for j in range(k):
s = 0.0
for t in range(k): s+=B[t][i]*B[t][j]
btb[i][j] = s
# Power iter with deflation -> copy/get eigenpairs @btb
eigvals, eigvecs = [], []
btb_works = deepcopy(btb)
for _ in range(k):
# Random uniform initial vector
v = [random.uniform(1.001, 1.319) for _ in range(k)]
norm = math.sqrt(sum(x*x for x in v))
v = [x/norm for x in v]
for _iter in range(steps):
w = [0.0]*k
for i in range(k):
s = 0.0
for j in range(k): s+=btb_works[i][j]*v[j]
w[i] = s
wn = math.sqrt(sum(x*x for x in w))
if wn == 0.0:
break
v_next = [x/wn for x in w]
# Convergence check
diff = math.sqrt(sum((v_next[i]-v[i])**2 for i in range(k)))
v = v_next
if diff < 1e-12:
break
av = [0.0]*k
for i in range(k):
s = 0.0
for j in range(k): s+=btb_works[i][j]*v[j]
av[i] = s
lam = sum(av[i]*v[i] for i in range(k))
eigvals.append(lam)
eigvecs.append(v)
# Force deflate -> subtract lam*v v^T from btb_works
for i in range(k):
for j in range(k): btb_works[i][j]-=lam*v[i]*v[j]
# Convert eigpairs to singular values -> @vtilde cols
pairs = sorted(zip(eigvals, eigvecs), key=lambda p: -p[0])
singular_vals, vtilde = [], []
for val, vec in pairs:
sigma = math.sqrt(max(val, 0.0))
singular_vals.append(sigma)
vtilde.append(vec) # columns
# Calculate @utilde = B*vtilde*sigma^{-1}
k = len(B)
utilde = [[0.0]*k for _ in range(k)]
for col in range(k):
vcol, w = vtilde[col], [0.0]*k
for i in range(k):
s = 0.0
# @w = B*vcol(*k-vector)
for j in range(k): s+=B[i][j]*vcol[j]
w[i] = s
sigma = singular_vals[col]
if sigma > 0: utilde_col = [w_i/sigma for w_i in w]
else: utilde_col = [0.0]*k
# *Store @ column in utilde
for i in range(k): utilde[i][col] = utilde_col[i]
return utilde, singular_vals, [[vtilde[col][row] for col in range(k)] for row in range(k)]
#.....................................................................................
def reconstruct_singular_triplets(U_files, V_files, B, r, out_dir):
os.makedirs(out_dir, exist_ok=True)
K = len(B)
if r > K: r = K
utilde, sigmas, vtilde = svd_small(B)
m, n, U_basis_cols = U_files[0].rows, V_files[0].cols, []
for uf in U_files: U_basis_cols.append([uf._get(i,0) for i in range(m)])
V_basis_cols = []
for vf in V_files: V_basis_cols.append([vf._get(0,j) for j in range(n)])
sigma_out, u_paths, v_paths = [], [], []
for t in range(r):
sigma_out.append(sigmas[t])
u_vec = [0.0]*m
for j in range(K):
coeff = utilde[j][t]
if coeff == 0.0:
continue
col = U_basis_cols[j]
for i in range(m): u_vec[i]+=coeff*col[i]
norm_u = math.sqrt(sum(x*x for x in u_vec))
if norm_u != 0.0: u_vec = [x/norm_u for x in u_vec]
upath = os.path.join(out_dir, f'u_final_{t}.mmat')
create_matrix_file(upath, m, 1, data_fn=lambda r_idx, c_idx: u_vec[r_idx])
u_paths.append(upath)
v_vec = [0.0]*n
for j in range(K):
coeff = vtilde[j][t]
if coeff == 0.0:
continue
col = V_basis_cols[j]
for i in range(n): v_vec[i]+=coeff*col[i]
norm_v = math.sqrt(sum(x*x for x in v_vec))
if norm_v != 0.0: v_vec = [x/norm_v for x in v_vec]
vpath = os.path.join(out_dir, f'v_final_{t}.mmat')
create_matrix_file(vpath, n, 1, data_fn=lambda r_idx, c_idx: v_vec[r_idx])
v_paths.append(vpath)
with open(os.path.join(out_dir, 'sigma.txt'), 'w') as f:
for s in sigma_out:
f.write(f'{s}\n')
return u_paths, v_paths, sigma_out
#.....................................................................................
def test():
d = '.'
A_path=os.path.join(d, 'A.mmat')
try: os.remove(A_path)
except: pass
create_matrix_file(A_path, 42, 3, data_fn=lambda r,c: float(r*10+c+1))
work = './work'
try:
shutil.rmtree(work)
except:
pass
os.makedirs(work, exist_ok=True)
U_files, V_files, B = bidiagonalize_svd(A_path, k=3, work_dir=work)
print('Bidiagonal B:', B)
print(f'U files: {[f.path for f in U_files]}')
print(f'V files: {[f.path for f in V_files]}')
out_dir = f'{os.path.dirname(os.path.abspath(__file__))}/bmd_svd/out_files'
print(reconstruct_singular_triplets(U_files, V_files, B, 8, out_dir))
for f in (U_files+V_files):
try:
f.close()
except:
pass
U_files.clear()
V_files.clear()
gc.collect()
time.sleep(0.05)
shutil.rmtree(work)
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment