Last active
June 8, 2024 02:27
-
-
Save 0xh3x/9a6abdbffa821091b6ff822029e1b644 to your computer and use it in GitHub Desktop.
VectorGraph - PoC from Memory Hackathon
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyPjIZOlMiymdw5WceeA/FR4", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/0xh3x/9a6abdbffa821091b6ff822029e1b644/vectorgraph.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# VectorGraph\n", | |
"Memory hackathon\n", | |
"\n", | |
"Slides: https://docs.google.com/presentation/d/1nGW2xEakHgR2TgbbdtoHrTlxU7F2IwcWQ8VcCNEHQiA/edit?usp=sharing" | |
], | |
"metadata": { | |
"id": "LB7f3xRGt1Qv" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "9GFN8iPWxK7V" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install -U langchain pypdf pymongo openai python-dotenv tiktoken" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from google.colab import userdata" | |
], | |
"metadata": { | |
"id": "lo0Pu_W-yKnd" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import openai\n", | |
"openai.api_key = userdata.get('OPENAI_API_KEY')" | |
], | |
"metadata": { | |
"id": "y38oKcsI_GW4" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from openai import OpenAI\n", | |
"openai_client = OpenAI(api_key=userdata.get('OPENAI_API_KEY'))\n", | |
"\n", | |
"def get_embedding(text: str) -> list[float]:\n", | |
" response = openai_client.embeddings.create(\n", | |
" input=text,\n", | |
" model=\"text-embedding-3-large\",\n", | |
" dimensions=2048\n", | |
" )\n", | |
" return response.data[0].embedding\n", | |
"print(len(get_embedding(\"apple\")))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "11Fee_mRy74V", | |
"outputId": "bc9738dc-dd28-4cd4-c17a-3700319ef166" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"2048\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from pymongo import MongoClient\n", | |
"\n", | |
"DB_NAME = \"semanticgraphdb\"\n", | |
"COLLECTION_NAME = \"SemanticGraphDb\"\n", | |
"ATLAS_VECTOR_SEARCH_INDEX_NAME = \"default\"\n", | |
"EMBEDDING_FIELD_NAME = \"embedding\"\n", | |
"\n", | |
"\n", | |
"client = MongoClient(userdata.get('mongo_uri'))\n", | |
"\n", | |
"db = client[DB_NAME]\n", | |
"collection = db[COLLECTION_NAME]" | |
], | |
"metadata": { | |
"id": "4zTz28l58xGx" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Get IP address to whitelist in mongodb" | |
], | |
"metadata": { | |
"id": "JDwXg1fvoSi9" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!curl ipecho.net/plain" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "MhBd2WX2BM3s", | |
"outputId": "0dc3c1bd-dbb7-4547-fb0e-d3d4d1bcd09c" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"34.32.147.23" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain.embeddings import OpenAIEmbeddings\n", | |
"from langchain.vectorstores import MongoDBAtlasVectorSearch\n", | |
"from langchain.docstore.document import Document\n", | |
"docs = [\n", | |
" Document(page_content=doc)\n", | |
" for doc in [\"pizza\",\"pasta\", \"salad\", \"italy\", \"germany\", \"france\", \"europe\",\"asia\", \"africa\", \"pie\"]\n", | |
"]\n", | |
"embedder = OpenAIEmbeddings(model=\"text-embedding-3-large\", disallowed_special=(), openai_api_key=userdata.get('OPENAI_API_KEY'), dimensions=2048)\n", | |
"# insert the documents in MongoDB Atlas Vector Search\n", | |
"x = MongoDBAtlasVectorSearch.from_documents(\n", | |
" documents=docs, embedding=embedder, collection=collection, index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME\n", | |
" )\n" | |
], | |
"metadata": { | |
"id": "7sFb3Zdm-Vv4" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"collection.count_documents({})" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "3V4yJP01A1b7", | |
"outputId": "b42c43af-c05f-4651-80e1-3abe71439f93" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"10" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 132 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def getsim(query, topk=5, exclude=None, adjust_probs=1.0):\n", | |
" if exclude is None:\n", | |
" exclude = [query]\n", | |
" print(\"query\",query, \"exclude\", exclude, \"adjust_probs\", adjust_probs)\n", | |
" results = collection.aggregate([\n", | |
" {\n", | |
" \"$vectorSearch\": {\n", | |
" \"index\": \"vector_index\",\n", | |
" \"queryVector\": get_embedding(query),\n", | |
" \"numCandidates\": 200,\n", | |
" \"limit\": topk,\n", | |
" \"path\": \"embedding\",\n", | |
" \"filter\": {\n", | |
" \"text\": { \"$nin\": exclude}\n", | |
" }\n", | |
" }},\n", | |
" {\n", | |
" \"$project\": {\n", | |
" \"_id\": 0,\n", | |
" \"text\": 1,\n", | |
" \"score\": { \"$meta\": \"vectorSearchScore\" }\n", | |
" }\n", | |
" }\n", | |
"\n", | |
" ])\n", | |
"\n", | |
" return [{'text':r['text'], 'score':r['score'] * adjust_probs} for r in list(results)[:topk]]\n", | |
"\n", | |
"getsim(\"3.15\", topk=1)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "UbPCcKKyB3oV", | |
"outputId": "1b956610-3cab-4be6-d706-8bc15778d23f" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"query 3.15 exclude ['3.15'] adjust_probs 1.0\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[{'text': 'pie', 'score': 0.6632951498031616}]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 123 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"collection.delete_many({})\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "xG-DiLcTEmkl", | |
"outputId": "1a857a7a-2b05-47f0-be86-e5d5059fd67e" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeleteResult({'n': 14, 'electionId': ObjectId('7fffffff00000000000000e0'), 'opTime': {'ts': Timestamp(1712436514, 13), 't': 224}, 'ok': 1.0, '$clusterTime': {'clusterTime': Timestamp(1712436514, 16), 'signature': {'hash': b'\\x11\"\\xd8Q{m,\\xd4(\\xbc{\\x0f\\xe6`5sl#^\\xec', 'keyId': 7299225970987237377}}, 'operationTime': Timestamp(1712436514, 13)}, acknowledged=True)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 53 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"input1=\"3.15\"\n", | |
"input2=\"ferrari\"\n", | |
"\n", | |
"s1 = getsim(input1, topk=2, exclude=[input1])\n", | |
"set1 = set([s[\"text\"] for s in s1])\n", | |
"print(input1, s1)\n", | |
"print(\"set1\", set1)\n", | |
"s2 = getsim(input2, topk=2, exclude=[input2])\n", | |
"set2 = set([s[\"text\"] for s in s2])\n", | |
"print(input2, s2)\n", | |
"print(\"set2\", set2)\n", | |
"\n", | |
"print()\n", | |
"\n", | |
"for s in s1:\n", | |
" sims = getsim(s[\"text\"], topk=2, exclude=list(set1), adjust_probs=s['score'])\n", | |
" print(s[\"text\"], sims)\n", | |
" for sim in sims:\n", | |
" set1.add(sim[\"text\"])\n", | |
" if sim[\"text\"] in set2:\n", | |
" print(\"found: \", sim[\"text\"])\n", | |
"print(\"set1\", set1)\n", | |
"\n", | |
"print()\n", | |
"\n", | |
"for s in s2:\n", | |
" sims = getsim(s[\"text\"], topk=2, exclude=list(set2), adjust_probs=s['score'])\n", | |
" print(s[\"text\"], sims)\n", | |
" for sim in sims:\n", | |
" set2.add(sim[\"text\"])\n", | |
" if sim[\"text\"] in set1:\n", | |
" print(\"found: \", sim[\"text\"])\n", | |
"print(\"set2\", set2)\n", | |
"\n", | |
"print (set1.intersection(set2))\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "xWlfWGwi-XU_", | |
"outputId": "9ec81dcf-8ff9-4e33-b4b1-1f9cff615c7a" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"query pie exclude ['pie'] adjust_probs 1.0\n", | |
"pie [{'text': 'pizza', 'score': 0.76506507396698}, {'text': 'pasta', 'score': 0.6837033629417419}]\n", | |
"set1 {'pizza', 'pasta'}\n", | |
"query ferrari exclude ['ferrari'] adjust_probs 1.0\n", | |
"ferrari [{'text': 'italy', 'score': 0.6754150390625}, {'text': 'france', 'score': 0.6711523532867432}]\n", | |
"set2 {'italy', 'france'}\n", | |
"\n", | |
"query pizza exclude ['pizza', 'pasta'] adjust_probs 0.76506507396698\n", | |
"pizza [{'text': 'pie', 'score': 0.58530760367141}, {'text': 'salad', 'score': 0.550856861858513}]\n", | |
"query pasta exclude ['pizza', 'pasta', 'pie', 'salad'] adjust_probs 0.6837033629417419\n", | |
"pasta [{'text': 'italy', 'score': 0.45891093243378833}, {'text': 'asia', 'score': 0.4391373789216999}]\n", | |
"found: italy\n", | |
"set1 {'italy', 'salad', 'pasta', 'pizza', 'asia', 'pie'}\n", | |
"\n", | |
"query italy exclude ['italy', 'france'] adjust_probs 0.6754150390625\n", | |
"italy [{'text': 'germany', 'score': 0.5503174412078806}, {'text': 'europe', 'score': 0.5277450528374175}]\n", | |
"query france exclude ['italy', 'germany', 'france', 'europe'] adjust_probs 0.6711523532867432\n", | |
"france [{'text': 'africa', 'score': 0.47426246229650815}, {'text': 'asia', 'score': 0.4330421091996328}]\n", | |
"found: asia\n", | |
"set2 {'africa', 'europe', 'france', 'germany', 'italy', 'asia'}\n", | |
"{'italy', 'asia'}\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"getsim(\"3.15 pie pizza\", topk=2)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "wNtNE71_JFoA", | |
"outputId": "7462c980-ee49-435f-b475-257e2eaf0269" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"query 3.15 pie exclude ['bla'] adjust_probs 1.0\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[{'text': 'pie', 'score': 0.7743104696273804},\n", | |
" {'text': 'pizza', 'score': 0.6535331010818481}]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 99 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"getsim(\"ferrari italy\", topk=2)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Wo47PQR8LqJ9", | |
"outputId": "1e48fb85-0489-47f3-b6f0-2b11a27663e4" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"query ferrari italy exclude ['bla'] adjust_probs 1.0\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[{'text': 'italy', 'score': 0.8061650991439819},\n", | |
" {'text': 'germany', 'score': 0.7086942195892334}]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 101 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"getsim(\"3.15\", topk=2)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "M0yGcS1Ocy2n", | |
"outputId": "80f0f983-4429-4f28-ba5b-1ca3cd50388a" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"query 3.15 exclude ['3.15'] adjust_probs 1.0\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[{'text': 'pie', 'score': 0.6632951498031616},\n", | |
" {'text': 'salad', 'score': 0.6039043068885803}]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 112 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"getsim(\"pie\", topk=2)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "TPvN0sOtdVjs", | |
"outputId": "f7d22b5e-c282-4571-eb51-fd0618f62b68" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"query pie exclude ['pie'] adjust_probs 1.0\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[{'text': 'pizza', 'score': 0.76506507396698},\n", | |
" {'text': 'pasta', 'score': 0.6837033629417419}]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 113 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"getsim(\"italy\", topk=20)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "a3hldsqIdZGo", | |
"outputId": "15624646-5ff9-4a7c-e7e7-49b59036aa82" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"query italy exclude ['italy'] adjust_probs 1.0\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[{'text': 'germany', 'score': 0.8147841095924377},\n", | |
" {'text': 'europe', 'score': 0.7813640832901001},\n", | |
" {'text': 'france', 'score': 0.7730833292007446},\n", | |
" {'text': 'africa', 'score': 0.7081910967826843},\n", | |
" {'text': 'asia', 'score': 0.6715290546417236},\n", | |
" {'text': 'pasta', 'score': 0.6712135076522827},\n", | |
" {'text': 'pizza', 'score': 0.6585623025894165},\n", | |
" {'text': 'salad', 'score': 0.6119321584701538},\n", | |
" {'text': 'pie', 'score': 0.6002786159515381}]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 126 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"getsim(\"ferrari\", topk=20)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_IoZHLs-gO8T", | |
"outputId": "099cb0fc-c9b8-434c-9c56-d4f8e5d0ad5c" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"query ferrari exclude ['ferrari'] adjust_probs 1.0\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[{'text': 'italy', 'score': 0.6754190921783447},\n", | |
" {'text': 'france', 'score': 0.6711257696151733},\n", | |
" {'text': 'pasta', 'score': 0.6701401472091675},\n", | |
" {'text': 'germany', 'score': 0.6491234302520752},\n", | |
" {'text': 'africa', 'score': 0.6408870220184326},\n", | |
" {'text': 'pizza', 'score': 0.638198733329773},\n", | |
" {'text': 'europe', 'score': 0.6174775958061218},\n", | |
" {'text': 'salad', 'score': 0.605734646320343},\n", | |
" {'text': 'asia', 'score': 0.6048133373260498},\n", | |
" {'text': 'pie', 'score': 0.6046735644340515}]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 128 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### TODO: figure out how to combine similarity scores" | |
], | |
"metadata": { | |
"id": "GjEsrGuzm_L_" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "XScACRAZdroo" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is the PoC of the idea I had at Memory Hackathon https://lu.ma/taa6ijxt (organized by
@LangChainAI, @newcomputer, @AnthropicAI and @MongoDB ) and my idea got second place in the "Memory Infra Category" (40 teams competing).
hackathon-vid.mp4