Skip to content

Instantly share code, notes, and snippets.

@rbiswasfc
Created October 14, 2024 15:47
Show Gist options
  • Save rbiswasfc/94b01aff807e1318f26b5bb18bb69210 to your computer and use it in GitHub Desktop.
Save rbiswasfc/94b01aff807e1318f26b5bb18bb69210 to your computer and use it in GitHub Desktop.
05_jina_v3_lora_embed.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": false
},
"id": "b23c3628",
"cell_type": "code",
"source": "import os, re, numpy as np, torch, transformers\nfrom transformers import AutoModel, AutoTokenizer",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "c99431d9",
"cell_type": "code",
"source": "np.set_printoptions(precision=3, linewidth=200)\ntorch.set_printoptions(precision=3, linewidth=200)",
"execution_count": 3,
"outputs": []
},
{
"metadata": {},
"id": "b10d69ec",
"cell_type": "markdown",
"source": "### Model & Tokenizer"
},
{
"metadata": {
"trusted": false
},
"id": "5537a730",
"cell_type": "code",
"source": "%%capture\narch = \"jinaai/jina-embeddings-v3\"\ntok = AutoTokenizer.from_pretrained(arch)\nmodel = AutoModel.from_pretrained(arch, trust_remote_code=True).to(\"cuda\")",
"execution_count": 5,
"outputs": []
},
{
"metadata": {},
"id": "4e63f5bc",
"cell_type": "markdown",
"source": "### Examples & Task"
},
{
"metadata": {
"trusted": false
},
"id": "09705a0e",
"cell_type": "code",
"source": "texts = [\"Follow the white rabbit.\", \"Sigue al conejo blanco.\", \"This is a test.\"] # test samples\ntask = \"retrieval.passage\" # \"retrieval.passage\", \"separation\", \"classification\", \"text-matching\"",
"execution_count": 20,
"outputs": []
},
{
"metadata": {},
"id": "169f1343",
"cell_type": "markdown",
"source": "### Get embeddings using model call"
},
{
"metadata": {
"trusted": false
},
"id": "e4c19ec3",
"cell_type": "code",
"source": "inp_texts = [model._task_instructions[task] + t for t in texts]\ninp_texts",
"execution_count": 21,
"outputs": [
{
"data": {
"text/plain": "['Represent the document for retrieval: Follow the white rabbit.',\n 'Represent the document for retrieval: Sigue al conejo blanco.',\n 'Represent the document for retrieval: This is a test.']"
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"id": "32e7744a",
"cell_type": "code",
"source": "inputs = tok(inp_texts, padding=True, return_tensors=\"pt\")\ninputs = {k:v.to(\"cuda\") for k,v in inputs.items()}",
"execution_count": 22,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "b1a57206",
"cell_type": "code",
"source": "task_id = model._adaptation_map[task]\nnum_examples = len(texts)\nadapter_mask = torch.full((num_examples,), task_id, dtype=torch.int32, device=model.device)\nlora_arguments = {\"adapter_mask\": adapter_mask}\nlora_arguments",
"execution_count": 23,
"outputs": [
{
"data": {
"text/plain": "{'adapter_mask': tensor([1, 1, 1], device='cuda:0', dtype=torch.int32)}"
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"id": "c152abf4",
"cell_type": "code",
"source": "# get embeddings\nwith torch.inference_mode():\n token_embs = model.roberta.forward(**inputs, **lora_arguments).last_hidden_state\n\ntoken_embs = token_embs.float() # Accumulate in fp32 to avoid overflow\nembeddings_a = model.roberta.mean_pooling(token_embs, inputs[\"attention_mask\"]) # mean_pooling\nembeddings_a = torch.nn.functional.normalize(embeddings_a, p=2, dim=1) # normalize_embeddings\nembeddings_a = embeddings_a.cpu().numpy() # convert to numpy\nembeddings_a[0][:10]",
"execution_count": 24,
"outputs": [
{
"data": {
"text/plain": "array([-0.148, -0.016, 0.029, -0.026, 0.025, 0.028, 0.031, 0.019, 0.004, -0.036], dtype=float32)"
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"id": "da421285",
"cell_type": "markdown",
"source": "### Get Embeddings using high level API"
},
{
"metadata": {
"trusted": false
},
"id": "46875b01",
"cell_type": "code",
"source": "embeddings_b = model.encode(texts, task=task)\nembeddings_b[0][:10]",
"execution_count": 25,
"outputs": [
{
"data": {
"text/plain": "array([-0.148, -0.016, 0.029, -0.026, 0.025, 0.028, 0.031, 0.019, 0.004, -0.036], dtype=float32)"
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"id": "5cea38ed",
"cell_type": "markdown",
"source": "### Test"
},
{
"metadata": {
"trusted": false
},
"id": "bb2f14a9",
"cell_type": "code",
"source": "np.testing.assert_allclose(embeddings_a, embeddings_b, rtol=1e-5, atol=1e-8)\nassert np.all(np.abs(embeddings_a - embeddings_b) < 1e-6), \"Embeddings are not exactly equal\"\nprint(\"Embeddings are equal within tolerance\")",
"execution_count": 26,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Embeddings are equal within tolerance\n"
}
]
},
{
"metadata": {
"trusted": false
},
"id": "3928ebcd",
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "18323907",
"cell_type": "code",
"source": "# code_dir =\"~/.cache/huggingface/modules/transformers_modules/jinaai/xlm-roberta-flash-implementation/12700ba4972d9e900313a85ae855f5a76fb9500e\"\n# !cat $code_dir/modeling_xlm_roberta.py",
"execution_count": 103,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "d5cc0762",
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"id": "81c702b8",
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3 (ipykernel)",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.10.15",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "05_jina_v3_lora_embed.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment