This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from functools import partial | |
from itertools import cycle | |
import logging | |
import multiprocessing as std_mp | |
import socket | |
import warnings | |
import dill | |
import os |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from elk.metrics import to_one_hot | |
from elk.training import Classifier | |
from scipy.optimize import brentq | |
from sklearn.datasets import make_classification | |
from sklearn.linear_model import LogisticRegression | |
from torch import Tensor | |
import numpy as np | |
import torch | |
import torch.nn.functional as F |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import product | |
from scipy.optimize import curve_fit | |
from typing import NamedTuple, Sequence | |
import numpy as np | |
class Break(NamedTuple): | |
c: float | |
d: float | |
f: float |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
from einops import rearrange | |
from tqdm.auto import tqdm, trange | |
from transformers import PreTrainedModel | |
from typing import ( | |
Literal, NamedTuple, Optional, Union, Sequence | |
) | |
from white_box import TunedLens | |
from white_box.causal import ablate_subspace, remove_subspace | |
from white_box.nn import Decoder |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch as th | |
import torch.nn.functional as F | |
# Sinkhorn-Knopp algorithm for projecting onto doubly stochastic matrices | |
def sinkhorn_knopp(A: th.Tensor, max_iter: int = 20): | |
A = A.clone() | |
for _ in range(max_iter): | |
A /= A.sum(dim=1, keepdim=True) |
NewerOlder