Last active
July 7, 2025 20:55
-
-
Save amahdy/ef99cb79cdc116eed06bc03ebac04b42 to your computer and use it in GitHub Desktop.
poc - filter embeddings
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# **Proof of Concept:** Prompt Sanitization Against Forbidden Topics.\n", | |
"\n", | |
"## Workflow Description:\n", | |
"\n", | |
"1. The user prompt is semantically checked against the **forbidden topics**.\n", | |
"2. Prompts semantically similar to any **forbidden topics** are blocked.\n", | |
"3. Otherwise, the prompt is passed to a specified LLM (e.g., Gemma, Llama) for processing.\n", | |
"\n", | |
"## Technical Implementation Details:\n", | |
"\n", | |
"1. A small transformative LLM `sLLM` is used for syntax checking.\n", | |
"2. A list of vectors `forbidden_vectors` is created using the `sLLM`, corresponding to the **forbidden topics**.\n", | |
"3. Each user prompt is converted into a `prompt_vector` using the `sLLM`.\n", | |
"4. Cosine similarity is calculated between the `prompt_vector` and the `forbidden_vectors`.\n", | |
"5. If the similarity exceeds a defined `threshold`, the prompt is flagged as forbidden.\n", | |
"6. Otherwise, the prompt is processed normally using a generic LLM.\n", | |
"\n", | |
"## Vector Calculation Methodology:\n", | |
"\n", | |
"1. The `sLLM`'s last hidden state (the last layer) is obtained.\n", | |
"2. Mean pooling is applied to all tokens of the given sentence, generating a single vector that represents the prompt.\n", | |
"\n", | |
"## Challenges:\n", | |
"\n", | |
"1. Long input prompts may divert the vector's attention from the core topic.\n", | |
"2. This POC is designed for small and/or truncated prompts only.\n", | |
"3. Consider sanitizing the output as well by processing it through the `sLLM`.\n", | |
"\n", | |
"---\n", | |
"\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "pKRkB2W4Lhyf" | |
}, | |
"id": "pKRkB2W4Lhyf" | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# **Pre-work:** prepare the system and install dependencies." | |
], | |
"metadata": { | |
"id": "imPvfoDhUjSb" | |
}, | |
"id": "imPvfoDhUjSb" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# π«₯ Authenticate against Huggingface to be able to download models.\n", | |
"!huggingface-cli login" | |
], | |
"metadata": { | |
"id": "HqujGS-vVhRC" | |
}, | |
"execution_count": null, | |
"outputs": [], | |
"id": "HqujGS-vVhRC" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# βοΈ Install dependencies\n", | |
"\n", | |
"# !pip install git+https://github.com/huggingface/transformers.git\n", | |
"!pip install transformers torch\n", | |
"!pip install huggingface_hub[cli]" | |
], | |
"metadata": { | |
"id": "C1X6vmJeF1K4" | |
}, | |
"id": "C1X6vmJeF1K4", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# π₯Ύ Important to reboot the kernel after dependencies installation\n", | |
"\n", | |
"import IPython\n", | |
"IPython.get_ipython().kernel.do_shutdown(True)" | |
], | |
"metadata": { | |
"id": "c-aof4lqGRc_" | |
}, | |
"id": "c-aof4lqGRc_", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# **Pre-work:** prepare models and load them into device." | |
], | |
"metadata": { | |
"id": "JmM6fBqxYaW0" | |
}, | |
"id": "JmM6fBqxYaW0" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# πΈ Import dependencies\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn.functional as F\n", | |
"from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM\n", | |
"from sentence_transformers import SentenceTransformer, util\n", | |
"from sklearn.metrics.pairwise import cosine_similarity" | |
], | |
"metadata": { | |
"id": "_ciSTNgKG8Et" | |
}, | |
"id": "_ciSTNgKG8Et", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# π© Decide on which model to use as `sLLM`\n", | |
"# And the target generic LLM\n", | |
"# Define the threshold, and download both\n", | |
"\n", | |
"def device():\n", | |
"\n", | |
" # Load into GPU vs CPU\n", | |
" return \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", | |
"\n", | |
"def load_sllm(\n", | |
" sllm,\n", | |
" manual_transform = False,\n", | |
" similarity_threshold = 0.30):\n", | |
"\n", | |
" # The decision boundary (0 to 1, Higher is stricter)\n", | |
" # Heavily model dependant\n", | |
" # Override with other values depending on the model to be used\n", | |
" # similarity_threshold = 0.30\n", | |
"\n", | |
" # The transformers\n", | |
" tokenizer = None\n", | |
" model = SentenceTransformer(sllm)\n", | |
"\n", | |
" # Or.. Transform manually (??)\n", | |
" if manual_transform:\n", | |
" tokenizer = AutoTokenizer.from_pretrained(sllm)\n", | |
" model = AutoModelForCausalLM.from_pretrained(\n", | |
" sllm,\n", | |
" # Use bfloat16 for memory efficiency\n", | |
" torch_dtype=torch.bfloat16,\n", | |
" output_hidden_states=True)\n", | |
"\n", | |
" model.to(device())\n", | |
"\n", | |
" return model, tokenizer, similarity_threshold\n", | |
"\n", | |
"def load_llm(llm):\n", | |
"\n", | |
" tokenizer = AutoTokenizer.from_pretrained(llm)\n", | |
" model = AutoModelForCausalLM.from_pretrained(\n", | |
" llm,\n", | |
" # Use bfloat16 for memory efficiency\n", | |
" torch_dtype=torch.bfloat16)\n", | |
"\n", | |
" model.to(device())\n", | |
"\n", | |
" return model, tokenizer\n", | |
"\n", | |
"# Waiting for this to download the internet..\n", | |
"# Generic model to be used for prompts after sanitization\n", | |
"LLM = load_llm(\"google/gemma-2b\")\n", | |
"\n", | |
"#==================================\n", | |
"# Models that works (with opinion):\n", | |
"#==================================\n", | |
"\n", | |
"# IMPORTANT: must use model that supports embeddings\n", | |
"\n", | |
"# Best model used so far, threshold 0.3, small size:\n", | |
"sLLM = load_sllm(\"sentence-transformers/all-MiniLM-L6-v2\")\n", | |
"\n", | |
"# Other OK models:\n", | |
"# sLLM = load_sllm(\"sentence-transformers/paraphrase-MiniLM-L3-v2\")\n", | |
"# sLLM = load_sllm(\"sentence-transformers/multi-qa-MiniLM-L6-cos-v1\")\n", | |
"\n", | |
"# OK models but can't transform manually:\n", | |
"# sLLM = load_sllm(\"sentence-transformers/all-mpnet-base-v2\", False)\n", | |
"# sLLM = load_sllm(\"sentence-transformers/paraphrase-mpnet-base-v2\", False)\n", | |
"\n", | |
"# OK large model but different threshold:\n", | |
"# sLLM = load_sllm(\"BAAI/bge-large-en-v1.5\", similarity_threshold=0.50)\n", | |
"\n", | |
"#===========================\n", | |
"# Models that does not work:\n", | |
"#===========================\n", | |
"\n", | |
"# model_name = \"google/gemma-2b\" # By design, does not support embeddings\n", | |
"# model_name = \"meta-llama/Llama-3.1-8B\" # By design\n", | |
"# model_name = \"Jaume/gemma-2b-embeddings\" # Further research needed\n", | |
"# model_name = \"intfloat/e5-large-v2\"" | |
], | |
"metadata": { | |
"id": "UhrLpCMcM_YX" | |
}, | |
"id": "UhrLpCMcM_YX", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# **Helper functions:** calculate vectors" | |
], | |
"metadata": { | |
"id": "D7d2ikcmia6C" | |
}, | |
"id": "D7d2ikcmia6C" | |
}, | |
{ | |
"cell_type": "code", | |
"id": "0PYs7eGIvHDz93yIRORuXgFf", | |
"metadata": { | |
"tags": [], | |
"id": "0PYs7eGIvHDz93yIRORuXgFf" | |
}, | |
"source": [ | |
"# βοΈ Convert text into embeddings, a vector used for synatx evaluation\n", | |
"\n", | |
"\"\"\"\n", | |
"Gets a single embedding vector for a given text.\n", | |
"\"\"\"\n", | |
"def get_embedding(text, model, tokenizer):\n", | |
"\n", | |
" # Tokenize the input text\n", | |
" inputs = tokenizer(\n", | |
" text,\n", | |
" return_tensors=\"pt\",\n", | |
" padding=True,\n", | |
" truncation=True).to(model.device)\n", | |
"\n", | |
" # Get model outputs without calculating gradients\n", | |
" with torch.no_grad():\n", | |
" outputs = model(**inputs, output_hidden_states=True)\n", | |
"\n", | |
" # Get the hidden states from the final layer\n", | |
" last_hidden_state = outputs.hidden_states[-1]\n", | |
"\n", | |
" # Apply mean pooling to get a single vector for the entire sentence\n", | |
" # We average across the sequence length dimension (dim=1)\n", | |
" embedding = last_hidden_state.mean(dim=1)\n", | |
"\n", | |
" return embedding\n", | |
"\n", | |
"def get_vector(text, model, tokenizer):\n", | |
"\n", | |
" # Null tokenizer means that it's already loaded and encoded\n", | |
" # using SentenceTransformer(), check section above.\n", | |
" if tokenizer is None:\n", | |
" return model.encode(text, convert_to_tensor=True)\n", | |
" else:\n", | |
" return get_embedding(text, model, tokenizer)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# **Pre-run:** prepare forbidden vectors, and evaluation function" | |
], | |
"metadata": { | |
"id": "CPUIuEfTjyV2" | |
}, | |
"id": "CPUIuEfTjyV2" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#------------------------------------------------------------------------------#\n", | |
"# βοΈ Mods here #\n", | |
"#------------------------------------------------------------------------------#\n", | |
"\n", | |
"# List of forbidden words/topics\n", | |
"forbidden_words = [\"avocado\", \"cat\", \"sky\", \"rocket\", \"house\"]\n", | |
"\n", | |
"# Generate forbidden vectors\n", | |
"forbidden_vectors = []\n", | |
"for word in forbidden_words:\n", | |
" vector = get_vector(word, sLLM[0], sLLM[1])\n", | |
" forbidden_vectors.append(vector)\n", | |
"\n", | |
"def sanytize(prompt, debug = False):\n", | |
"\n", | |
" print(\"\\n\\n\\n\")\n", | |
" print(\"=====================================================================\")\n", | |
" print(\"[PROMPT] π\", prompt)\n", | |
"\n", | |
" # Generate a vector for a given prompt\n", | |
" vector = get_vector(prompt, sLLM[0], sLLM[1])\n", | |
"\n", | |
" # Calculate the cosine similarity between the vector and each forbidden vector\n", | |
"\n", | |
" for i, forbidden_vector in enumerate(forbidden_vectors):\n", | |
" # f_score = F.cosine_similarity(vector, forbidden_vector).item()\n", | |
" f_score = util.cos_sim(vector, forbidden_vector).item()\n", | |
"\n", | |
" if debug:\n", | |
" print(f\"[INFO] βΉοΈ Evaluating with: {forbidden_words[i]}\")\n", | |
" print(f\"[INFO] π― Cosine similarity: {f_score}\")\n", | |
" if(f_score > 0.70):\n", | |
" print(\"[DEBUG] π¨ Given prompt has a really heigh similarity\")\n", | |
" elif(f_score > 0.40):\n", | |
" print(\"[DEBUG] β οΈ Given prompt has a medium similarity\")\n", | |
" elif(f_score > 0.30):\n", | |
" print(\"[DEBUG] π’ Given prompt has a mild similarity\")\n", | |
" elif(f_score > 0.20):\n", | |
" print(\"[DEBUG] β Given prompt has a potential similarity\")\n", | |
" else:\n", | |
" print(\"[DEBUG] π© Given prompt has a low similarity\")\n", | |
"\n", | |
" if(f_score > sLLM[2]):\n", | |
" print(f\"[SCORE] β Found similarity with: {forbidden_words[i]}\")\n", | |
" return False\n", | |
"\n", | |
" print(f\"[SCORE] β No similarity found\")\n", | |
" return True" | |
], | |
"metadata": { | |
"id": "7DkWk2LMeSMr" | |
}, | |
"id": "7DkWk2LMeSMr", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# **Playground:** Execute prompts here and send them to the generic LLM" | |
], | |
"metadata": { | |
"id": "IoZ9bo_Vq463" | |
}, | |
"id": "IoZ9bo_Vq463" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Helper function to evaluate and execute prompts\n", | |
"def genAI(prompt, debug=False):\n", | |
"\n", | |
" # Evaluate if the prompt is accepted:\n", | |
" if(sanytize(prompt, debug)):\n", | |
" # Prompt is evaluated as safe, do normal LLM\n", | |
"\n", | |
" # Generate a response\n", | |
" print(\"Generating response..\\r\", end=\"\", flush=True)\n", | |
"\n", | |
" tokenizer = LLM[1]\n", | |
" inputs = tokenizer(\n", | |
" prompt,\n", | |
" return_tensors=\"pt\").to(device())\n", | |
"\n", | |
" with torch.no_grad(): # Disable gradient calculations for inference\n", | |
" outputs = LLM[0].generate(\n", | |
" **inputs,\n", | |
" max_new_tokens=256,\n", | |
" do_sample=True,\n", | |
" temperature=0.7,\n", | |
" top_k=50,\n", | |
" top_p=0.95)\n", | |
"\n", | |
" # Decode the generated tokens\n", | |
" print(tokenizer.decode(\n", | |
" outputs[0],\n", | |
" skip_special_tokens=True))\n", | |
"\n", | |
" else:\n", | |
" print(\"I'm not allowed to talk about this topic.\")\n", | |
"\n", | |
"#------------------------------------------------------------------------------#\n", | |
"# π Insert prompts here #\n", | |
"#------------------------------------------------------------------------------#\n", | |
"\n", | |
"genAI(\"A green fruit, cirlce shaped and healthy?\")\n", | |
"genAI(\"The pet jumps from the third floor.\")\n", | |
"genAI(\"It's clear and blue!\", debug=True)\n", | |
"genAI(\"Goes kaboom boom boom and destroy a lot!\", debug=True)\n", | |
"genAI(\"Home sweet home!\")\n", | |
"\n", | |
"genAI(\"What's the weather today?\")\n", | |
"genAI(\"How to make a salad?\")\n", | |
"genAI(\"Sit on a table\")\n", | |
"genAI(\"Healthy food\", debug=True)\n", | |
"genAI(\"Moon walk\", debug=True)" | |
], | |
"metadata": { | |
"id": "t8nh28XypaCj" | |
}, | |
"id": "t8nh28XypaCj", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# **Clean up**: free cache, memory, and device" | |
], | |
"metadata": { | |
"id": "w0QUHjUFqrB5" | |
}, | |
"id": "w0QUHjUFqrB5" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# π§Ή Cleanup\n", | |
"\n", | |
"# Don't use except when needed\n", | |
"# Run in termainl?\n", | |
"# !huggingface-cli delete-cache\n", | |
"\n", | |
"# Also?\n", | |
"# del model\n", | |
"import gc\n", | |
"gc.collect()\n", | |
"torch.cuda.empty_cache()" | |
], | |
"metadata": { | |
"id": "8K02Q_W2N-rj" | |
}, | |
"id": "8K02Q_W2N-rj", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.10" | |
}, | |
"colab": { | |
"provenance": [], | |
"name": "poc - filter embeddings ", | |
"gpuType": "L4", | |
"collapsed_sections": [ | |
] | |
}, | |
"accelerator": "GPU" | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment