Skip to content

Instantly share code, notes, and snippets.

View norabelrose's full-sized avatar

Nora Belrose norabelrose

View GitHub Profile
@norabelrose
norabelrose / r-nn-svd.py
Last active January 2, 2023 08:34
Relaxed non-negative SVD
import torch as th
import torch.nn.functional as F
# Sinkhorn-Knopp algorithm for projecting onto doubly stochastic matrices
def sinkhorn_knopp(A: th.Tensor, max_iter: int = 20):
A = A.clone()
for _ in range(max_iter):
A /= A.sum(dim=1, keepdim=True)
@norabelrose
norabelrose / cbe.py
Created January 5, 2023 08:39
Causal basis extraction
from copy import deepcopy
from einops import rearrange
from tqdm.auto import tqdm, trange
from transformers import PreTrainedModel
from typing import (
Literal, NamedTuple, Optional, Union, Sequence
)
from white_box import TunedLens
from white_box.causal import ablate_subspace, remove_subspace
from white_box.nn import Decoder
from itertools import product
from scipy.optimize import curve_fit
from typing import NamedTuple, Sequence
import numpy as np
class Break(NamedTuple):
c: float
d: float
f: float
from dataclasses import dataclass
from elk.metrics import to_one_hot
from elk.training import Classifier
from scipy.optimize import brentq
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from torch import Tensor
import numpy as np
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from functools import partial
from itertools import cycle
import logging
import multiprocessing as std_mp
import socket
import warnings
import dill
import os
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types, gc, os, time, re
import torch
from torch.nn import functional as F
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
import torch
from torch import Tensor, nn
class ConceptEraser(nn.Module):
"""Removes the subspace responsible for correlations between hiddens and labels."""
mean_x: Tensor
"""Running mean of X."""
@norabelrose
norabelrose / groupby.py
Created August 23, 2023 21:30
GroupedTensor
from dataclasses import dataclass
from typing import Callable, Iterator
from torch import LongTensor, Tensor
import torch
@dataclass(frozen=True)
class GroupedTensor:
"""A tensor split into groups along a given dimension.
@norabelrose
norabelrose / intersection.py
Created September 9, 2023 19:20
Intersection of ranges
from torch import Tensor
import torch
def intersection_of_ranges(As: Tensor) -> Tensor:
"""Compute the intersection of the ranges of a batch of matrices.
We use the formula from "Projectors on Intersections of Subspaces" by
Ben-Israel (2015) <http://benisrael.net/ADI-BENISRAEL-AUG-29-13.pdf>.
"""
@norabelrose
norabelrose / imagenet_kaggle.py
Last active September 19, 2023 20:51
PyTorch dataset for loading imagenet images from Kaggle
from PIL import Image
from torch.utils.data import Dataset
import json
import os
class ImageNetKaggle(Dataset):
def __init__(self, root, split, transform=None):
self.samples = []
self.targets = []