Skip to content

Instantly share code, notes, and snippets.

View KeAWang's full-sized avatar

Alex Wang KeAWang

View GitHub Profile
@KeAWang
KeAWang / indexing.py
Created May 12, 2025 01:00
Indexing exercises
# %%
import numpy as np
"""Column-wise Sorting
Spec: sort every column of X using the pre-computed indices in idx
Desired result shape: (M, N)
"""
M, N = 6, 4
X = np.random.randn(M, N)
@KeAWang
KeAWang / gd_jacobian.py
Last active January 15, 2025 23:03
Jacobian of a gradient update recurrence
import torch
from einops import einsum, rearrange
from typing import NamedTuple
class MLPParams(NamedTuple):
W1: torch.Tensor
W2: torch.Tensor
b1: torch.Tensor
def mlp(params: NamedTuple, x):
@KeAWang
KeAWang / gpt-2-wikitext-103.py
Created September 23, 2024 20:23 — forked from thomwolf/gpt-2-wikitext-103.py
A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103
# Copyright (c) 2019-present, Thomas Wolf.
# All rights reserved. This source code is licensed under the MIT-style license.
""" A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103 """
import os
from collections import namedtuple
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from ignite.engine import Engine, Events
@KeAWang
KeAWang / mfu_compute.py
Created April 11, 2024 17:17 — forked from Chillee/mfu_compute.py
Compute Flop Utilization in PyTorch
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
@KeAWang
KeAWang / tcn_experiment.py
Last active November 20, 2023 17:37
TCN experiment with correct residual connection
# %%
import torch
import numpy as np
def make_adding_dataset(num_seqs, seq_len, num_terms=2, seed=43141):
assert 0 <= num_terms <= seq_len
rng = np.random.default_rng(seed=seed)
numbers = rng.uniform(0, 1, (num_seqs, seq_len)) # B x T
mask = np.zeros_like(numbers) # B x T
non_zero = np.stack([rng.choice(seq_len, num_terms, replace=True) for _ in range(num_seqs)]) # B x 2
mask[np.arange(num_seqs)[:, None], non_zero] = 1 # mask[i, non_zero[i, j]]
@KeAWang
KeAWang / nan_embedder.py
Created November 14, 2023 18:42
PyTorch NaN embedder
import torch
class NanWrapper(torch.nn.Module):
"""Wrapper module around a torch Module that handles incoming nans"""
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
""" Masks the entire last dimension (usually the feature/channel dimension) if any element is NaN. """
@KeAWang
KeAWang / count_torch_params.py
Last active November 14, 2023 18:35
Count number of pytorch parameters
import torch
def count_params(model: torch.nn.Module):
"""count number trainable parameters in a pytorch model"""
total_params = sum(torch.numel(x) for x in model.parameters())
return total_params
@KeAWang
KeAWang / constrain.py
Created November 14, 2023 18:32
Constrain and unconstrain
import torch
def constrain(x, min, max, temperature:float=1.):
return (max - min) * torch.sigmoid(x / temperature) + min
def unconstrain(y, min, max, temperature:float=1, EPS:float=1e-8):
assert torch.all(y >= min) and torch.all(y <= max)
# ensure both numerator and denominator are positive
numerator = y - min
@KeAWang
KeAWang / dexcom_palette.py
Created October 23, 2023 03:46
Dexcom Clarity Color Palette
import numpy as np
tir_palette = {
"very_low": "#A61D2A",
"low": "#EE1D23",
"in_range": "#26B257",
"high": "#FAAB1A",
"very_high": "#F47D21"
}
def color_bg(bgs):
bins = [0, 54, 70, 180, 250, 1000]
@KeAWang
KeAWang / array_to_dataframe.py
Created May 16, 2023 22:57
Multidimensional array to pandas dataframe
import pandas as pd
from typing import Optional, List
def array_to_dataframe(array, axis_names: Optional[List[str]]=None):
"""Based on https://stackoverflow.com/questions/35525028/how-to-transform-a-3d-arrays-into-a-dataframe-in-python"""
if axis_names is None:
axis_names = list(range(array.ndim))
index = pd.MultiIndex.from_product([range(s) for s in array.shape], names=names)
df = pd.DataFrame({"array": array.flatten()}, index=index)["array"]