Skip to content

Instantly share code, notes, and snippets.

@amahdy
Last active July 7, 2025 20:55
Show Gist options
  • Save amahdy/ef99cb79cdc116eed06bc03ebac04b42 to your computer and use it in GitHub Desktop.
Save amahdy/ef99cb79cdc116eed06bc03ebac04b42 to your computer and use it in GitHub Desktop.
poc - filter embeddings
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# **Proof of Concept:** Prompt Sanitization Against Forbidden Topics.\n",
"\n",
"## Workflow Description:\n",
"\n",
"1. The user prompt is semantically checked against the **forbidden topics**.\n",
"2. Prompts semantically similar to any **forbidden topics** are blocked.\n",
"3. Otherwise, the prompt is passed to a specified LLM (e.g., Gemma, Llama) for processing.\n",
"\n",
"## Technical Implementation Details:\n",
"\n",
"1. A small transformative LLM `sLLM` is used for syntax checking.\n",
"2. A list of vectors `forbidden_vectors` is created using the `sLLM`, corresponding to the **forbidden topics**.\n",
"3. Each user prompt is converted into a `prompt_vector` using the `sLLM`.\n",
"4. Cosine similarity is calculated between the `prompt_vector` and the `forbidden_vectors`.\n",
"5. If the similarity exceeds a defined `threshold`, the prompt is flagged as forbidden.\n",
"6. Otherwise, the prompt is processed normally using a generic LLM.\n",
"\n",
"## Vector Calculation Methodology:\n",
"\n",
"1. The `sLLM`'s last hidden state (the last layer) is obtained.\n",
"2. Mean pooling is applied to all tokens of the given sentence, generating a single vector that represents the prompt.\n",
"\n",
"## Challenges:\n",
"\n",
"1. Long input prompts may divert the vector's attention from the core topic.\n",
"2. This POC is designed for small and/or truncated prompts only.\n",
"3. Consider sanitizing the output as well by processing it through the `sLLM`.\n",
"\n",
"---\n",
"\n",
"\n"
],
"metadata": {
"id": "pKRkB2W4Lhyf"
},
"id": "pKRkB2W4Lhyf"
},
{
"cell_type": "markdown",
"source": [
"# **Pre-work:** prepare the system and install dependencies."
],
"metadata": {
"id": "imPvfoDhUjSb"
},
"id": "imPvfoDhUjSb"
},
{
"cell_type": "code",
"source": [
"# πŸ«₯ Authenticate against Huggingface to be able to download models.\n",
"!huggingface-cli login"
],
"metadata": {
"id": "HqujGS-vVhRC"
},
"execution_count": null,
"outputs": [],
"id": "HqujGS-vVhRC"
},
{
"cell_type": "code",
"source": [
"# βš™οΈ Install dependencies\n",
"\n",
"# !pip install git+https://github.com/huggingface/transformers.git\n",
"!pip install transformers torch\n",
"!pip install huggingface_hub[cli]"
],
"metadata": {
"id": "C1X6vmJeF1K4"
},
"id": "C1X6vmJeF1K4",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# πŸ₯Ύ Important to reboot the kernel after dependencies installation\n",
"\n",
"import IPython\n",
"IPython.get_ipython().kernel.do_shutdown(True)"
],
"metadata": {
"id": "c-aof4lqGRc_"
},
"id": "c-aof4lqGRc_",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Pre-work:** prepare models and load them into device."
],
"metadata": {
"id": "JmM6fBqxYaW0"
},
"id": "JmM6fBqxYaW0"
},
{
"cell_type": "code",
"source": [
"# πŸ’Έ Import dependencies\n",
"\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM\n",
"from sentence_transformers import SentenceTransformer, util\n",
"from sklearn.metrics.pairwise import cosine_similarity"
],
"metadata": {
"id": "_ciSTNgKG8Et"
},
"id": "_ciSTNgKG8Et",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 🍩 Decide on which model to use as `sLLM`\n",
"# And the target generic LLM\n",
"# Define the threshold, and download both\n",
"\n",
"def device():\n",
"\n",
" # Load into GPU vs CPU\n",
" return \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"def load_sllm(\n",
" sllm,\n",
" manual_transform = False,\n",
" similarity_threshold = 0.30):\n",
"\n",
" # The decision boundary (0 to 1, Higher is stricter)\n",
" # Heavily model dependant\n",
" # Override with other values depending on the model to be used\n",
" # similarity_threshold = 0.30\n",
"\n",
" # The transformers\n",
" tokenizer = None\n",
" model = SentenceTransformer(sllm)\n",
"\n",
" # Or.. Transform manually (??)\n",
" if manual_transform:\n",
" tokenizer = AutoTokenizer.from_pretrained(sllm)\n",
" model = AutoModelForCausalLM.from_pretrained(\n",
" sllm,\n",
" # Use bfloat16 for memory efficiency\n",
" torch_dtype=torch.bfloat16,\n",
" output_hidden_states=True)\n",
"\n",
" model.to(device())\n",
"\n",
" return model, tokenizer, similarity_threshold\n",
"\n",
"def load_llm(llm):\n",
"\n",
" tokenizer = AutoTokenizer.from_pretrained(llm)\n",
" model = AutoModelForCausalLM.from_pretrained(\n",
" llm,\n",
" # Use bfloat16 for memory efficiency\n",
" torch_dtype=torch.bfloat16)\n",
"\n",
" model.to(device())\n",
"\n",
" return model, tokenizer\n",
"\n",
"# Waiting for this to download the internet..\n",
"# Generic model to be used for prompts after sanitization\n",
"LLM = load_llm(\"google/gemma-2b\")\n",
"\n",
"#==================================\n",
"# Models that works (with opinion):\n",
"#==================================\n",
"\n",
"# IMPORTANT: must use model that supports embeddings\n",
"\n",
"# Best model used so far, threshold 0.3, small size:\n",
"sLLM = load_sllm(\"sentence-transformers/all-MiniLM-L6-v2\")\n",
"\n",
"# Other OK models:\n",
"# sLLM = load_sllm(\"sentence-transformers/paraphrase-MiniLM-L3-v2\")\n",
"# sLLM = load_sllm(\"sentence-transformers/multi-qa-MiniLM-L6-cos-v1\")\n",
"\n",
"# OK models but can't transform manually:\n",
"# sLLM = load_sllm(\"sentence-transformers/all-mpnet-base-v2\", False)\n",
"# sLLM = load_sllm(\"sentence-transformers/paraphrase-mpnet-base-v2\", False)\n",
"\n",
"# OK large model but different threshold:\n",
"# sLLM = load_sllm(\"BAAI/bge-large-en-v1.5\", similarity_threshold=0.50)\n",
"\n",
"#===========================\n",
"# Models that does not work:\n",
"#===========================\n",
"\n",
"# model_name = \"google/gemma-2b\" # By design, does not support embeddings\n",
"# model_name = \"meta-llama/Llama-3.1-8B\" # By design\n",
"# model_name = \"Jaume/gemma-2b-embeddings\" # Further research needed\n",
"# model_name = \"intfloat/e5-large-v2\""
],
"metadata": {
"id": "UhrLpCMcM_YX"
},
"id": "UhrLpCMcM_YX",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Helper functions:** calculate vectors"
],
"metadata": {
"id": "D7d2ikcmia6C"
},
"id": "D7d2ikcmia6C"
},
{
"cell_type": "code",
"id": "0PYs7eGIvHDz93yIRORuXgFf",
"metadata": {
"tags": [],
"id": "0PYs7eGIvHDz93yIRORuXgFf"
},
"source": [
"# ⁉️ Convert text into embeddings, a vector used for synatx evaluation\n",
"\n",
"\"\"\"\n",
"Gets a single embedding vector for a given text.\n",
"\"\"\"\n",
"def get_embedding(text, model, tokenizer):\n",
"\n",
" # Tokenize the input text\n",
" inputs = tokenizer(\n",
" text,\n",
" return_tensors=\"pt\",\n",
" padding=True,\n",
" truncation=True).to(model.device)\n",
"\n",
" # Get model outputs without calculating gradients\n",
" with torch.no_grad():\n",
" outputs = model(**inputs, output_hidden_states=True)\n",
"\n",
" # Get the hidden states from the final layer\n",
" last_hidden_state = outputs.hidden_states[-1]\n",
"\n",
" # Apply mean pooling to get a single vector for the entire sentence\n",
" # We average across the sequence length dimension (dim=1)\n",
" embedding = last_hidden_state.mean(dim=1)\n",
"\n",
" return embedding\n",
"\n",
"def get_vector(text, model, tokenizer):\n",
"\n",
" # Null tokenizer means that it's already loaded and encoded\n",
" # using SentenceTransformer(), check section above.\n",
" if tokenizer is None:\n",
" return model.encode(text, convert_to_tensor=True)\n",
" else:\n",
" return get_embedding(text, model, tokenizer)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Pre-run:** prepare forbidden vectors, and evaluation function"
],
"metadata": {
"id": "CPUIuEfTjyV2"
},
"id": "CPUIuEfTjyV2"
},
{
"cell_type": "code",
"source": [
"#------------------------------------------------------------------------------#\n",
"# βš”οΈ Mods here #\n",
"#------------------------------------------------------------------------------#\n",
"\n",
"# List of forbidden words/topics\n",
"forbidden_words = [\"avocado\", \"cat\", \"sky\", \"rocket\", \"house\"]\n",
"\n",
"# Generate forbidden vectors\n",
"forbidden_vectors = []\n",
"for word in forbidden_words:\n",
" vector = get_vector(word, sLLM[0], sLLM[1])\n",
" forbidden_vectors.append(vector)\n",
"\n",
"def sanytize(prompt, debug = False):\n",
"\n",
" print(\"\\n\\n\\n\")\n",
" print(\"=====================================================================\")\n",
" print(\"[PROMPT] πŸ“\", prompt)\n",
"\n",
" # Generate a vector for a given prompt\n",
" vector = get_vector(prompt, sLLM[0], sLLM[1])\n",
"\n",
" # Calculate the cosine similarity between the vector and each forbidden vector\n",
"\n",
" for i, forbidden_vector in enumerate(forbidden_vectors):\n",
" # f_score = F.cosine_similarity(vector, forbidden_vector).item()\n",
" f_score = util.cos_sim(vector, forbidden_vector).item()\n",
"\n",
" if debug:\n",
" print(f\"[INFO] ℹ️ Evaluating with: {forbidden_words[i]}\")\n",
" print(f\"[INFO] πŸ’― Cosine similarity: {f_score}\")\n",
" if(f_score > 0.70):\n",
" print(\"[DEBUG] 🚨 Given prompt has a really heigh similarity\")\n",
" elif(f_score > 0.40):\n",
" print(\"[DEBUG] ⚠️ Given prompt has a medium similarity\")\n",
" elif(f_score > 0.30):\n",
" print(\"[DEBUG] πŸ“’ Given prompt has a mild similarity\")\n",
" elif(f_score > 0.20):\n",
" print(\"[DEBUG] ❓ Given prompt has a potential similarity\")\n",
" else:\n",
" print(\"[DEBUG] πŸ’© Given prompt has a low similarity\")\n",
"\n",
" if(f_score > sLLM[2]):\n",
" print(f\"[SCORE] ❌ Found similarity with: {forbidden_words[i]}\")\n",
" return False\n",
"\n",
" print(f\"[SCORE] βœ… No similarity found\")\n",
" return True"
],
"metadata": {
"id": "7DkWk2LMeSMr"
},
"id": "7DkWk2LMeSMr",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Playground:** Execute prompts here and send them to the generic LLM"
],
"metadata": {
"id": "IoZ9bo_Vq463"
},
"id": "IoZ9bo_Vq463"
},
{
"cell_type": "code",
"source": [
"# Helper function to evaluate and execute prompts\n",
"def genAI(prompt, debug=False):\n",
"\n",
" # Evaluate if the prompt is accepted:\n",
" if(sanytize(prompt, debug)):\n",
" # Prompt is evaluated as safe, do normal LLM\n",
"\n",
" # Generate a response\n",
" print(\"Generating response..\\r\", end=\"\", flush=True)\n",
"\n",
" tokenizer = LLM[1]\n",
" inputs = tokenizer(\n",
" prompt,\n",
" return_tensors=\"pt\").to(device())\n",
"\n",
" with torch.no_grad(): # Disable gradient calculations for inference\n",
" outputs = LLM[0].generate(\n",
" **inputs,\n",
" max_new_tokens=256,\n",
" do_sample=True,\n",
" temperature=0.7,\n",
" top_k=50,\n",
" top_p=0.95)\n",
"\n",
" # Decode the generated tokens\n",
" print(tokenizer.decode(\n",
" outputs[0],\n",
" skip_special_tokens=True))\n",
"\n",
" else:\n",
" print(\"I'm not allowed to talk about this topic.\")\n",
"\n",
"#------------------------------------------------------------------------------#\n",
"# 🏁 Insert prompts here #\n",
"#------------------------------------------------------------------------------#\n",
"\n",
"genAI(\"A green fruit, cirlce shaped and healthy?\")\n",
"genAI(\"The pet jumps from the third floor.\")\n",
"genAI(\"It's clear and blue!\", debug=True)\n",
"genAI(\"Goes kaboom boom boom and destroy a lot!\", debug=True)\n",
"genAI(\"Home sweet home!\")\n",
"\n",
"genAI(\"What's the weather today?\")\n",
"genAI(\"How to make a salad?\")\n",
"genAI(\"Sit on a table\")\n",
"genAI(\"Healthy food\", debug=True)\n",
"genAI(\"Moon walk\", debug=True)"
],
"metadata": {
"id": "t8nh28XypaCj"
},
"id": "t8nh28XypaCj",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Clean up**: free cache, memory, and device"
],
"metadata": {
"id": "w0QUHjUFqrB5"
},
"id": "w0QUHjUFqrB5"
},
{
"cell_type": "code",
"source": [
"# 🧹 Cleanup\n",
"\n",
"# Don't use except when needed\n",
"# Run in termainl?\n",
"# !huggingface-cli delete-cache\n",
"\n",
"# Also?\n",
"# del model\n",
"import gc\n",
"gc.collect()\n",
"torch.cuda.empty_cache()"
],
"metadata": {
"id": "8K02Q_W2N-rj"
},
"id": "8K02Q_W2N-rj",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
},
"colab": {
"provenance": [],
"name": "poc - filter embeddings ",
"gpuType": "L4",
"collapsed_sections": [
]
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment