Skip to content

Instantly share code, notes, and snippets.

@alonsosilvaallende
Last active August 7, 2024 20:16
Show Gist options
  • Save alonsosilvaallende/27213240a19aa94f308a38b00a913b5b to your computer and use it in GitHub Desktop.
Save alonsosilvaallende/27213240a19aa94f308a38b00a913b5b to your computer and use it in GitHub Desktop.
Structured Outputs with Outlines+Hermes
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "3d169209-a562-465a-bbe8-7580fe05c3ce",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:42.331411Z",
"iopub.status.busy": "2024-08-07T20:14:42.330874Z",
"iopub.status.idle": "2024-08-07T20:14:42.372611Z",
"shell.execute_reply": "2024-08-07T20:14:42.371402Z",
"shell.execute_reply.started": "2024-08-07T20:14:42.331343Z"
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b33f0a08-a699-4682-a50a-e5b21acb7645",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:42.838673Z",
"iopub.status.busy": "2024-08-07T20:14:42.838160Z",
"iopub.status.idle": "2024-08-07T20:14:42.854839Z",
"shell.execute_reply": "2024-08-07T20:14:42.853533Z",
"shell.execute_reply.started": "2024-08-07T20:14:42.838623Z"
}
},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"ignore\", category=RuntimeWarning) # ignore runtime warnings"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "96cdbccc-c584-4966-a442-02741a171ab2",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:43.399221Z",
"iopub.status.busy": "2024-08-07T20:14:43.398710Z",
"iopub.status.idle": "2024-08-07T20:14:43.510281Z",
"shell.execute_reply": "2024-08-07T20:14:43.508967Z",
"shell.execute_reply.started": "2024-08-07T20:14:43.399173Z"
}
},
"outputs": [],
"source": [
"from pydantic import BaseModel, Field\n",
"\n",
"class Reasoning_Step(BaseModel):\n",
" reasoning_step: str = Field(..., description=\"Reasoning step\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "532995c5-f076-4bf7-b294-ed70332c1c10",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:44.016348Z",
"iopub.status.busy": "2024-08-07T20:14:44.015808Z",
"iopub.status.idle": "2024-08-07T20:14:44.038348Z",
"shell.execute_reply": "2024-08-07T20:14:44.037085Z",
"shell.execute_reply.started": "2024-08-07T20:14:44.016298Z"
}
},
"outputs": [],
"source": [
"from typing import List\n",
"\n",
"class Reasoning(BaseModel):\n",
" reasoning: List[Reasoning_Step] = Field(..., description=\"List of reasoning steps\")\n",
" conclusion: str = Field(..., description=\"Conclusion\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a125eb20-efb2-4274-a42d-a35f58d9db54",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:44.635274Z",
"iopub.status.busy": "2024-08-07T20:14:44.634244Z",
"iopub.status.idle": "2024-08-07T20:14:44.660004Z",
"shell.execute_reply": "2024-08-07T20:14:44.658754Z",
"shell.execute_reply.started": "2024-08-07T20:14:44.635223Z"
}
},
"outputs": [],
"source": [
"json_schema = Reasoning.model_json_schema()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "61ee86ce-3503-43ae-879d-f947aee88623",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:45.282287Z",
"iopub.status.busy": "2024-08-07T20:14:45.281773Z",
"iopub.status.idle": "2024-08-07T20:14:48.955568Z",
"shell.execute_reply": "2024-08-07T20:14:48.954148Z",
"shell.execute_reply.started": "2024-08-07T20:14:45.282238Z"
}
},
"outputs": [],
"source": [
"from outlines.integrations.utils import convert_json_schema_to_str\n",
"from outlines.fsm.json_schema import build_regex_from_schema\n",
"\n",
"def get_regex_from_Pydantic_class(pydantic_class):\n",
" json_schema = pydantic_class.model_json_schema()\n",
" schema_str = convert_json_schema_to_str(json_schema=json_schema)\n",
" regex_str = build_regex_from_schema(schema_str, whitespace_pattern=r\" \")\n",
" return regex_str"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f6e5e037-f005-47c8-9f30-9e4eeb5ef48a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:48.957977Z",
"iopub.status.busy": "2024-08-07T20:14:48.957296Z",
"iopub.status.idle": "2024-08-07T20:14:48.986551Z",
"shell.execute_reply": "2024-08-07T20:14:48.985904Z",
"shell.execute_reply.started": "2024-08-07T20:14:48.957937Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"'\\\\{ \"reasoning\" : \\\\[ ((\\\\{ \"reasoning_step\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\})(, (\\\\{ \"reasoning_step\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\})){0,})? \\\\] , \"conclusion\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\}'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"regex_str = get_regex_from_Pydantic_class(Reasoning)\n",
"regex_str"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "64f2fe73-56f7-4533-9357-394d5c6555dd",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:51.316487Z",
"iopub.status.busy": "2024-08-07T20:14:51.315935Z",
"iopub.status.idle": "2024-08-07T20:14:54.274501Z",
"shell.execute_reply": "2024-08-07T20:14:54.273492Z",
"shell.execute_reply.started": "2024-08-07T20:14:51.316436Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"import llama_cpp\n",
"from llama_cpp import Llama\n",
"from outlines import generate, models\n",
"\n",
"llm = Llama(\n",
" \"/big_storage/llms/models/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf\",\n",
" tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(\n",
" \"NousResearch/Hermes-2-Pro-Llama-3-8B\"\n",
" ),\n",
" n_gpu_layers=-1,\n",
" flash_attn=True,\n",
" n_ctx=8192,\n",
" verbose=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3fddb1a4-fc3d-4a81-bfb6-ab07c94c16e2",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:54.275697Z",
"iopub.status.busy": "2024-08-07T20:14:54.275509Z",
"iopub.status.idle": "2024-08-07T20:14:54.294054Z",
"shell.execute_reply": "2024-08-07T20:14:54.292995Z",
"shell.execute_reply.started": "2024-08-07T20:14:54.275680Z"
}
},
"outputs": [],
"source": [
"def generate_hermes_prompt(question, schema=\"\"):\n",
" return (\n",
" \"<|im_start|>system\\n\"\n",
" \"You are a world class AI model who answers questions in JSON with correct Pydantic schema. \"\n",
" f\"Here's the json schema you must adhere to:\\n<schema>\\n{schema}\\n</schema>\"\n",
" \"\\n<|im_start|>user\\n\" + question + \"<|im_end|>\"\n",
" \"\\n<|im_start|>assistant\\n\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0cc7dede-c0b8-48aa-ab44-14deaea9064d",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:55.174614Z",
"iopub.status.busy": "2024-08-07T20:14:55.174089Z",
"iopub.status.idle": "2024-08-07T20:14:55.213874Z",
"shell.execute_reply": "2024-08-07T20:14:55.212761Z",
"shell.execute_reply.started": "2024-08-07T20:14:55.174565Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|im_start|>system\n",
"You are a world class AI model who answers questions in JSON with correct Pydantic schema. Here's the json schema you must adhere to:\n",
"<schema>\n",
"{'$defs': {'Reasoning_Step': {'properties': {'reasoning_step': {'description': 'Reasoning step', 'title': 'Reasoning Step', 'type': 'string'}}, 'required': ['reasoning_step'], 'title': 'Reasoning_Step', 'type': 'object'}}, 'properties': {'reasoning': {'description': 'List of reasoning steps', 'items': {'$ref': '#/$defs/Reasoning_Step'}, 'title': 'Reasoning', 'type': 'array'}, 'conclusion': {'description': 'Conclusion', 'title': 'Conclusion', 'type': 'string'}}, 'required': ['reasoning', 'conclusion'], 'title': 'Reasoning', 'type': 'object'}\n",
"</schema>\n",
"<|im_start|>user\n",
"9.11 and 9.9 -- which is bigger?<|im_end|>\n",
"<|im_start|>assistant\n",
"\n"
]
}
],
"source": [
"prompt = generate_hermes_prompt(\"9.11 and 9.9 -- which is bigger?\", schema=Reasoning.model_json_schema())\n",
"print(prompt)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c090e919-7743-48b2-9ae9-7a196939fb7b",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:56.638844Z",
"iopub.status.busy": "2024-08-07T20:14:56.638166Z",
"iopub.status.idle": "2024-08-07T20:14:56.676125Z",
"shell.execute_reply": "2024-08-07T20:14:56.675297Z",
"shell.execute_reply.started": "2024-08-07T20:14:56.638792Z"
}
},
"outputs": [],
"source": [
"model = models.LlamaCpp(llm)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "809349a7-451e-4888-ae24-ffe88fc50402",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:14:59.155260Z",
"iopub.status.busy": "2024-08-07T20:14:59.154070Z",
"iopub.status.idle": "2024-08-07T20:14:59.384102Z",
"shell.execute_reply": "2024-08-07T20:14:59.383068Z",
"shell.execute_reply.started": "2024-08-07T20:14:59.155207Z"
}
},
"outputs": [],
"source": [
"generator = generate.regex(\n",
" model,\n",
" regex_str,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8b6574e0-6d76-4ca7-9841-57b8f1fec7d5",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-07T20:15:00.101880Z",
"iopub.status.busy": "2024-08-07T20:15:00.101120Z",
"iopub.status.idle": "2024-08-07T20:15:03.386404Z",
"shell.execute_reply": "2024-08-07T20:15:03.385585Z",
"shell.execute_reply.started": "2024-08-07T20:15:00.101823Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"'{ \"reasoning\" : [ { \"reasoning_step\" : \"Both 9.11 and 9.9 are decimal numbers.\" }, { \"reasoning_step\" : \"When comparing decimal numbers, we look at the numbers after the decimal point.\" }, { \"reasoning_step\" : \"In this case, 9.11 has the number 1 after the decimal point, while 9.9 has the number 9.\" }, { \"reasoning_step\" : \"Since 1 is greater than 9, 9.11 is greater than 9.9.\" } ], \"conclusion\" : \"9.11 is bigger.\" }'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response = generator(prompt, max_tokens=1024, temperature=0, seed=0)\n",
"response"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment