Skip to content

Instantly share code, notes, and snippets.

View evanatyourservice's full-sized avatar

Evan Walters evanatyourservice

  • Denver, CO
View GitHub Profile
@evanatyourservice
evanatyourservice / batched_sqrt_inv.py
Created June 5, 2025 18:02
batched sqrt inverses using newton schulz
import torch
def compute_H_inv_cubic(A, num_iters=10):
X = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device).expand(A.shape)
for _ in range(num_iters):
X_update = torch.einsum('...ij,...jk,...kl,...lm->...im', X, A, X, X)
X = 1.5 * X - 0.5 * X_update
return X
def compute_H_inv_quintic(A, num_iters=5):
@evanatyourservice
evanatyourservice / as_ns_inverses.py
Created June 5, 2025 18:00
alexander stotsky newton schulz inverses
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_spd_matrix
"""
Factorized newton schulz iters for inverse of SPD matrix, from paper
by Alexander Stotsky https://arxiv.org/pdf/2208.04068
From efficiency equation EI = n^(1/np) where EI is the efficiency, n is the
order of the algorithm, and np is the number of matmuls in the algorithm, n=11
import torch
def _qdwh_qr_step(u, params):
a_minus_e_by_sqrt_c, sqrt_c, e = params
M, N = u.shape
eye_n = torch.eye(N, dtype=u.dtype, device=u.device)
y = torch.cat((sqrt_c * u, eye_n), dim=0)
q, _ = torch.linalg.qr(y, mode='reduced')
q1, q2 = q[:M, :], q[M:, :]
return e * u + a_minus_e_by_sqrt_c * (q1 @ q2.mT)
@evanatyourservice
evanatyourservice / hellaswag_jax.py
Last active September 15, 2024 00:54
How to prepare and evaluate on hellaswag in JAX
import json
from tqdm import tqdm
import numpy as np
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax
import optax.tree_utils as otu
import tensorflow as tf
@evanatyourservice
evanatyourservice / Beta-TCVAE in JAX Flax
Created January 1, 2024 20:10
Beta-TCVAE in JAX Flax
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from tensorflow_probability.substrates.jax import distributions as tfd
"""
There's a typo in most B-TCVAE implementations on github, so I thought I'd make a
quick gist of a working B-TCVAE.