Created
February 20, 2024 23:50
-
-
Save Phylliida/ec38ffba6addf8c65bc0fd2479b1e063 to your computer and use it in GitHub Desktop.
Noising mamba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from einops import rearrange | |
import torch | |
from functools import partial | |
from jaxtyping import Float | |
from transformer_lens.hook_points import HookPoint | |
import tqdm | |
import pandas as pd | |
import plotly.express as px | |
from mamba_lens import HookedMamba | |
from test_data import greater_than_data_generator, IOI_generator, ABC_TEMPLATES, BAC_TEMPLATES, BABA_TEMPLATES, BABA_LONG_TEMPLATES, BABA_LATE_IOS, BABA_EARLY_IOS | |
model = HookedMamba.from_pretrained("state-spaces/mamba-370m") | |
torch.set_grad_enabled(False) | |
seed = 27 | |
num_examples = 120 | |
data = IOI_generator(templates=[BABA_TEMPLATES[0]], tokenizer=model.tokenizer, num_examples=num_examples, seed=seed) | |
batched_data = [] | |
batched_correct = [] | |
batched_incorrect = [] | |
for i, (prompt, corrects, incorrects) in enumerate(data): | |
if i < 3: | |
print(prompt, corrects, incorrects) | |
batched_data.append(torch.tensor(model.tokenizer.encode(prompt), device=model.cfg.device)) | |
batched_correct.append(model.tokenizer.encode(corrects[0])[0]) | |
batched_incorrect.append(model.tokenizer.encode(incorrects[0])[0]) | |
batched_data = torch.stack(batched_data) | |
batched_correct = torch.tensor(batched_correct) | |
batched_incorrect = torch.tensor(batched_incorrect) | |
points = list(torch.linspace(0, 0.1, 200)) | |
output_accuracies = torch.zeros([len(points)], device=model.cfg.device) | |
output_prs = torch.zeros([len(points)], device=model.cfg.device) | |
output_prs_incorrect = torch.zeros([len(points)], device=model.cfg.device) | |
def resid_pre_hook( | |
resid_pre: Float[torch.Tensor, "B L D"], | |
hook: HookPoint, | |
noise_std: Float, | |
) -> Float[torch.Tensor, "B L D"]: | |
noise = torch.empty(resid_pre.size(), device=model.cfg.device).normal_(mean=0,std=noise_std) | |
return resid_pre + noise | |
for i, noise_std in tqdm.tqdm(enumerate(points)): | |
hook = partial(resid_pre_hook, noise_std=noise_std) | |
hooks = [] | |
for layer in range(model.cfg.n_layers): | |
hooks.append((f"blocks.{layer}.hook_resid_pre", hook)) | |
logits = model.run_with_hooks(input=batched_data, fwd_hooks=hooks, fast_ssm=True, fast_conv=True)[:,-1] | |
prs = torch.nn.functional.softmax(logits, dim=1) | |
correct_prs = prs[torch.arange(num_examples),batched_correct] | |
incorrect_prs = prs[torch.arange(num_examples),batched_incorrect] | |
correct_logits = logits[torch.arange(num_examples),batched_correct] | |
incorrect_logits = logits[torch.arange(num_examples),batched_incorrect] | |
num_correct = torch.sum(correct_logits > incorrect_logits) | |
output_accuracies[i] = num_correct/float(num_examples) | |
output_prs[i] = torch.mean(correct_prs) | |
output_prs_incorrect[i] = torch.mean(incorrect_prs) | |
def bar_chart(data, x_labels, y_label, title, font_size=None): | |
# it requires a pandas dict with the columns and rows named, annoying | |
# by default rows and columns are named with ints so we relabel them accordingly | |
renames = dict([(i, x_labels[i]) for i in range(len(x_labels))]) | |
ps = pd.DataFrame(data.cpu().numpy()).rename(renames, axis='rows').rename({0: y_label}, axis='columns') | |
fig = px.bar(ps, y=y_label, x=x_labels, title=title) | |
if not font_size is None: | |
fig.update_layout( | |
xaxis = dict( | |
tickmode='array', | |
tickvals = x_labels, | |
ticktext = x_labels, | |
), | |
font=dict(size=font_size, color="black")) | |
#fig.update_xaxes(title_font=dict(size=font_size)) | |
fig.show() | |
bar_chart(data=output_accuracies, x_labels=points, y_label='accuracy', title='applying mean zero noise to input of every layer') | |
bar_chart(data=output_prs, x_labels=points, y_label='pr correct answer', title='applying mean zero noise to input of every layer') | |
bar_chart(data=output_prs_incorrect, x_labels=points, y_label='pr incorrect answer', title='applying mean zero noise to input of every layer') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment