Skip to content

Instantly share code, notes, and snippets.

View norabelrose's full-sized avatar

Nora Belrose norabelrose

View GitHub Profile
@norabelrose
norabelrose / cdf-erasure.py
Created October 3, 2023 07:55
Erasing CIFAR-10 classes with componentwise probability integral transform
from argparse import ArgumentParser
from itertools import pairwise
from pathlib import Path
from typing import Callable, Sized
import random
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics as tm
@norabelrose
norabelrose / classifier.py
Last active October 2, 2023 14:15
CUDA-enabled logistic regression with CV
from dataclasses import dataclass, field
import torch
from torch import Tensor
from torch.nn.functional import (
binary_cross_entropy_with_logits as bce_with_logits,
)
from torch.nn.functional import (
cross_entropy,
)
@norabelrose
norabelrose / cifar-leace.py
Last active September 30, 2023 07:36
messy cifar leace testing
from argparse import ArgumentParser
from typing import Any, Callable, Protocol, Sized, Type
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics as tm
import torchvision as tv
from concept_erasure import LeaceFitter, OracleFitter, QuadraticFitter
from pytorch_lightning.loggers import WandbLogger
@norabelrose
norabelrose / qleace-mlp.py
Last active October 1, 2023 06:12
Q-LEACE Learning Prevention on 3-layer MLP
from argparse import ArgumentParser
from itertools import pairwise
from typing import Callable, Sized
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics as tm
import torchvision as tv
from concept_erasure import LeaceFitter, OracleFitter, QuadraticFitter
@norabelrose
norabelrose / kronecker_decompose.py
Last active March 13, 2025 14:40
Fast, optimal Kronecker decomposition
from einops import rearrange
from torch import Tensor
import torch
def kronecker_decompose(
A: Tensor, m: int, n: int, *, k: int = 1, niter: int = 10
) -> tuple[Tensor, Tensor]:
"""Frobenius-optimal decomposition of `A` into a sum of `k` Kronecker products.
@norabelrose
norabelrose / imagenet_kaggle.py
Last active September 19, 2023 20:51
PyTorch dataset for loading imagenet images from Kaggle
from PIL import Image
from torch.utils.data import Dataset
import json
import os
class ImageNetKaggle(Dataset):
def __init__(self, root, split, transform=None):
self.samples = []
self.targets = []
@norabelrose
norabelrose / intersection.py
Created September 9, 2023 19:20
Intersection of ranges
from torch import Tensor
import torch
def intersection_of_ranges(As: Tensor) -> Tensor:
"""Compute the intersection of the ranges of a batch of matrices.
We use the formula from "Projectors on Intersections of Subspaces" by
Ben-Israel (2015) <http://benisrael.net/ADI-BENISRAEL-AUG-29-13.pdf>.
"""
@norabelrose
norabelrose / groupby.py
Created August 23, 2023 21:30
GroupedTensor
from dataclasses import dataclass
from typing import Callable, Iterator
from torch import LongTensor, Tensor
import torch
@dataclass(frozen=True)
class GroupedTensor:
"""A tensor split into groups along a given dimension.
import torch
from torch import Tensor, nn
class ConceptEraser(nn.Module):
"""Removes the subspace responsible for correlations between hiddens and labels."""
mean_x: Tensor
"""Running mean of X."""
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types, gc, os, time, re
import torch
from torch.nn import functional as F
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True