Created
September 24, 2025 12:42
-
-
Save lastforkbender/83fe46130fc9e45c8959a9d710e13e02 to your computer and use it in GitHub Desktop.
Triple singular svd decomp with mmap/ctypes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # bermuda_svd.py | |
| from ctypes import Structure, c_char, c_uint64, c_uint8, c_double, POINTER, sizeof, cast, byref, addressof | |
| from copy import deepcopy | |
| import tempfile | |
| import random | |
| import struct | |
| import shutil | |
| import cmath | |
| import mmap | |
| import math | |
| import time | |
| import os | |
| import gc | |
| #..................................................................................... | |
| HEADER_SIZE = 64 | |
| BUSH = b'3BM3VD' | |
| ELSIZE = 8 | |
| #..................................................................................... | |
| def pack_header(version, rows, cols, stride, behavior=0): | |
| return struct.pack('<4sQ B B 6x Q Q Q', BUSH, version, 2, behavior, rows, cols, stride).ljust(HEADER_SIZE, b'\0') | |
| #..................................................................................... | |
| def atomic_write_full(path, write_fn, *, dirpath=None): | |
| d = dirpath or os.path.dirname(path) or '.' | |
| try: | |
| os.remove(os.path.join(d, '.tmp_bmd_')) | |
| except Exception: | |
| pass | |
| fd, tmp = tempfile.mkstemp(prefix='.tmp_bmd_', dir=d) | |
| try: | |
| write_fn(fd) | |
| os.fsync(fd) | |
| os.close(fd) | |
| os.replace(tmp, path) | |
| finally: | |
| if os.path.exists(tmp): | |
| try: | |
| os.remove(tmp) | |
| except Exception: | |
| pass | |
| #..................................................................................... | |
| def create_matrix_file(path, rows, cols, data_fn=None, behavior=0): | |
| stride = cols*ELSIZE | |
| def writer(fd): | |
| hdr = pack_header(1, rows, cols, stride, behavior) | |
| os.write(fd, hdr) | |
| if data_fn is None: | |
| os.write(fd, b'\0'*(rows*cols*ELSIZE)) | |
| else: | |
| for r in range(rows): | |
| for c in range(cols): | |
| os.write(fd, struct.pack('<d', float(data_fn(r, c)))) | |
| atomic_write_full(path, writer) | |
| #..................................................................................... | |
| class BermudaMatrix: | |
| def __init__(self, path, mode='r'): | |
| self.path, self.mode = path, mode | |
| self._open_map_overlay() | |
| #..................................................................................... | |
| def _open_map_overlay(self): | |
| writable = ('w' in self.mode) or ('+' in self.mode) | |
| flags = os.O_RDWR if writable else os.O_RDONLY | |
| access = mmap.ACCESS_WRITE if writable else mmap.ACCESS_READ | |
| self.fd = os.open(self.path, flags) | |
| self.size = os.path.getsize(self.path) | |
| self.m = mmap.mmap(self.fd, self.size, access=access) | |
| # Persistent memoryview of mmap | |
| self._mv = memoryview(self.m) | |
| # Parse header without building buffer backed C structs | |
| hdr_bytes = self._mv[:HEADER_SIZE].tobytes() | |
| try: | |
| bush, version, dtype, behavior, rows, cols, stride = struct.unpack_from('<4sQ B B 6x Q Q Q', hdr_bytes) | |
| except struct.error: | |
| raise ValueError('bad header') | |
| if bush != b'3BM3': | |
| raise ValueError('bad bush') | |
| self.rows = int(rows); self.cols = int(cols) | |
| self.stride = int(stride or (self.cols*ELSIZE)) | |
| self.behavior = int(behavior) | |
| self.version = int(version) | |
| self._writable = writable | |
| # Build a memoryview type over data region for direct read & write | |
| data_offset, data_bytes = HEADER_SIZE, self.size-HEADER_SIZE | |
| # Number of double elements | |
| n_doubles = data_bytes//ELSIZE | |
| # Avoid creating ctypes pointers that tie lifetime to mmap; cast double view | |
| self._data_mv = self._mv[data_offset:data_offset+n_doubles*ELSIZE].cast('d') | |
| #..................................................................................... | |
| def refresh(self): | |
| # Refresh header fields by re-reading the header bytes | |
| # (Rows/Cols/Stride non-expected change in existing files) | |
| hdr_bytes = self._mv[:HEADER_SIZE].tobytes() | |
| bush, version, dtype, behavior, rows, cols, stride = struct.unpack_from('<4sQ B B 6x Q Q Q', hdr_bytes) | |
| self.behavior, self.version = int(behavior), int(version) | |
| #..................................................................................... | |
| def close(self): | |
| try: | |
| # *Ensure that the pending writes in memoryview are visible! | |
| if self._writable: | |
| try: | |
| self._data_mv.release() | |
| except Exception: | |
| pass | |
| except Exception: | |
| pass | |
| try: self._data_mv = None | |
| except: pass | |
| try: self._mv = None | |
| except: pass | |
| try: | |
| try: self.m.flush() | |
| except: pass | |
| try: self.m.close() | |
| except: pass | |
| except Exception: | |
| pass | |
| try: os.close(self.fd) | |
| except: pass | |
| self.m = None | |
| self.fd = None | |
| gc.collect() | |
| #..................................................................................... | |
| def row_ptr_offset(self, r): | |
| # Offset in elements, not bytes(floor) | |
| return (r*self.stride)//ELSIZE | |
| #..................................................................................... | |
| def row_ptr(self, r): | |
| # Return slice of memoryview type(doubles) | |
| off = self.row_ptr_offset(r) | |
| return self._data_mv[off:off+self.cols] | |
| #..................................................................................... | |
| def row_list(self, r): | |
| return [float(n) for n in self.row_ptr(r)] | |
| #..................................................................................... | |
| def _set(self, r, c, val): | |
| off = self.row_ptr_offset(r)+c | |
| try: | |
| _ = self._data_mv[off] | |
| self._data_mv[off] = float(val) | |
| except Exception: | |
| # 3BM3VD push float | |
| self._data_mv[len(self._data_mv)-1] = float(val) | |
| #..................................................................................... | |
| def _get(self, r, c): | |
| off = self.row_ptr_offset(r)+c | |
| try: | |
| rtn = self._data_mv[off] | |
| return float(rtn) | |
| except Exception: | |
| # 3BM3VD bush float | |
| return float(self._data_mv[0]) | |
| #..................................................................................... | |
| def vec_ptr(self): | |
| # Returns view, length and keeper | |
| if self.cols == 1: | |
| ln = self.rows | |
| return (self._data_mv[:ln], ln, self) | |
| if self.rows == 1: | |
| ln = self.cols | |
| return (self._data_mv[:ln], ln, self) | |
| raise ValueError('not a 3BM3VD vector file') | |
| #..................................................................................... | |
| @staticmethod | |
| def dot_ptr_safe(ptr_a, ptr_b, length, chunk=1<<16): | |
| # Safe dot of two pointers | |
| s, i = 0.0, 0 | |
| while i < length: | |
| j = min(length, i+chunk) | |
| for k in range(i, j): s+=ptr_a[k]*ptr_b[k] | |
| i = j | |
| return s | |
| #..................................................................................... | |
| def matvec_ptr(self, x_ptr, out_ptr): | |
| # Pointer-based matvec: x_ptr -> out_ptr; both ctypes POINTER(c_double)) | |
| # *out_ptr may point to mmap vector file... | |
| m, n = self.rows, self.cols | |
| for r in range(m): | |
| a_ptr = self.row_ptr(r) | |
| s = self.dot_ptr_safe(a_ptr, x_ptr, n) | |
| out_ptr[r] = s | |
| #..................................................................................... | |
| def matTvec_ptr(self, x_ptr, out_ptr): | |
| m = self.rows; n = self.cols | |
| # Zero out the out_ptr | |
| for j in range(n): out_ptr[j] = 0.0 | |
| for i in range(m): | |
| xr = x_ptr[i] | |
| a_ptr = self.row_ptr(i) | |
| # *Accumulate xr*a_ptr into the out_ptr | |
| for _ in range(n): out_ptr[_]+=xr*a_ptr[_] | |
| #..................................................................................... | |
| def swap_behavior_atomic(self, new_behavior): | |
| # Atomic header behavior swap(rewrites whole file to bump version & behavior) | |
| def writer(fd): | |
| new_version = int(self.hdr.version)+1 | |
| new_hdr = pack_header(new_version, self.rows, self.cols, self.stride, new_behavior) | |
| os.write(fd, new_hdr) | |
| self.m.seek(HEADER_SIZE) | |
| remaining, chunk = self.size-HEADER_SIZE, 1<<20 | |
| while remaining: | |
| sz = min(chunk, remaining) | |
| data = self.m.read(sz) | |
| os.write(fd, data) | |
| remaining-=sz | |
| atomic_write_full(self.path, writer) | |
| self.m.close() | |
| os.close(self.fd) | |
| self._open_map_overlay() | |
| #..................................................................................... | |
| #..................................................................................... | |
| #..................................................................................... | |
| class BermudaDispatcher: | |
| def __init__(self): | |
| self.behaviors = {} | |
| #..................................................................................... | |
| def register(self, bid, matvec_fn, matTvec_fn=None): | |
| self.behaviors[bid] = (matvec_fn, matTvec_fn) | |
| #..................................................................................... | |
| def _prepare_input(self, x): | |
| # Returns view_or_ptr, length and keeper | |
| if isinstance(x, BermudaMatrix): | |
| return x.vec_ptr() | |
| else: | |
| arr = (c_double*len(x))(*x) | |
| return (cast(arr, POINTER(c_double)), len(x), arr) | |
| #..................................................................................... | |
| def _ensure_ctypes_ptr(self, view_or_ptr, length, keeper): | |
| # If @view_or_ptr a memoryview typed...build tmp ctypes array copy for kernel | |
| if isinstance(view_or_ptr, memoryview): | |
| tmp = (c_double*length)(*view_or_ptr[:length]) | |
| return cast(tmp, POINTER(c_double)), length, tmp | |
| # Already a ctypes pointer @list, return as is | |
| return view_or_ptr, length, keeper | |
| #..................................................................................... | |
| def run_matvec(self, A: BermudaMatrix, x, out): | |
| A.refresh() | |
| bid = int(A.behavior) | |
| if bid not in self.behaviors: | |
| raise KeyError('behavior not registered') | |
| matvec_fn, _ = self.behaviors[bid] | |
| x_view, x_len, x_keeper = self._prepare_input(x) | |
| x_ptr, x_len, x_keeper = self._ensure_ctypes_ptr(x_view, x_len, x_keeper) | |
| # Use a temporary ctypes output buffer for kernel, then copy back into out | |
| out_len = A.rows | |
| out_arr = (c_double*out_len)() | |
| out_ptr = cast(out_arr, POINTER(c_double)) | |
| matvec_fn(A, x_ptr, out_ptr) | |
| if isinstance(out, BermudaMatrix): | |
| # Copy back into a file backed vector, returns mv slice | |
| mv, ln, keeper = out.vec_ptr() | |
| for i in range(min(ln, out_len)): out._set(i, 0, out_arr[i]) | |
| return out | |
| else: | |
| return [out_arr[i] for i in range(out_len)] | |
| #..................................................................................... | |
| def run_matTvec(self, A: BermudaMatrix, x, out): | |
| A.refresh() | |
| bid = int(A.behavior) | |
| if bid not in self.behaviors: | |
| raise KeyError('behavior not registered') | |
| _, matTvec_fn = self.behaviors[bid] | |
| if matTvec_fn is None: | |
| raise KeyError('matTvec not implemented for behavior') | |
| x_view, x_len, x_keeper = self._prepare_input(x) | |
| x_ptr, x_len, x_keeper = self._ensure_ctypes_ptr(x_view, x_len, x_keeper) | |
| out_len = A.cols | |
| out_arr = (c_double*out_len)() | |
| out_ptr = cast(out_arr, POINTER(c_double)) | |
| matTvec_fn(A, x_ptr, out_ptr) | |
| if isinstance(out, BermudaMatrix): | |
| mv, ln, keeper = out.vec_ptr() | |
| for j in range(min(ln, out_len)): out._set(0, j, out_arr[j]) | |
| return out | |
| else: | |
| return [out_arr[i] for i in range(out_len)] | |
| #..................................................................................... | |
| #///////////////////////////////////////////////////////////////////////////////////// | |
| # Dispatcher instance(s) to be used by your own custom algorithm(s) | |
| # NOTE: You must register behavior callbacks externally to be valid | |
| # bm.register(#, your_matvec_fn, your_matTvec_fn) ---> pntr kernels | |
| dispatcher = BermudaDispatcher() | |
| dispatcher.register(0, BermudaMatrix.matvec_ptr, BermudaMatrix.matTvec_ptr) | |
| dispatcher.register(1, BermudaMatrix.matvec_ptr, BermudaMatrix.matTvec_ptr) | |
| #///////////////////////////////////////////////////////////////////////////////////// | |
| #..................................................................................... | |
| def orthonormalize_and_store(vec_src, basis_files, out_file_path): | |
| # Orthonormalize vector file against existing basis files & normalize: | |
| # @vec_src -> BermudaMatrix or list | |
| # @basis_files -> list of BermudaMatrix | |
| # @out_file_path -> path to create file with length matching @vec_src | |
| m = vec_src.rows if isinstance(vec_src, BermudaMatrix) else len(vec_src) | |
| create_matrix_file(out_file_path, m, 1) | |
| out_vm = BermudaMatrix(out_file_path, mode='r+') | |
| # Get source list(copy once) | |
| if isinstance(vec_src, BermudaMatrix): | |
| src = [vec_src._get(i, 0) for i in range(m)] | |
| else: src = list(vec_src) | |
| # Subtract projections | |
| for bf in basis_files: | |
| bf_vec = [bf._get(i,0) for i in range(m)] | |
| dot = sum(a*b for a,b in zip(src, bf_vec)) | |
| for i in range(m): src[i]-=dot*bf_vec[i] | |
| norm = sum(x*x for x in src)**0.5 | |
| if norm == 0.0: | |
| out_vm.close() | |
| raise ValueError('zero vector after orthonormalization') | |
| src = [x/norm for x in src] | |
| for i, val in enumerate(src): out_vm._set(i, 0, val) | |
| return out_vm | |
| #..................................................................................... | |
| def bidiagonalize_svd(A_path, k, work_dir): | |
| # SVD bidiagonalization driver that stores the basis vectors as mmap files | |
| A = BermudaMatrix(A_path, mode='r+') | |
| m, n, U_files, V_files, alpha, beta = A.rows, A.cols, [], [], [], [] | |
| v0 = [random.random() for _ in range(n)] | |
| # Disappear pre-existing files if any there... | |
| try: | |
| os.remove(os.path.join(work_dir, 'v_0.mmat')) | |
| except Exception: | |
| pass | |
| for i in range(k): | |
| try: | |
| os.remove(os.path.join(work_dir, f'v_{i+1}.mmat')) | |
| os.remove(os.path.join(work_dir, f'u_{i}.mmat')) | |
| except Exception: | |
| pass | |
| v0_path = os.path.join(work_dir, 'v_0.mmat') | |
| create_matrix_file(v0_path, 1, n, data_fn=lambda r,c: v0[c]) | |
| v0_f = BermudaMatrix(v0_path, mode='r+') | |
| V_files.append(v0_f) | |
| for i in range(k): | |
| vi = V_files[i] | |
| vi_list = [vi._get(0,j) for j in range(vi.cols)] | |
| # Allocate out u_col file | |
| u_path = os.path.join(work_dir, f'u_{i}.mmat') | |
| create_matrix_file(u_path, m, 1) | |
| u_file = BermudaMatrix(u_path, mode='r+') | |
| # Run matvec: A*vi -> u_file -> orthonormalize it against existing U_files | |
| dispatcher.run_matvec(A, vi_list, u_file) | |
| if U_files: | |
| # Overwrite the u_path | |
| u_file = orthonormalize_and_store(u_file, U_files, u_path) | |
| u_vals = [u_file._get(r,0) for r in range(m)] | |
| alpha_i = sum(x*x for x in u_vals)**0.5 | |
| alpha.append(alpha_i) | |
| if alpha_i == 0.0: | |
| break | |
| # Normalize the u_file in-place | |
| for r in range(m): u_file._set(r, 0, u_vals[r]/alpha_i) | |
| U_files.append(u_file) | |
| # @v_next = A^T*u_file | |
| v_next_path = os.path.join(work_dir, f'v_{i+1}.mmat') | |
| create_matrix_file(v_next_path, 1, n) | |
| v_next_file = BermudaMatrix(v_next_path, mode='r+') | |
| # Run your parallel dispatcher if any | |
| dispatcher.run_matTvec(A, u_file, v_next_file) | |
| # Orthonormalize v_next against existing V_files | |
| if V_files: | |
| v_next_file = orthonormalize_and_store(v_next_file, V_files, v_next_path) | |
| v_vals = [v_next_file._get(0,j) for j in range(n)] | |
| beta_i = sum(x*x for x in v_vals)**0.5 | |
| beta.append(beta_i) | |
| if beta_i == 0.0 or i == k-1: | |
| break | |
| # Normalize v_next_file | |
| for j in range(n): v_next_file._set(0, j, v_vals[j]/beta_i) | |
| V_files.append(v_next_file) | |
| K = len(alpha) | |
| B = [[0.0]*K for _ in range(K)] | |
| for i in range(K): | |
| B[i][i] = alpha[i] | |
| if i < K-1: B[i][i+1] = beta[i] | |
| A.close() | |
| return U_files, V_files, B | |
| #..................................................................................... | |
| def svd_small(B, steps=1000): | |
| # Computes dense SVD eigen-decomposition: B^T B(for V) then form U=BVΣ^{-1} | |
| # @B -> kxk list-of-lists(float) | |
| k = len(B) | |
| btb = [[0.0]*k for _ in range(k)] | |
| for i in range(k): | |
| for j in range(k): | |
| s = 0.0 | |
| for t in range(k): s+=B[t][i]*B[t][j] | |
| btb[i][j] = s | |
| # Power iter with deflation -> copy/get eigenpairs @btb | |
| eigvals, eigvecs = [], [] | |
| btb_works = deepcopy(btb) | |
| for _ in range(k): | |
| # Random uniform initial vector | |
| v = [random.uniform(1.001, 1.319) for _ in range(k)] | |
| norm = math.sqrt(sum(x*x for x in v)) | |
| v = [x/norm for x in v] | |
| for _iter in range(steps): | |
| w = [0.0]*k | |
| for i in range(k): | |
| s = 0.0 | |
| for j in range(k): s+=btb_works[i][j]*v[j] | |
| w[i] = s | |
| wn = math.sqrt(sum(x*x for x in w)) | |
| if wn == 0.0: | |
| break | |
| v_next = [x/wn for x in w] | |
| # Convergence check | |
| diff = math.sqrt(sum((v_next[i]-v[i])**2 for i in range(k))) | |
| v = v_next | |
| if diff < 1e-12: | |
| break | |
| av = [0.0]*k | |
| for i in range(k): | |
| s = 0.0 | |
| for j in range(k): s+=btb_works[i][j]*v[j] | |
| av[i] = s | |
| lam = sum(av[i]*v[i] for i in range(k)) | |
| eigvals.append(lam) | |
| eigvecs.append(v) | |
| # Force deflate -> subtract lam*v v^T from btb_works | |
| for i in range(k): | |
| for j in range(k): btb_works[i][j]-=lam*v[i]*v[j] | |
| # Convert eigpairs to singular values -> @vtilde cols | |
| pairs = sorted(zip(eigvals, eigvecs), key=lambda p: -p[0]) | |
| singular_vals, vtilde = [], [] | |
| for val, vec in pairs: | |
| sigma = math.sqrt(max(val, 0.0)) | |
| singular_vals.append(sigma) | |
| vtilde.append(vec) # columns | |
| # Calculate @utilde = B*vtilde*sigma^{-1} | |
| k = len(B) | |
| utilde = [[0.0]*k for _ in range(k)] | |
| for col in range(k): | |
| vcol, w = vtilde[col], [0.0]*k | |
| for i in range(k): | |
| s = 0.0 | |
| # @w = B*vcol(*k-vector) | |
| for j in range(k): s+=B[i][j]*vcol[j] | |
| w[i] = s | |
| sigma = singular_vals[col] | |
| if sigma > 0: utilde_col = [w_i/sigma for w_i in w] | |
| else: utilde_col = [0.0]*k | |
| # *Store @ column in utilde | |
| for i in range(k): utilde[i][col] = utilde_col[i] | |
| return utilde, singular_vals, [[vtilde[col][row] for col in range(k)] for row in range(k)] | |
| #..................................................................................... | |
| def reconstruct_singular_triplets(U_files, V_files, B, r, out_dir): | |
| os.makedirs(out_dir, exist_ok=True) | |
| K = len(B) | |
| if r > K: r = K | |
| utilde, sigmas, vtilde = svd_small(B) | |
| m, n, U_basis_cols = U_files[0].rows, V_files[0].cols, [] | |
| for uf in U_files: U_basis_cols.append([uf._get(i,0) for i in range(m)]) | |
| V_basis_cols = [] | |
| for vf in V_files: V_basis_cols.append([vf._get(0,j) for j in range(n)]) | |
| sigma_out, u_paths, v_paths = [], [], [] | |
| for t in range(r): | |
| sigma_out.append(sigmas[t]) | |
| u_vec = [0.0]*m | |
| for j in range(K): | |
| coeff = utilde[j][t] | |
| if coeff == 0.0: | |
| continue | |
| col = U_basis_cols[j] | |
| for i in range(m): u_vec[i]+=coeff*col[i] | |
| norm_u = math.sqrt(sum(x*x for x in u_vec)) | |
| if norm_u != 0.0: u_vec = [x/norm_u for x in u_vec] | |
| upath = os.path.join(out_dir, f'u_final_{t}.mmat') | |
| create_matrix_file(upath, m, 1, data_fn=lambda r_idx, c_idx: u_vec[r_idx]) | |
| u_paths.append(upath) | |
| v_vec = [0.0]*n | |
| for j in range(K): | |
| coeff = vtilde[j][t] | |
| if coeff == 0.0: | |
| continue | |
| col = V_basis_cols[j] | |
| for i in range(n): v_vec[i]+=coeff*col[i] | |
| norm_v = math.sqrt(sum(x*x for x in v_vec)) | |
| if norm_v != 0.0: v_vec = [x/norm_v for x in v_vec] | |
| vpath = os.path.join(out_dir, f'v_final_{t}.mmat') | |
| create_matrix_file(vpath, n, 1, data_fn=lambda r_idx, c_idx: v_vec[r_idx]) | |
| v_paths.append(vpath) | |
| with open(os.path.join(out_dir, 'sigma.txt'), 'w') as f: | |
| for s in sigma_out: | |
| f.write(f'{s}\n') | |
| return u_paths, v_paths, sigma_out | |
| #..................................................................................... | |
| def test(): | |
| d = '.' | |
| A_path=os.path.join(d, 'A.mmat') | |
| try: os.remove(A_path) | |
| except: pass | |
| create_matrix_file(A_path, 42, 3, data_fn=lambda r,c: float(r*10+c+1)) | |
| work = './work' | |
| try: | |
| shutil.rmtree(work) | |
| except: | |
| pass | |
| os.makedirs(work, exist_ok=True) | |
| U_files, V_files, B = bidiagonalize_svd(A_path, k=3, work_dir=work) | |
| print('Bidiagonal B:', B) | |
| print(f'U files: {[f.path for f in U_files]}') | |
| print(f'V files: {[f.path for f in V_files]}') | |
| out_dir = f'{os.path.dirname(os.path.abspath(__file__))}/bmd_svd/out_files' | |
| print(reconstruct_singular_triplets(U_files, V_files, B, 8, out_dir)) | |
| for f in (U_files+V_files): | |
| try: | |
| f.close() | |
| except: | |
| pass | |
| U_files.clear() | |
| V_files.clear() | |
| gc.collect() | |
| time.sleep(0.05) | |
| shutil.rmtree(work) | |
| test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment