Last active
August 7, 2024 20:16
-
-
Save alonsosilvaallende/27213240a19aa94f308a38b00a913b5b to your computer and use it in GitHub Desktop.
Structured Outputs with Outlines+Hermes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "3d169209-a562-465a-bbe8-7580fe05c3ce", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:42.331411Z", | |
"iopub.status.busy": "2024-08-07T20:14:42.330874Z", | |
"iopub.status.idle": "2024-08-07T20:14:42.372611Z", | |
"shell.execute_reply": "2024-08-07T20:14:42.371402Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:42.331343Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"%load_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "b33f0a08-a699-4682-a50a-e5b21acb7645", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:42.838673Z", | |
"iopub.status.busy": "2024-08-07T20:14:42.838160Z", | |
"iopub.status.idle": "2024-08-07T20:14:42.854839Z", | |
"shell.execute_reply": "2024-08-07T20:14:42.853533Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:42.838623Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import warnings\n", | |
"warnings.filterwarnings(\"ignore\", category=RuntimeWarning) # ignore runtime warnings" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "96cdbccc-c584-4966-a442-02741a171ab2", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:43.399221Z", | |
"iopub.status.busy": "2024-08-07T20:14:43.398710Z", | |
"iopub.status.idle": "2024-08-07T20:14:43.510281Z", | |
"shell.execute_reply": "2024-08-07T20:14:43.508967Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:43.399173Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from pydantic import BaseModel, Field\n", | |
"\n", | |
"class Reasoning_Step(BaseModel):\n", | |
" reasoning_step: str = Field(..., description=\"Reasoning step\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "532995c5-f076-4bf7-b294-ed70332c1c10", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:44.016348Z", | |
"iopub.status.busy": "2024-08-07T20:14:44.015808Z", | |
"iopub.status.idle": "2024-08-07T20:14:44.038348Z", | |
"shell.execute_reply": "2024-08-07T20:14:44.037085Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:44.016298Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from typing import List\n", | |
"\n", | |
"class Reasoning(BaseModel):\n", | |
" reasoning: List[Reasoning_Step] = Field(..., description=\"List of reasoning steps\")\n", | |
" conclusion: str = Field(..., description=\"Conclusion\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "a125eb20-efb2-4274-a42d-a35f58d9db54", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:44.635274Z", | |
"iopub.status.busy": "2024-08-07T20:14:44.634244Z", | |
"iopub.status.idle": "2024-08-07T20:14:44.660004Z", | |
"shell.execute_reply": "2024-08-07T20:14:44.658754Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:44.635223Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"json_schema = Reasoning.model_json_schema()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "61ee86ce-3503-43ae-879d-f947aee88623", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:45.282287Z", | |
"iopub.status.busy": "2024-08-07T20:14:45.281773Z", | |
"iopub.status.idle": "2024-08-07T20:14:48.955568Z", | |
"shell.execute_reply": "2024-08-07T20:14:48.954148Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:45.282238Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from outlines.integrations.utils import convert_json_schema_to_str\n", | |
"from outlines.fsm.json_schema import build_regex_from_schema\n", | |
"\n", | |
"def get_regex_from_Pydantic_class(pydantic_class):\n", | |
" json_schema = pydantic_class.model_json_schema()\n", | |
" schema_str = convert_json_schema_to_str(json_schema=json_schema)\n", | |
" regex_str = build_regex_from_schema(schema_str, whitespace_pattern=r\" \")\n", | |
" return regex_str" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "f6e5e037-f005-47c8-9f30-9e4eeb5ef48a", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:48.957977Z", | |
"iopub.status.busy": "2024-08-07T20:14:48.957296Z", | |
"iopub.status.idle": "2024-08-07T20:14:48.986551Z", | |
"shell.execute_reply": "2024-08-07T20:14:48.985904Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:48.957937Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'\\\\{ \"reasoning\" : \\\\[ ((\\\\{ \"reasoning_step\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\})(, (\\\\{ \"reasoning_step\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\})){0,})? \\\\] , \"conclusion\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\}'" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"regex_str = get_regex_from_Pydantic_class(Reasoning)\n", | |
"regex_str" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "64f2fe73-56f7-4533-9357-394d5c6555dd", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:51.316487Z", | |
"iopub.status.busy": "2024-08-07T20:14:51.315935Z", | |
"iopub.status.idle": "2024-08-07T20:14:54.274501Z", | |
"shell.execute_reply": "2024-08-07T20:14:54.273492Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:51.316436Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | |
] | |
} | |
], | |
"source": [ | |
"import llama_cpp\n", | |
"from llama_cpp import Llama\n", | |
"from outlines import generate, models\n", | |
"\n", | |
"llm = Llama(\n", | |
" \"/big_storage/llms/models/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf\",\n", | |
" tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(\n", | |
" \"NousResearch/Hermes-2-Pro-Llama-3-8B\"\n", | |
" ),\n", | |
" n_gpu_layers=-1,\n", | |
" flash_attn=True,\n", | |
" n_ctx=8192,\n", | |
" verbose=False\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "3fddb1a4-fc3d-4a81-bfb6-ab07c94c16e2", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:54.275697Z", | |
"iopub.status.busy": "2024-08-07T20:14:54.275509Z", | |
"iopub.status.idle": "2024-08-07T20:14:54.294054Z", | |
"shell.execute_reply": "2024-08-07T20:14:54.292995Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:54.275680Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def generate_hermes_prompt(question, schema=\"\"):\n", | |
" return (\n", | |
" \"<|im_start|>system\\n\"\n", | |
" \"You are a world class AI model who answers questions in JSON with correct Pydantic schema. \"\n", | |
" f\"Here's the json schema you must adhere to:\\n<schema>\\n{schema}\\n</schema>\"\n", | |
" \"\\n<|im_start|>user\\n\" + question + \"<|im_end|>\"\n", | |
" \"\\n<|im_start|>assistant\\n\"\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "0cc7dede-c0b8-48aa-ab44-14deaea9064d", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:55.174614Z", | |
"iopub.status.busy": "2024-08-07T20:14:55.174089Z", | |
"iopub.status.idle": "2024-08-07T20:14:55.213874Z", | |
"shell.execute_reply": "2024-08-07T20:14:55.212761Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:55.174565Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<|im_start|>system\n", | |
"You are a world class AI model who answers questions in JSON with correct Pydantic schema. Here's the json schema you must adhere to:\n", | |
"<schema>\n", | |
"{'$defs': {'Reasoning_Step': {'properties': {'reasoning_step': {'description': 'Reasoning step', 'title': 'Reasoning Step', 'type': 'string'}}, 'required': ['reasoning_step'], 'title': 'Reasoning_Step', 'type': 'object'}}, 'properties': {'reasoning': {'description': 'List of reasoning steps', 'items': {'$ref': '#/$defs/Reasoning_Step'}, 'title': 'Reasoning', 'type': 'array'}, 'conclusion': {'description': 'Conclusion', 'title': 'Conclusion', 'type': 'string'}}, 'required': ['reasoning', 'conclusion'], 'title': 'Reasoning', 'type': 'object'}\n", | |
"</schema>\n", | |
"<|im_start|>user\n", | |
"9.11 and 9.9 -- which is bigger?<|im_end|>\n", | |
"<|im_start|>assistant\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"prompt = generate_hermes_prompt(\"9.11 and 9.9 -- which is bigger?\", schema=Reasoning.model_json_schema())\n", | |
"print(prompt)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "c090e919-7743-48b2-9ae9-7a196939fb7b", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:56.638844Z", | |
"iopub.status.busy": "2024-08-07T20:14:56.638166Z", | |
"iopub.status.idle": "2024-08-07T20:14:56.676125Z", | |
"shell.execute_reply": "2024-08-07T20:14:56.675297Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:56.638792Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"model = models.LlamaCpp(llm)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "809349a7-451e-4888-ae24-ffe88fc50402", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:14:59.155260Z", | |
"iopub.status.busy": "2024-08-07T20:14:59.154070Z", | |
"iopub.status.idle": "2024-08-07T20:14:59.384102Z", | |
"shell.execute_reply": "2024-08-07T20:14:59.383068Z", | |
"shell.execute_reply.started": "2024-08-07T20:14:59.155207Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"generator = generate.regex(\n", | |
" model,\n", | |
" regex_str,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "8b6574e0-6d76-4ca7-9841-57b8f1fec7d5", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-07T20:15:00.101880Z", | |
"iopub.status.busy": "2024-08-07T20:15:00.101120Z", | |
"iopub.status.idle": "2024-08-07T20:15:03.386404Z", | |
"shell.execute_reply": "2024-08-07T20:15:03.385585Z", | |
"shell.execute_reply.started": "2024-08-07T20:15:00.101823Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'{ \"reasoning\" : [ { \"reasoning_step\" : \"Both 9.11 and 9.9 are decimal numbers.\" }, { \"reasoning_step\" : \"When comparing decimal numbers, we look at the numbers after the decimal point.\" }, { \"reasoning_step\" : \"In this case, 9.11 has the number 1 after the decimal point, while 9.9 has the number 9.\" }, { \"reasoning_step\" : \"Since 1 is greater than 9, 9.11 is greater than 9.9.\" } ], \"conclusion\" : \"9.11 is bigger.\" }'" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"response = generator(prompt, max_tokens=1024, temperature=0, seed=0)\n", | |
"response" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment