Skip to content

Instantly share code, notes, and snippets.

View norabelrose's full-sized avatar

Nora Belrose norabelrose

View GitHub Profile
@norabelrose
norabelrose / relu_poly_ev.py
Last active November 5, 2024 03:15
Compute E[x^n * ReLU(x)] analytically where x ~ N(mu, sigma^2)
import math
import numpy as np
from numpy.typing import ArrayLike, NDArray
from scipy.special import factorial2
from scipy.stats import norm
def relu_poly_ev(n: int, mu: ArrayLike, sigma: ArrayLike) -> NDArray:
"""
Compute E[x^n * ReLU(x)] analytically where x ~ N(mu, sigma^2)
@norabelrose
norabelrose / x_gelu_expectation.py
Created October 18, 2024 00:46
expectation of x * gelu(x) where x ~ N(mu, sigma)
def x_gelu_expectation(mu, sigma):
"""Compute E[x * gelu(x)] for x ~ N(mu, sigma^2) analytically."""
evCDF = norm.cdf(mu / np.sqrt(1 + sigma**2))
evPDF = norm.pdf(mu / np.sqrt(1 + sigma**2)) / np.sqrt(1 + sigma**2)
evZPDF = -mu*sigma/np.sqrt(1 + sigma**2)**3 * norm.pdf(mu / np.sqrt(1 + sigma**2))
# linearity
evXPDF = mu * evPDF + sigma * evZPDF
# identity (first time)
@norabelrose
norabelrose / relu-ols.py
Created October 17, 2024 05:41
Analytically computing the least-squares linear approximation to a ReLU network
import numpy as np
from scipy.stats import norm
def compute_E_xf(W1, W2, b1):
"""
Computes the analytical expectation E[x f(x)^T] for a single hidden layer ReLU network.
Parameters:
- W1: numpy.ndarray, shape (k, n)
Weight matrix of the first layer (W^{(1)}).
@norabelrose
norabelrose / train.py
Last active December 8, 2023 22:19
Features across time training script
from argparse import ArgumentParser
from dataclasses import dataclass
import torch
import torchvision.transforms as T
from concept_erasure import QuadraticEditor, QuadraticFitter
from datasets import (
ClassLabel, Dataset, DatasetDict, Features, Image, load_dataset
)
from einops import rearrange
@norabelrose
norabelrose / extract.py
Created November 15, 2023 05:59
Hidden state extraction
from argparse import ArgumentParser
from pathlib import Path
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
@norabelrose
norabelrose / dpo.py
Created November 8, 2023 07:04
Training quirky models with DPO
from argparse import ArgumentParser
from datasets import load_dataset
from peft import LoraConfig
from trl import DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
if __name__ == "__main__":
parser = ArgumentParser()
@norabelrose
norabelrose / training-code.py
Created October 24, 2023 00:36
training code
from itertools import pairwise
from typing import Literal
import pytorch_lightning as pl
import torch
import torchmetrics as tm
import torchvision as tv
from torch import nn
from torch.optim import RAdam
from torch.optim.lr_scheduler import CosineAnnealingLR
@norabelrose
norabelrose / moments.py
Last active October 22, 2023 05:18
Blocked moment generator
from itertools import (
combinations_with_replacement as pyramid
)
from typing import Iterable
import math
from opt_einsum import get_symbol
from torch import Tensor
import torch
@norabelrose
norabelrose / triton-covariance.py
Last active October 20, 2023 10:30
Compute covariance matrix in Triton
from itertools import product
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_N': n, 'BLOCK_D': d, 'GROUP_SIZE_D': 8}, num_stages=4, num_warps=4)
@norabelrose
norabelrose / cumulants.py
Last active October 19, 2023 03:56
Ryan Greenblatt's cumulant estimation code
from typing import Optional
import torch
def get_all_the_cumulants(
x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, w: torch.Tensor, weights_in: Optional[torch.Tensor] = None
):
if weights_in is not None:
weights = weights_in
weights = weights / weights.sum()