Last active
July 17, 2024 19:43
-
-
Save alonsosilvaallende/43fe22cf566d70546411882086f80cb3 to your computer and use it in GitHub Desktop.
outlines-KG-example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "dc01349d-f735-4c3a-b5e4-64519ab8b0e8", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-07-17T12:52:41.772743Z", | |
"iopub.status.busy": "2024-07-17T12:52:41.772253Z", | |
"iopub.status.idle": "2024-07-17T12:52:41.866367Z", | |
"shell.execute_reply": "2024-07-17T12:52:41.865439Z", | |
"shell.execute_reply.started": "2024-07-17T12:52:41.772694Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from pydantic import BaseModel, Field\n", | |
"\n", | |
"class Node(BaseModel):\n", | |
" \"\"\"Node of the Knowledge Graph\"\"\"\n", | |
"\n", | |
" id: int = Field(..., description=\"Unique identifier of the node\")\n", | |
" label: str = Field(..., description=\"Label of the node\")\n", | |
" property: str = Field(..., description=\"Property of the node\")\n", | |
"\n", | |
"\n", | |
"class Edge(BaseModel):\n", | |
" \"\"\"Edge of the Knowledge Graph\"\"\"\n", | |
"\n", | |
" source: int = Field(..., description=\"Unique source of the edge\")\n", | |
" target: int = Field(..., description=\"Unique target of the edge\")\n", | |
" label: str = Field(..., description=\"Label of the edge\")\n", | |
" property: str = Field(..., description=\"Property of the edge\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "df8ad510-b13b-4309-a1e7-b5fe0cddb91c", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-07-17T12:52:41.867243Z", | |
"iopub.status.busy": "2024-07-17T12:52:41.867071Z", | |
"iopub.status.idle": "2024-07-17T12:52:41.872385Z", | |
"shell.execute_reply": "2024-07-17T12:52:41.871697Z", | |
"shell.execute_reply.started": "2024-07-17T12:52:41.867226Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from typing import List\n", | |
"\n", | |
"class KnowledgeGraph(BaseModel):\n", | |
" \"\"\"Generated Knowledge Graph\"\"\"\n", | |
"\n", | |
" nodes: List[Node] = Field(..., description=\"List of nodes of the knowledge graph\")\n", | |
" edges: List[Edge] = Field(..., description=\"List of edges of the knowledge graph\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "7bc727a1-cc26-44d6-84ea-4f67de777676", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-07-17T12:52:41.873223Z", | |
"iopub.status.busy": "2024-07-17T12:52:41.873066Z", | |
"iopub.status.idle": "2024-07-17T12:52:41.894285Z", | |
"shell.execute_reply": "2024-07-17T12:52:41.892878Z", | |
"shell.execute_reply.started": "2024-07-17T12:52:41.873208Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"schema = KnowledgeGraph.model_json_schema()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "f69f4232-10fd-4e45-b496-78a2aa139891", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-07-17T12:52:41.895804Z", | |
"iopub.status.busy": "2024-07-17T12:52:41.895562Z", | |
"iopub.status.idle": "2024-07-17T12:52:48.305301Z", | |
"shell.execute_reply": "2024-07-17T12:52:48.303978Z", | |
"shell.execute_reply.started": "2024-07-17T12:52:41.895781Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/asilva/test-finite-state-machine/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
" from .autonotebook import tqdm as notebook_tqdm\n", | |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | |
] | |
} | |
], | |
"source": [ | |
"import llama_cpp\n", | |
"from llama_cpp import Llama\n", | |
"from outlines import generate, models\n", | |
"\n", | |
"llm = Llama(\n", | |
" \"/home/asilva/models/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf\", # replace with your /path/to/the/model\n", | |
" tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(\n", | |
" \"NousResearch/Hermes-2-Pro-Llama-3-8B\"\n", | |
" ),\n", | |
" n_gpu_layers=-1,\n", | |
" n_ctx=8192,\n", | |
" flash_attn=True,\n", | |
" verbose=False,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "71ee25c8-27b7-47ea-a53c-40868eb9b5ca", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-07-17T12:52:48.306197Z", | |
"iopub.status.busy": "2024-07-17T12:52:48.306031Z", | |
"iopub.status.idle": "2024-07-17T12:52:48.310246Z", | |
"shell.execute_reply": "2024-07-17T12:52:48.309183Z", | |
"shell.execute_reply.started": "2024-07-17T12:52:48.306181Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"user_prompt = \"Alice loves Bob and she hates Charlie.\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "83f126cb-6898-437c-96cf-fd9df7f0b923", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-07-17T12:52:48.314381Z", | |
"iopub.status.busy": "2024-07-17T12:52:48.313777Z", | |
"iopub.status.idle": "2024-07-17T12:52:48.335875Z", | |
"shell.execute_reply": "2024-07-17T12:52:48.334430Z", | |
"shell.execute_reply.started": "2024-07-17T12:52:48.314334Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def generate_hermes_prompt(user_prompt):\n", | |
" return (\n", | |
" \"<|im_start|>system\\n\"\n", | |
" \"You are a world class AI model who answers questions in JSON \"\n", | |
" f\"Here's the json schema you must adhere to:\\n<schema>\\n{schema}\\n</schema><|im_end|>\\n\"\n", | |
" \"<|im_start|>user\\n\"\n", | |
" + user_prompt\n", | |
" + \"<|im_end|>\"\n", | |
" + \"\\n<|im_start|>assistant\\n\"\n", | |
" \"<schema>\"\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "0c4ae094-9e63-4c75-a20f-53c4d1161a05", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-07-17T12:52:48.338232Z", | |
"iopub.status.busy": "2024-07-17T12:52:48.337275Z", | |
"iopub.status.idle": "2024-07-17T12:52:48.557305Z", | |
"shell.execute_reply": "2024-07-17T12:52:48.556276Z", | |
"shell.execute_reply.started": "2024-07-17T12:52:48.338186Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from outlines import generate, models\n", | |
"\n", | |
"model = models.LlamaCpp(llm)\n", | |
"generator = generate.json(model, KnowledgeGraph)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "5e5cb3b3-3057-411e-a343-34a0fb5736aa", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-07-17T12:52:48.559320Z", | |
"iopub.status.busy": "2024-07-17T12:52:48.558776Z", | |
"iopub.status.idle": "2024-07-17T12:52:51.319255Z", | |
"shell.execute_reply": "2024-07-17T12:52:51.318688Z", | |
"shell.execute_reply.started": "2024-07-17T12:52:48.559273Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/asilva/test-finite-state-machine/.venv/lib/python3.10/site-packages/llama_cpp/llama.py:1054: RuntimeWarning: Detected duplicate leading \"<|begin_of_text|>\" in prompt, this will likely reduce response quality, consider removing it...\n", | |
" warnings.warn(\n" | |
] | |
} | |
], | |
"source": [ | |
"prompt = generate_hermes_prompt(user_prompt)\n", | |
"response = generator(prompt, max_tokens=1024, temperature=0, seed=42)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "eb79f281-3d62-4e70-90d9-a846752f66cf", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-07-17T12:52:51.320207Z", | |
"iopub.status.busy": "2024-07-17T12:52:51.319965Z", | |
"iopub.status.idle": "2024-07-17T12:52:51.325420Z", | |
"shell.execute_reply": "2024-07-17T12:52:51.325029Z", | |
"shell.execute_reply.started": "2024-07-17T12:52:51.320187Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[Node(id=1, label='Alice', property='Person'),\n", | |
" Node(id=2, label='Bob', property='Person'),\n", | |
" Node(id=3, label='Charlie', property='Person')]" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"response.nodes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "6d58debb-1710-45bb-91c5-02264a9ccf0f", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-07-17T12:52:51.326147Z", | |
"iopub.status.busy": "2024-07-17T12:52:51.325943Z", | |
"iopub.status.idle": "2024-07-17T12:52:51.353580Z", | |
"shell.execute_reply": "2024-07-17T12:52:51.352358Z", | |
"shell.execute_reply.started": "2024-07-17T12:52:51.326131Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[Edge(source=1, target=2, label='love', property='Relationship'),\n", | |
" Edge(source=1, target=3, label='hate', property='Relationship')]" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"response.edges" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "85bcf40b-cfb9-49a6-bcea-e844a771bc0c", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-07-17T12:52:51.355673Z", | |
"iopub.status.busy": "2024-07-17T12:52:51.355180Z", | |
"iopub.status.idle": "2024-07-17T12:52:51.460466Z", | |
"shell.execute_reply": "2024-07-17T12:52:51.459124Z", | |
"shell.execute_reply.started": "2024-07-17T12:52:51.355626Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 2.43.0 (0)\n", | |
" -->\n", | |
"<!-- Title: %3 Pages: 1 -->\n", | |
"<svg width=\"170pt\" height=\"203pt\"\n", | |
" viewBox=\"0.00 0.00 170.00 203.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 199)\">\n", | |
"<title>%3</title>\n", | |
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-199 166,-199 166,4 -4,4\"/>\n", | |
"<!-- 1 -->\n", | |
"<g id=\"node1\" class=\"node\">\n", | |
"<title>1</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"81\" cy=\"-159\" rx=\"36\" ry=\"36\"/>\n", | |
"<text text-anchor=\"middle\" x=\"81\" y=\"-155.3\" font-family=\"Times,serif\" font-size=\"14.00\">Alice</text>\n", | |
"</g>\n", | |
"<!-- 2 -->\n", | |
"<g id=\"node2\" class=\"node\">\n", | |
"<title>2</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"36\" cy=\"-36\" rx=\"36\" ry=\"36\"/>\n", | |
"<text text-anchor=\"middle\" x=\"36\" y=\"-32.3\" font-family=\"Times,serif\" font-size=\"14.00\">Bob</text>\n", | |
"</g>\n", | |
"<!-- 1->2 -->\n", | |
"<g id=\"edge1\" class=\"edge\">\n", | |
"<title>1->2</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M68.7,-124.94C63.49,-110.93 57.36,-94.44 51.82,-79.54\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"55.05,-78.18 48.28,-70.02 48.49,-80.62 55.05,-78.18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"77\" y=\"-93.8\" font-family=\"Times,serif\" font-size=\"14.00\">love</text>\n", | |
"</g>\n", | |
"<!-- 3 -->\n", | |
"<g id=\"node3\" class=\"node\">\n", | |
"<title>3</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"126\" cy=\"-36\" rx=\"36\" ry=\"36\"/>\n", | |
"<text text-anchor=\"middle\" x=\"126\" y=\"-32.3\" font-family=\"Times,serif\" font-size=\"14.00\">Charlie</text>\n", | |
"</g>\n", | |
"<!-- 1->3 -->\n", | |
"<g id=\"edge2\" class=\"edge\">\n", | |
"<title>1->3</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M93.3,-124.94C98.51,-110.93 104.64,-94.44 110.18,-79.54\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"113.51,-80.62 113.72,-70.02 106.95,-78.18 113.51,-80.62\"/>\n", | |
"<text text-anchor=\"middle\" x=\"122\" y=\"-93.8\" font-family=\"Times,serif\" font-size=\"14.00\">hate</text>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.graphs.Digraph at 0x7b47c46d4700>" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from graphviz import Digraph\n", | |
"\n", | |
"dot = Digraph()\n", | |
"for node in response.nodes:\n", | |
" dot.node(str(node.id), node.label, shape='circle', width='1', height='1')\n", | |
"for edge in response.edges:\n", | |
" dot.edge(str(edge.source), str(edge.target), label=edge.label)\n", | |
"\n", | |
"dot" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment