Last active
August 8, 2024 07:27
-
-
Save alonsosilvaallende/0521d509ed89cead87d973f47fca2964 to your computer and use it in GitHub Desktop.
Structured Outputs with Outlines+Hermes70
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "3d169209-a562-465a-bbe8-7580fe05c3ce", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:10:45.052514Z", | |
"iopub.status.busy": "2024-08-08T07:10:45.052019Z", | |
"iopub.status.idle": "2024-08-08T07:10:45.080962Z", | |
"shell.execute_reply": "2024-08-08T07:10:45.079578Z", | |
"shell.execute_reply.started": "2024-08-08T07:10:45.052463Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"%load_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "b33f0a08-a699-4682-a50a-e5b21acb7645", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:10:46.127082Z", | |
"iopub.status.busy": "2024-08-08T07:10:46.126588Z", | |
"iopub.status.idle": "2024-08-08T07:10:46.143929Z", | |
"shell.execute_reply": "2024-08-08T07:10:46.142492Z", | |
"shell.execute_reply.started": "2024-08-08T07:10:46.127035Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import warnings\n", | |
"warnings.filterwarnings(\"ignore\", category=RuntimeWarning) # ignore runtime warnings" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "96cdbccc-c584-4966-a442-02741a171ab2", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:10:46.712614Z", | |
"iopub.status.busy": "2024-08-08T07:10:46.712104Z", | |
"iopub.status.idle": "2024-08-08T07:10:46.838002Z", | |
"shell.execute_reply": "2024-08-08T07:10:46.837097Z", | |
"shell.execute_reply.started": "2024-08-08T07:10:46.712566Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from pydantic import BaseModel, Field\n", | |
"\n", | |
"class Reasoning_Step(BaseModel):\n", | |
" reasoning_step: str = Field(..., description=\"Reasoning step\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "532995c5-f076-4bf7-b294-ed70332c1c10", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:10:47.364049Z", | |
"iopub.status.busy": "2024-08-08T07:10:47.363486Z", | |
"iopub.status.idle": "2024-08-08T07:10:47.385444Z", | |
"shell.execute_reply": "2024-08-08T07:10:47.384082Z", | |
"shell.execute_reply.started": "2024-08-08T07:10:47.363997Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from typing import List\n", | |
"\n", | |
"class Reasoning(BaseModel):\n", | |
" reasoning: List[Reasoning_Step] = Field(..., description=\"List of reasoning steps\")\n", | |
" conclusion: str = Field(..., description=\"Conclusion\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "a125eb20-efb2-4274-a42d-a35f58d9db54", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:10:48.314992Z", | |
"iopub.status.busy": "2024-08-08T07:10:48.314485Z", | |
"iopub.status.idle": "2024-08-08T07:10:48.338411Z", | |
"shell.execute_reply": "2024-08-08T07:10:48.337160Z", | |
"shell.execute_reply.started": "2024-08-08T07:10:48.314945Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"json_schema = Reasoning.model_json_schema()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "61ee86ce-3503-43ae-879d-f947aee88623", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:10:49.271964Z", | |
"iopub.status.busy": "2024-08-08T07:10:49.271430Z", | |
"iopub.status.idle": "2024-08-08T07:10:52.941327Z", | |
"shell.execute_reply": "2024-08-08T07:10:52.940042Z", | |
"shell.execute_reply.started": "2024-08-08T07:10:49.271917Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from outlines.integrations.utils import convert_json_schema_to_str\n", | |
"from outlines.fsm.json_schema import build_regex_from_schema\n", | |
"\n", | |
"def get_regex_from_Pydantic_class(pydantic_class):\n", | |
" json_schema = pydantic_class.model_json_schema()\n", | |
" schema_str = convert_json_schema_to_str(json_schema=json_schema)\n", | |
" regex_str = build_regex_from_schema(schema_str)\n", | |
" return regex_str" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "f6e5e037-f005-47c8-9f30-9e4eeb5ef48a", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:10:53.649740Z", | |
"iopub.status.busy": "2024-08-08T07:10:53.648888Z", | |
"iopub.status.idle": "2024-08-08T07:10:53.692343Z", | |
"shell.execute_reply": "2024-08-08T07:10:53.691279Z", | |
"shell.execute_reply.started": "2024-08-08T07:10:53.649687Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'\\\\{[ ]?\"reasoning\"[ ]?:[ ]?\\\\[[ ]?((\\\\{[ ]?\"reasoning_step\"[ ]?:[ ]?\"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\"[ ]?\\\\})(,[ ]?(\\\\{[ ]?\"reasoning_step\"[ ]?:[ ]?\"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\"[ ]?\\\\})){0,})?[ ]?\\\\][ ]?,[ ]?\"conclusion\"[ ]?:[ ]?\"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\"[ ]?\\\\}'" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"regex_str = get_regex_from_Pydantic_class(Reasoning)\n", | |
"regex_str" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "64f2fe73-56f7-4533-9357-394d5c6555dd", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:10:56.210740Z", | |
"iopub.status.busy": "2024-08-08T07:10:56.210186Z", | |
"iopub.status.idle": "2024-08-08T07:11:04.506512Z", | |
"shell.execute_reply": "2024-08-08T07:11:04.505545Z", | |
"shell.execute_reply.started": "2024-08-08T07:10:56.210690Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | |
] | |
} | |
], | |
"source": [ | |
"import llama_cpp\n", | |
"from llama_cpp import Llama\n", | |
"from outlines import generate, models\n", | |
"\n", | |
"llm = Llama(\n", | |
" \"/big_storage/llms/models/Hermes-2-Pro-Llama-3-70B_Q4_K_M.gguf\",\n", | |
" tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(\n", | |
" \"NousResearch/Hermes-2-Pro-Llama-3-70B\"\n", | |
" ),\n", | |
" n_gpu_layers=-1,\n", | |
" flash_attn=True,\n", | |
" n_ctx=8192,\n", | |
" verbose=False\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "3fddb1a4-fc3d-4a81-bfb6-ab07c94c16e2", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:12:11.224836Z", | |
"iopub.status.busy": "2024-08-08T07:12:11.224185Z", | |
"iopub.status.idle": "2024-08-08T07:12:11.262220Z", | |
"shell.execute_reply": "2024-08-08T07:12:11.260828Z", | |
"shell.execute_reply.started": "2024-08-08T07:12:11.224782Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def generate_hermes_prompt(question, schema=\"\"):\n", | |
" return (\n", | |
" \"<|im_start|>system\\n\"\n", | |
" \"You are a world class AI model who answers questions in JSON with correct Pydantic schema. \"\n", | |
" f\"Here's the json schema you must adhere to:\\n<schema>\\n{schema}\\n</schema>\"\n", | |
" \"\\n<|im_start|>user\\n\" + question + \"<|im_end|>\"\n", | |
" \"\\n<|im_start|>assistant\\n\"\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "0cc7dede-c0b8-48aa-ab44-14deaea9064d", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:12:12.660511Z", | |
"iopub.status.busy": "2024-08-08T07:12:12.659953Z", | |
"iopub.status.idle": "2024-08-08T07:12:12.697704Z", | |
"shell.execute_reply": "2024-08-08T07:12:12.696830Z", | |
"shell.execute_reply.started": "2024-08-08T07:12:12.660462Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<|im_start|>system\n", | |
"You are a world class AI model who answers questions in JSON with correct Pydantic schema. Here's the json schema you must adhere to:\n", | |
"<schema>\n", | |
"{'$defs': {'Reasoning_Step': {'properties': {'reasoning_step': {'description': 'Reasoning step', 'title': 'Reasoning Step', 'type': 'string'}}, 'required': ['reasoning_step'], 'title': 'Reasoning_Step', 'type': 'object'}}, 'properties': {'reasoning': {'description': 'List of reasoning steps', 'items': {'$ref': '#/$defs/Reasoning_Step'}, 'title': 'Reasoning', 'type': 'array'}, 'conclusion': {'description': 'Conclusion', 'title': 'Conclusion', 'type': 'string'}}, 'required': ['reasoning', 'conclusion'], 'title': 'Reasoning', 'type': 'object'}\n", | |
"</schema>\n", | |
"<|im_start|>user\n", | |
"9.11 and 9.9 -- which is bigger?<|im_end|>\n", | |
"<|im_start|>assistant\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"prompt = generate_hermes_prompt(\"9.11 and 9.9 -- which is bigger?\", schema=Reasoning.model_json_schema())\n", | |
"print(prompt)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "c090e919-7743-48b2-9ae9-7a196939fb7b", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:12:13.800446Z", | |
"iopub.status.busy": "2024-08-08T07:12:13.799903Z", | |
"iopub.status.idle": "2024-08-08T07:12:13.836576Z", | |
"shell.execute_reply": "2024-08-08T07:12:13.835439Z", | |
"shell.execute_reply.started": "2024-08-08T07:12:13.800397Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"model = models.LlamaCpp(llm)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "809349a7-451e-4888-ae24-ffe88fc50402", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:12:21.701769Z", | |
"iopub.status.busy": "2024-08-08T07:12:21.701211Z", | |
"iopub.status.idle": "2024-08-08T07:12:30.100184Z", | |
"shell.execute_reply": "2024-08-08T07:12:30.099261Z", | |
"shell.execute_reply.started": "2024-08-08T07:12:21.701720Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Compiling FSM index for all state transitions: 100%|██████████| 80/80 [00:02<00:00, 31.27it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"generator = generate.regex(\n", | |
" model,\n", | |
" regex_str,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "8b6574e0-6d76-4ca7-9841-57b8f1fec7d5", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2024-08-08T07:12:31.146824Z", | |
"iopub.status.busy": "2024-08-08T07:12:31.146210Z", | |
"iopub.status.idle": "2024-08-08T07:12:45.357296Z", | |
"shell.execute_reply": "2024-08-08T07:12:45.356781Z", | |
"shell.execute_reply.started": "2024-08-08T07:12:31.146767Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'{\"reasoning\": [{\"reasoning_step\": \"Compare the whole numbers: 9 in 9.11 and 9 in 9.9.\"}, {\"reasoning_step\": \"The whole numbers are equal, so compare the decimal parts: 0.11 and 0.09.\"}, {\"reasoning_step\": \"Compare the tenths place: 1 in 0.11 and 0 in 0.09.\"}], \"conclusion\": \"9.11 is bigger than 9.9.\"}'" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"response = generator(prompt, max_tokens=1024, temperature=0, seed=42)\n", | |
"response" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment