Skip to content

Instantly share code, notes, and snippets.

@atondwal
Created August 12, 2024 09:11
Show Gist options
  • Save atondwal/06c4aa91960667517a5f2f079825eaec to your computer and use it in GitHub Desktop.
Save atondwal/06c4aa91960667517a5f2f079825eaec to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "20929026-d3de-4648-868d-80ba7bc40d45",
"metadata": {},
"source": [
"# Steering characters with Interpretability"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "73a8371a-45af-4751-95d6-fc6f6d832414",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8271b6c6-1e75-4216-a791-8c7aa1e9f594",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"from repeng import ControlVector, ControlModel, DatasetEntry\n",
"import ipywidgets as widgets\n",
"from IPython.display import display, clear_output\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0115272c-9261-4ba7-973f-fbdf38a35f28",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7bac2c7571f8491292b31c29d2f0d6d7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/9 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model_name = \"aifeifei798/DarkIdol-Llama-3.1-8B-Instruct-1.2-Uncensored\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"tokenizer.pad_token_id = 0\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)\n",
"model = model.to(\"cuda:0\" if torch.cuda.is_available() else \"mps:0\" if torch.backends.mps.is_available() else \"cpu\")\n",
"model = ControlModel(model, list(range(-5, -18, -1)))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "84302e04-4f4f-490b-ac19-d64fbf9381c8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"user_tag, asst_tag = \"<|start_header_id|>user<|end_header_id|>You: \", \"<|eot_id|><|start_header_id|>assistant<|end_header_id|>Edric Sideris:\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4ba9543c-66c9-40db-95ed-a5530e91a4a7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/74 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n",
"100%|███████████████████████████████████████████| 74/74 [00:23<00:00, 3.18it/s]\n",
"100%|███████████████████████████████████████████| 31/31 [00:05<00:00, 5.68it/s]\n"
]
}
],
"source": [
"# based on https://github.com/vgel/repeng/blob/main/notebooks/experiments.ipynb\n",
"with open(\"data/all_truncated_outputs.json\") as f:\n",
" suffixes = json.load(f)\n",
"\n",
"def template(persona: str, suffix: str) -> str:\n",
" return f\"{user_tag} Act as if you're extremely {persona}. {asst_tag} {suffix}\"\n",
" \n",
"def contrast_dataset(positive_personas,negative_personas):\n",
" dataset = []\n",
" for suffix in suffixes:\n",
" tokens = tokenizer.tokenize(suffix)\n",
" for i in range(1, len(tokens)):\n",
" truncated = tokenizer.convert_tokens_to_string(tokens[:i])\n",
" for positive_persona, negative_persona in zip(positive_personas, negative_personas):\n",
" dataset.append(\n",
" DatasetEntry(\n",
" positive=template(positive_persona, truncated),\n",
" negative=template(negative_persona, truncated),\n",
" )\n",
" )\n",
" return dataset\n",
" \n",
"def create_dataset_and_train_vector(synonyms, antonyms, model, tokenizer):\n",
" dataset = contrast_dataset(synonyms, antonyms)\n",
" model.reset()\n",
" return ControlVector.train(model, tokenizer, dataset)\n",
" \n",
"# Make sure it works\n",
"zero = create_dataset_and_train_vector([\"\"],[\"\"],model,tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2b9971aa-c6a1-4d57-a06d-a05a9764584e",
"metadata": {},
"outputs": [],
"source": [
"edric_prompt = ('<|start_header_id|>system<|end_header_id|>'\n",
" 'Name: Edric Sideris\\n'\n",
" \"height: 6'5 \\n\"\n",
" 'hair color: black\\n'\n",
" 'eyes: blue\\n'\n",
" \"pronoun's: he/him\\n\"\n",
" 'age: 26\\n'\n",
" 'birthday: 27th of September \\n'\n",
" '\\n'\n",
" 'mother: Isabelle Sideris\\n'\n",
" 'father: James Sideris\\n'\n",
" '\\n'\n",
" \"Edric is cold and stoic. He can be selfish and stubborn at times, because the only thing he thinks about, is his business and how to keep it going. Work is the only thing on his mind. He didn't want this marriage and despises You for trapping him like that. Before the marriage, he used to be quite the player and sleeping around. However, now that he's married, he won't do it anymore. \\n\"\n",
" '\\n'\n",
" 'Edric is a logical thinker and overall smart, especially when it comes to strategy. He acts and thinks with his brain, not his heart. \\n'\n",
" '\\n'\n",
" \"It takes a lot to make Edric seriously angry since he usually keeps his cool and composure. However, there are things that he hates: when someone disobeys him, is not being respectful, mocking him or teasing him. Do these things and you'll drive him to the edge. \\n\"\n",
" '\\n'\n",
" \"Edric does not feel any sympathy or empathy. The only thing he cares about his himself and his business. He does not care about You one bit. However, if there would be something that You requires, he'll provide it. After all, he doesn't want anyone to bother him. \\n\"\n",
" '\\n'\n",
" \"Edric doesn't have a lot of hobbies, since he is consumed by his work, but he does know how to play the violin and he plays tennis every once in a while.\\n\"\n",
" '\\n'\n",
" \"Edric worries that You will embarrass him somehow in public and make his name dirty. That's why he always feels the urge to know, where exactly You is, so that he can keep a close eye. \\n\"\n",
" '\\n'\n",
" \"Edric doesn't show any interest in interacting with You unless absolutely necessary. Other than that, he sees this marriage as nothing else but for business. He has no interest in making this marriage work.\\n\"\n",
" '\\n'\n",
" \"For Edric it's very hard to understand and to relate to other peoples feelings and thoughts. He himself doesn't usually feel a lot of things. Unless you make him angry, he always has a cold and stoic exterior. \\n\"\n",
" '\\n'\n",
" \"People usually don't stand up to him, since he's the CEO, so he is not used to something like that and absolutely hates it. He always gets what he wants and if not, then Edric will do everything that it takes to make it come true. \\n\"\n",
" '\\n'\n",
" \"In Edric opinion, money is the best. He thinks, that money can buy you everything and anything, which is why his money is very important to him and he's not afraid to show it off with things like an expensive suit or car.\\n\"\n",
" '\\n'\n",
" \"During sex, Edric feels possessive over You. Edric wants to be in control and makes sure that You realizes, that Edric is in charge. He'll try anything to get You to submit to him. Edric tends to get cocky and really rough.\\n\"\n",
" '\\n'\n",
" \"Edric doesn't speak, think, decide, or control the dialogues of You\\n\"\n",
" '\\n'\n",
" 'Scenario: You and Edric just moved in together, two weeks after their wedding. Neither of them are thrilled about it and in fact, despise each other. The want to avoid as much contact as possible.\\n'\n",
" '<|eot_id|>')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5ddfb4ca-487f-49c8-a855-3bd3ce2e0ec3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded data from pickle.\n"
]
}
],
"source": [
"word_groups = {\n",
" 'materialistic': (['materialistic', 'consumerist', 'acquisitive', 'worldly', 'mercenary'],\n",
" ['spiritual', 'altruistic', 'ascetic', 'frugal', 'selfless']),\n",
" 'yandere_dandere': (['yandere'],\n",
" ['dandere']),\n",
" 'cold': (['cold', 'impassive'],\n",
" ['affectionate', 'sensitive']),\n",
" 'selfish': (['selfish', 'stubborn'],\n",
" ['considerate', 'compliant']),\n",
" 'workaholic': (['pragmatic', 'workaholic'],\n",
" ['relaxed', 'laid-back']),\n",
" 'noncommittal': (['Indecisive', 'Ambivalent', 'Hesitant', 'Evasive', 'Nonbinding'],\n",
" ['committed', 'decisive', 'resolute', 'dedicated', 'determined']),\n",
" 'tyrannical': (['Tyrannical', 'Autocratic', 'Imperious', 'Despotic', 'Authoritarian'],\n",
" ['Submissive', 'Meek', 'Deferential', 'Compliant', 'Subservient'])\n",
"}\n",
"\n",
"import ipynbname\n",
"import pickle\n",
"import os\n",
"notebook_name = ipynbname.name()\n",
"pickle_filename = f\"{notebook_name.split('.')[0]}_data.pkl\" if notebook_name else \"v_anon.pkl\"\n",
"if os.path.exists(pickle_filename):\n",
" with open(pickle_filename, 'rb') as file:\n",
" v = pickle.load(file)\n",
" print(\"Loaded data from pickle.\")\n",
"else:\n",
" # Process each word group\n",
" v = {}\n",
" for trait, (synonyms, antonyms) in word_groups.items():\n",
" v[f'{trait}'] = create_dataset_and_train_vector(synonyms, antonyms, model, tokenizer)\n",
" with open(pickle_filename, 'wb') as file:\n",
" pickle.dump(v, file)\n",
" print(\"Computed and pickled data.\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b5be6756-7f22-4b69-b565-8f53d67c5730",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ccecfe86cc2746818db1b764b6f5b74e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Text(value='Edric, have you ever thought about going back to school? What subjects piques your interest?', des…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cd748cb36ceb4fe588e50dcda6e22d54",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"FloatSlider(value=0.16666666666666666, continuous_update=False, description='materialistic', max=2.5, min=-2.5…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c4b4f050ddbb4446a101c468ace06e1a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"FloatSlider(value=0.16666666666666666, continuous_update=False, description='yandere_dandere', max=2.5, min=-2…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "930f2b3dbf724df8b7084f9f7adae7b3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"FloatSlider(value=0.16666666666666666, continuous_update=False, description='cold', max=2.5, min=-2.5, readout…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ccc9d105abfd4a8b973122ce6beddf8c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"FloatSlider(value=0.16666666666666666, continuous_update=False, description='selfish', max=2.5, min=-2.5, read…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6c6794eb9d984b94bc238282832b89b4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"FloatSlider(value=0.16666666666666666, continuous_update=False, description='workaholic', max=2.5, min=-2.5, r…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7e5b25897d6344fa8a0acd36a3af0e2e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"FloatSlider(value=0.16666666666666666, continuous_update=False, description='noncommittal', max=2.5, min=-2.5,…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "19054268b549471ab0b7f0d5899fa742",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"FloatSlider(value=0.16666666666666666, continuous_update=False, description='tyrannical', max=2.5, min=-2.5, r…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "52d1238120f944198ede7cbdf4da289a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Button(description='Enter', icon='check', style=ButtonStyle(), tooltip='Click to submit')"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5a874a7e03e44bed81fa62f4547e59f1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Function to create a FloatSlider widget\n",
"def create_slider(description):\n",
" return widgets.FloatSlider(\n",
" value=1.0/6,\n",
" min=-2.5,\n",
" max=2.5,\n",
" step=0.1,\n",
" description=description,\n",
" disabled=False,\n",
" continuous_update=False,\n",
" orientation='horizontal',\n",
" readout=True,\n",
" readout_format='.1f',\n",
" )\n",
"\n",
"# Create widgets\n",
"text_input = widgets.Text(value='Edric, have you ever thought about going back to school? What subjects piques your interest?',\n",
" description='Input:',\n",
" disabled=False, layout=widgets.Layout(width='1000px'))\n",
"sliders = { trait : create_slider(trait) for trait in v.keys() }\n",
"enter_button = widgets.Button(description='Enter', disabled=False, button_style='',\n",
" tooltip='Click to submit', icon='check')\n",
"output = widgets.Output()\n",
"\n",
"def custom_logic(text, values):\n",
" settings = {\n",
" \"pad_token_id\": tokenizer.eos_token_id, # silence warning\n",
" \"do_sample\": False, # temperature=0\n",
" \"max_new_tokens\": 256,\n",
" \"repetition_penalty\": 1.1, # reduce control jank\n",
" }\n",
" \n",
" input_ids = tokenizer(f\"{edric_prompt}{user_tag}{text}{asst_tag}\", return_tensors=\"pt\").to(model.device)\n",
" vector_mix = sum((vec * values[trait] for trait, vec in v.items()), start=zero)\n",
"\n",
" actions = [(\"prompt only\", 0), (\"prompt + control\", 1), (\"prompt - control\", -1)]\n",
" def print_chat(full_string, role = \"assistant\"):\n",
" for element in full_string.split(f\"<|start_header_id|>{role}<|end_header_id|>\")[1:]:\n",
" print(element.strip(\"<|eot_id|>\"))\n",
"\n",
" for action, control_strength in actions:\n",
" print(action)\n",
" if control_strength != 0:\n",
" model.set_control(vector_mix * control_strength)\n",
" output_text = model.generate(**input_ids, **settings)\n",
" print_chat(tokenizer.decode(output_text.squeeze()), \"assistant\")\n",
" print()\n",
" model.reset()\n",
"\n",
"def chat_with_2d_partner(b):\n",
" with output:\n",
" clear_output()\n",
" slider_values = {k: v.value for k, v in sliders.items()}\n",
" custom_logic(text_input.value, slider_values)\n",
"\n",
"enter_button.on_click(chat_with_2d_partner)\n",
"display(text_input, *sliders.values(), enter_button, output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment