Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save alonsosilvaallende/43fe22cf566d70546411882086f80cb3 to your computer and use it in GitHub Desktop.
Save alonsosilvaallende/43fe22cf566d70546411882086f80cb3 to your computer and use it in GitHub Desktop.
outlines-KG-example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "dc01349d-f735-4c3a-b5e4-64519ab8b0e8",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-17T12:52:41.772743Z",
"iopub.status.busy": "2024-07-17T12:52:41.772253Z",
"iopub.status.idle": "2024-07-17T12:52:41.866367Z",
"shell.execute_reply": "2024-07-17T12:52:41.865439Z",
"shell.execute_reply.started": "2024-07-17T12:52:41.772694Z"
}
},
"outputs": [],
"source": [
"from pydantic import BaseModel, Field\n",
"\n",
"class Node(BaseModel):\n",
" \"\"\"Node of the Knowledge Graph\"\"\"\n",
"\n",
" id: int = Field(..., description=\"Unique identifier of the node\")\n",
" label: str = Field(..., description=\"Label of the node\")\n",
" property: str = Field(..., description=\"Property of the node\")\n",
"\n",
"\n",
"class Edge(BaseModel):\n",
" \"\"\"Edge of the Knowledge Graph\"\"\"\n",
"\n",
" source: int = Field(..., description=\"Unique source of the edge\")\n",
" target: int = Field(..., description=\"Unique target of the edge\")\n",
" label: str = Field(..., description=\"Label of the edge\")\n",
" property: str = Field(..., description=\"Property of the edge\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "df8ad510-b13b-4309-a1e7-b5fe0cddb91c",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-17T12:52:41.867243Z",
"iopub.status.busy": "2024-07-17T12:52:41.867071Z",
"iopub.status.idle": "2024-07-17T12:52:41.872385Z",
"shell.execute_reply": "2024-07-17T12:52:41.871697Z",
"shell.execute_reply.started": "2024-07-17T12:52:41.867226Z"
}
},
"outputs": [],
"source": [
"from typing import List\n",
"\n",
"class KnowledgeGraph(BaseModel):\n",
" \"\"\"Generated Knowledge Graph\"\"\"\n",
"\n",
" nodes: List[Node] = Field(..., description=\"List of nodes of the knowledge graph\")\n",
" edges: List[Edge] = Field(..., description=\"List of edges of the knowledge graph\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7bc727a1-cc26-44d6-84ea-4f67de777676",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-17T12:52:41.873223Z",
"iopub.status.busy": "2024-07-17T12:52:41.873066Z",
"iopub.status.idle": "2024-07-17T12:52:41.894285Z",
"shell.execute_reply": "2024-07-17T12:52:41.892878Z",
"shell.execute_reply.started": "2024-07-17T12:52:41.873208Z"
}
},
"outputs": [],
"source": [
"schema = KnowledgeGraph.model_json_schema()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f69f4232-10fd-4e45-b496-78a2aa139891",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-17T12:52:41.895804Z",
"iopub.status.busy": "2024-07-17T12:52:41.895562Z",
"iopub.status.idle": "2024-07-17T12:52:48.305301Z",
"shell.execute_reply": "2024-07-17T12:52:48.303978Z",
"shell.execute_reply.started": "2024-07-17T12:52:41.895781Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/asilva/test-finite-state-machine/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"import llama_cpp\n",
"from llama_cpp import Llama\n",
"from outlines import generate, models\n",
"\n",
"llm = Llama(\n",
" \"/home/asilva/models/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf\", # replace with your /path/to/the/model\n",
" tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(\n",
" \"NousResearch/Hermes-2-Pro-Llama-3-8B\"\n",
" ),\n",
" n_gpu_layers=-1,\n",
" n_ctx=8192,\n",
" flash_attn=True,\n",
" verbose=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "71ee25c8-27b7-47ea-a53c-40868eb9b5ca",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-17T12:52:48.306197Z",
"iopub.status.busy": "2024-07-17T12:52:48.306031Z",
"iopub.status.idle": "2024-07-17T12:52:48.310246Z",
"shell.execute_reply": "2024-07-17T12:52:48.309183Z",
"shell.execute_reply.started": "2024-07-17T12:52:48.306181Z"
}
},
"outputs": [],
"source": [
"user_prompt = \"Alice loves Bob and she hates Charlie.\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "83f126cb-6898-437c-96cf-fd9df7f0b923",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-17T12:52:48.314381Z",
"iopub.status.busy": "2024-07-17T12:52:48.313777Z",
"iopub.status.idle": "2024-07-17T12:52:48.335875Z",
"shell.execute_reply": "2024-07-17T12:52:48.334430Z",
"shell.execute_reply.started": "2024-07-17T12:52:48.314334Z"
}
},
"outputs": [],
"source": [
"def generate_hermes_prompt(user_prompt):\n",
" return (\n",
" \"<|im_start|>system\\n\"\n",
" \"You are a world class AI model who answers questions in JSON \"\n",
" f\"Here's the json schema you must adhere to:\\n<schema>\\n{schema}\\n</schema><|im_end|>\\n\"\n",
" \"<|im_start|>user\\n\"\n",
" + user_prompt\n",
" + \"<|im_end|>\"\n",
" + \"\\n<|im_start|>assistant\\n\"\n",
" \"<schema>\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0c4ae094-9e63-4c75-a20f-53c4d1161a05",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-17T12:52:48.338232Z",
"iopub.status.busy": "2024-07-17T12:52:48.337275Z",
"iopub.status.idle": "2024-07-17T12:52:48.557305Z",
"shell.execute_reply": "2024-07-17T12:52:48.556276Z",
"shell.execute_reply.started": "2024-07-17T12:52:48.338186Z"
}
},
"outputs": [],
"source": [
"from outlines import generate, models\n",
"\n",
"model = models.LlamaCpp(llm)\n",
"generator = generate.json(model, KnowledgeGraph)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5e5cb3b3-3057-411e-a343-34a0fb5736aa",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-17T12:52:48.559320Z",
"iopub.status.busy": "2024-07-17T12:52:48.558776Z",
"iopub.status.idle": "2024-07-17T12:52:51.319255Z",
"shell.execute_reply": "2024-07-17T12:52:51.318688Z",
"shell.execute_reply.started": "2024-07-17T12:52:48.559273Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/asilva/test-finite-state-machine/.venv/lib/python3.10/site-packages/llama_cpp/llama.py:1054: RuntimeWarning: Detected duplicate leading \"<|begin_of_text|>\" in prompt, this will likely reduce response quality, consider removing it...\n",
" warnings.warn(\n"
]
}
],
"source": [
"prompt = generate_hermes_prompt(user_prompt)\n",
"response = generator(prompt, max_tokens=1024, temperature=0, seed=42)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "eb79f281-3d62-4e70-90d9-a846752f66cf",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-17T12:52:51.320207Z",
"iopub.status.busy": "2024-07-17T12:52:51.319965Z",
"iopub.status.idle": "2024-07-17T12:52:51.325420Z",
"shell.execute_reply": "2024-07-17T12:52:51.325029Z",
"shell.execute_reply.started": "2024-07-17T12:52:51.320187Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[Node(id=1, label='Alice', property='Person'),\n",
" Node(id=2, label='Bob', property='Person'),\n",
" Node(id=3, label='Charlie', property='Person')]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response.nodes"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "6d58debb-1710-45bb-91c5-02264a9ccf0f",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-17T12:52:51.326147Z",
"iopub.status.busy": "2024-07-17T12:52:51.325943Z",
"iopub.status.idle": "2024-07-17T12:52:51.353580Z",
"shell.execute_reply": "2024-07-17T12:52:51.352358Z",
"shell.execute_reply.started": "2024-07-17T12:52:51.326131Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[Edge(source=1, target=2, label='love', property='Relationship'),\n",
" Edge(source=1, target=3, label='hate', property='Relationship')]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response.edges"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "85bcf40b-cfb9-49a6-bcea-e844a771bc0c",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-17T12:52:51.355673Z",
"iopub.status.busy": "2024-07-17T12:52:51.355180Z",
"iopub.status.idle": "2024-07-17T12:52:51.460466Z",
"shell.execute_reply": "2024-07-17T12:52:51.459124Z",
"shell.execute_reply.started": "2024-07-17T12:52:51.355626Z"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.43.0 (0)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"170pt\" height=\"203pt\"\n",
" viewBox=\"0.00 0.00 170.00 203.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 199)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-199 166,-199 166,4 -4,4\"/>\n",
"<!-- 1 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"81\" cy=\"-159\" rx=\"36\" ry=\"36\"/>\n",
"<text text-anchor=\"middle\" x=\"81\" y=\"-155.3\" font-family=\"Times,serif\" font-size=\"14.00\">Alice</text>\n",
"</g>\n",
"<!-- 2 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>2</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"36\" cy=\"-36\" rx=\"36\" ry=\"36\"/>\n",
"<text text-anchor=\"middle\" x=\"36\" y=\"-32.3\" font-family=\"Times,serif\" font-size=\"14.00\">Bob</text>\n",
"</g>\n",
"<!-- 1&#45;&gt;2 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>1&#45;&gt;2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M68.7,-124.94C63.49,-110.93 57.36,-94.44 51.82,-79.54\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"55.05,-78.18 48.28,-70.02 48.49,-80.62 55.05,-78.18\"/>\n",
"<text text-anchor=\"middle\" x=\"77\" y=\"-93.8\" font-family=\"Times,serif\" font-size=\"14.00\">love</text>\n",
"</g>\n",
"<!-- 3 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>3</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"126\" cy=\"-36\" rx=\"36\" ry=\"36\"/>\n",
"<text text-anchor=\"middle\" x=\"126\" y=\"-32.3\" font-family=\"Times,serif\" font-size=\"14.00\">Charlie</text>\n",
"</g>\n",
"<!-- 1&#45;&gt;3 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>1&#45;&gt;3</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M93.3,-124.94C98.51,-110.93 104.64,-94.44 110.18,-79.54\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"113.51,-80.62 113.72,-70.02 106.95,-78.18 113.51,-80.62\"/>\n",
"<text text-anchor=\"middle\" x=\"122\" y=\"-93.8\" font-family=\"Times,serif\" font-size=\"14.00\">hate</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7b47c46d4700>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from graphviz import Digraph\n",
"\n",
"dot = Digraph()\n",
"for node in response.nodes:\n",
" dot.node(str(node.id), node.label, shape='circle', width='1', height='1')\n",
"for edge in response.edges:\n",
" dot.edge(str(edge.source), str(edge.target), label=edge.label)\n",
"\n",
"dot"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment