Skip to content

Instantly share code, notes, and snippets.

View norabelrose's full-sized avatar

Nora Belrose norabelrose

View GitHub Profile
@norabelrose
norabelrose / extract.py
Created November 15, 2023 05:59
Hidden state extraction
from argparse import ArgumentParser
from pathlib import Path
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
@norabelrose
norabelrose / train.py
Last active December 8, 2023 22:19
Features across time training script
from argparse import ArgumentParser
from dataclasses import dataclass
import torch
import torchvision.transforms as T
from concept_erasure import QuadraticEditor, QuadraticFitter
from datasets import (
ClassLabel, Dataset, DatasetDict, Features, Image, load_dataset
)
from einops import rearrange
@norabelrose
norabelrose / relu-ols.py
Created October 17, 2024 05:41
Analytically computing the least-squares linear approximation to a ReLU network
import numpy as np
from scipy.stats import norm
def compute_E_xf(W1, W2, b1):
"""
Computes the analytical expectation E[x f(x)^T] for a single hidden layer ReLU network.
Parameters:
- W1: numpy.ndarray, shape (k, n)
Weight matrix of the first layer (W^{(1)}).
@norabelrose
norabelrose / x_gelu_expectation.py
Created October 18, 2024 00:46
expectation of x * gelu(x) where x ~ N(mu, sigma)
def x_gelu_expectation(mu, sigma):
"""Compute E[x * gelu(x)] for x ~ N(mu, sigma^2) analytically."""
evCDF = norm.cdf(mu / np.sqrt(1 + sigma**2))
evPDF = norm.pdf(mu / np.sqrt(1 + sigma**2)) / np.sqrt(1 + sigma**2)
evZPDF = -mu*sigma/np.sqrt(1 + sigma**2)**3 * norm.pdf(mu / np.sqrt(1 + sigma**2))
# linearity
evXPDF = mu * evPDF + sigma * evZPDF
# identity (first time)
@norabelrose
norabelrose / relu_poly_ev.py
Last active November 5, 2024 03:15
Compute E[x^n * ReLU(x)] analytically where x ~ N(mu, sigma^2)
import math
import numpy as np
from numpy.typing import ArrayLike, NDArray
from scipy.special import factorial2
from scipy.stats import norm
def relu_poly_ev(n: int, mu: ArrayLike, sigma: ArrayLike) -> NDArray:
"""
Compute E[x^n * ReLU(x)] analytically where x ~ N(mu, sigma^2)