This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch as th | |
import torch.nn.functional as F | |
# Sinkhorn-Knopp algorithm for projecting onto doubly stochastic matrices | |
def sinkhorn_knopp(A: th.Tensor, max_iter: int = 20): | |
A = A.clone() | |
for _ in range(max_iter): | |
A /= A.sum(dim=1, keepdim=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
from einops import rearrange | |
from tqdm.auto import tqdm, trange | |
from transformers import PreTrainedModel | |
from typing import ( | |
Literal, NamedTuple, Optional, Union, Sequence | |
) | |
from white_box import TunedLens | |
from white_box.causal import ablate_subspace, remove_subspace | |
from white_box.nn import Decoder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import product | |
from scipy.optimize import curve_fit | |
from typing import NamedTuple, Sequence | |
import numpy as np | |
class Break(NamedTuple): | |
c: float | |
d: float | |
f: float |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from elk.metrics import to_one_hot | |
from elk.training import Classifier | |
from scipy.optimize import brentq | |
from sklearn.datasets import make_classification | |
from sklearn.linear_model import LogisticRegression | |
from torch import Tensor | |
import numpy as np | |
import torch | |
import torch.nn.functional as F |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from functools import partial | |
from itertools import cycle | |
import logging | |
import multiprocessing as std_mp | |
import socket | |
import warnings | |
import dill | |
import os |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import Tensor, nn | |
class ConceptEraser(nn.Module): | |
"""Removes the subspace responsible for correlations between hiddens and labels.""" | |
mean_x: Tensor | |
"""Running mean of X.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from typing import Callable, Iterator | |
from torch import LongTensor, Tensor | |
import torch | |
@dataclass(frozen=True) | |
class GroupedTensor: | |
"""A tensor split into groups along a given dimension. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import Tensor | |
import torch | |
def intersection_of_ranges(As: Tensor) -> Tensor: | |
"""Compute the intersection of the ranges of a batch of matrices. | |
We use the formula from "Projectors on Intersections of Subspaces" by | |
Ben-Israel (2015) <http://benisrael.net/ADI-BENISRAEL-AUG-29-13.pdf>. | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from PIL import Image | |
from torch.utils.data import Dataset | |
import json | |
import os | |
class ImageNetKaggle(Dataset): | |
def __init__(self, root, split, transform=None): | |
self.samples = [] | |
self.targets = [] |
OlderNewer