Skip to content

Instantly share code, notes, and snippets.

@alonsosilvaallende
Last active August 8, 2024 07:27
Show Gist options
  • Save alonsosilvaallende/0521d509ed89cead87d973f47fca2964 to your computer and use it in GitHub Desktop.
Save alonsosilvaallende/0521d509ed89cead87d973f47fca2964 to your computer and use it in GitHub Desktop.
Structured Outputs with Outlines+Hermes70
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "3d169209-a562-465a-bbe8-7580fe05c3ce",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:10:45.052514Z",
"iopub.status.busy": "2024-08-08T07:10:45.052019Z",
"iopub.status.idle": "2024-08-08T07:10:45.080962Z",
"shell.execute_reply": "2024-08-08T07:10:45.079578Z",
"shell.execute_reply.started": "2024-08-08T07:10:45.052463Z"
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b33f0a08-a699-4682-a50a-e5b21acb7645",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:10:46.127082Z",
"iopub.status.busy": "2024-08-08T07:10:46.126588Z",
"iopub.status.idle": "2024-08-08T07:10:46.143929Z",
"shell.execute_reply": "2024-08-08T07:10:46.142492Z",
"shell.execute_reply.started": "2024-08-08T07:10:46.127035Z"
}
},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"ignore\", category=RuntimeWarning) # ignore runtime warnings"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "96cdbccc-c584-4966-a442-02741a171ab2",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:10:46.712614Z",
"iopub.status.busy": "2024-08-08T07:10:46.712104Z",
"iopub.status.idle": "2024-08-08T07:10:46.838002Z",
"shell.execute_reply": "2024-08-08T07:10:46.837097Z",
"shell.execute_reply.started": "2024-08-08T07:10:46.712566Z"
}
},
"outputs": [],
"source": [
"from pydantic import BaseModel, Field\n",
"\n",
"class Reasoning_Step(BaseModel):\n",
" reasoning_step: str = Field(..., description=\"Reasoning step\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "532995c5-f076-4bf7-b294-ed70332c1c10",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:10:47.364049Z",
"iopub.status.busy": "2024-08-08T07:10:47.363486Z",
"iopub.status.idle": "2024-08-08T07:10:47.385444Z",
"shell.execute_reply": "2024-08-08T07:10:47.384082Z",
"shell.execute_reply.started": "2024-08-08T07:10:47.363997Z"
}
},
"outputs": [],
"source": [
"from typing import List\n",
"\n",
"class Reasoning(BaseModel):\n",
" reasoning: List[Reasoning_Step] = Field(..., description=\"List of reasoning steps\")\n",
" conclusion: str = Field(..., description=\"Conclusion\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a125eb20-efb2-4274-a42d-a35f58d9db54",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:10:48.314992Z",
"iopub.status.busy": "2024-08-08T07:10:48.314485Z",
"iopub.status.idle": "2024-08-08T07:10:48.338411Z",
"shell.execute_reply": "2024-08-08T07:10:48.337160Z",
"shell.execute_reply.started": "2024-08-08T07:10:48.314945Z"
}
},
"outputs": [],
"source": [
"json_schema = Reasoning.model_json_schema()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "61ee86ce-3503-43ae-879d-f947aee88623",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:10:49.271964Z",
"iopub.status.busy": "2024-08-08T07:10:49.271430Z",
"iopub.status.idle": "2024-08-08T07:10:52.941327Z",
"shell.execute_reply": "2024-08-08T07:10:52.940042Z",
"shell.execute_reply.started": "2024-08-08T07:10:49.271917Z"
}
},
"outputs": [],
"source": [
"from outlines.integrations.utils import convert_json_schema_to_str\n",
"from outlines.fsm.json_schema import build_regex_from_schema\n",
"\n",
"def get_regex_from_Pydantic_class(pydantic_class):\n",
" json_schema = pydantic_class.model_json_schema()\n",
" schema_str = convert_json_schema_to_str(json_schema=json_schema)\n",
" regex_str = build_regex_from_schema(schema_str)\n",
" return regex_str"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f6e5e037-f005-47c8-9f30-9e4eeb5ef48a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:10:53.649740Z",
"iopub.status.busy": "2024-08-08T07:10:53.648888Z",
"iopub.status.idle": "2024-08-08T07:10:53.692343Z",
"shell.execute_reply": "2024-08-08T07:10:53.691279Z",
"shell.execute_reply.started": "2024-08-08T07:10:53.649687Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"'\\\\{[ ]?\"reasoning\"[ ]?:[ ]?\\\\[[ ]?((\\\\{[ ]?\"reasoning_step\"[ ]?:[ ]?\"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\"[ ]?\\\\})(,[ ]?(\\\\{[ ]?\"reasoning_step\"[ ]?:[ ]?\"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\"[ ]?\\\\})){0,})?[ ]?\\\\][ ]?,[ ]?\"conclusion\"[ ]?:[ ]?\"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\"[ ]?\\\\}'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"regex_str = get_regex_from_Pydantic_class(Reasoning)\n",
"regex_str"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "64f2fe73-56f7-4533-9357-394d5c6555dd",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:10:56.210740Z",
"iopub.status.busy": "2024-08-08T07:10:56.210186Z",
"iopub.status.idle": "2024-08-08T07:11:04.506512Z",
"shell.execute_reply": "2024-08-08T07:11:04.505545Z",
"shell.execute_reply.started": "2024-08-08T07:10:56.210690Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"import llama_cpp\n",
"from llama_cpp import Llama\n",
"from outlines import generate, models\n",
"\n",
"llm = Llama(\n",
" \"/big_storage/llms/models/Hermes-2-Pro-Llama-3-70B_Q4_K_M.gguf\",\n",
" tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(\n",
" \"NousResearch/Hermes-2-Pro-Llama-3-70B\"\n",
" ),\n",
" n_gpu_layers=-1,\n",
" flash_attn=True,\n",
" n_ctx=8192,\n",
" verbose=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3fddb1a4-fc3d-4a81-bfb6-ab07c94c16e2",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:12:11.224836Z",
"iopub.status.busy": "2024-08-08T07:12:11.224185Z",
"iopub.status.idle": "2024-08-08T07:12:11.262220Z",
"shell.execute_reply": "2024-08-08T07:12:11.260828Z",
"shell.execute_reply.started": "2024-08-08T07:12:11.224782Z"
}
},
"outputs": [],
"source": [
"def generate_hermes_prompt(question, schema=\"\"):\n",
" return (\n",
" \"<|im_start|>system\\n\"\n",
" \"You are a world class AI model who answers questions in JSON with correct Pydantic schema. \"\n",
" f\"Here's the json schema you must adhere to:\\n<schema>\\n{schema}\\n</schema>\"\n",
" \"\\n<|im_start|>user\\n\" + question + \"<|im_end|>\"\n",
" \"\\n<|im_start|>assistant\\n\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0cc7dede-c0b8-48aa-ab44-14deaea9064d",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:12:12.660511Z",
"iopub.status.busy": "2024-08-08T07:12:12.659953Z",
"iopub.status.idle": "2024-08-08T07:12:12.697704Z",
"shell.execute_reply": "2024-08-08T07:12:12.696830Z",
"shell.execute_reply.started": "2024-08-08T07:12:12.660462Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|im_start|>system\n",
"You are a world class AI model who answers questions in JSON with correct Pydantic schema. Here's the json schema you must adhere to:\n",
"<schema>\n",
"{'$defs': {'Reasoning_Step': {'properties': {'reasoning_step': {'description': 'Reasoning step', 'title': 'Reasoning Step', 'type': 'string'}}, 'required': ['reasoning_step'], 'title': 'Reasoning_Step', 'type': 'object'}}, 'properties': {'reasoning': {'description': 'List of reasoning steps', 'items': {'$ref': '#/$defs/Reasoning_Step'}, 'title': 'Reasoning', 'type': 'array'}, 'conclusion': {'description': 'Conclusion', 'title': 'Conclusion', 'type': 'string'}}, 'required': ['reasoning', 'conclusion'], 'title': 'Reasoning', 'type': 'object'}\n",
"</schema>\n",
"<|im_start|>user\n",
"9.11 and 9.9 -- which is bigger?<|im_end|>\n",
"<|im_start|>assistant\n",
"\n"
]
}
],
"source": [
"prompt = generate_hermes_prompt(\"9.11 and 9.9 -- which is bigger?\", schema=Reasoning.model_json_schema())\n",
"print(prompt)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c090e919-7743-48b2-9ae9-7a196939fb7b",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:12:13.800446Z",
"iopub.status.busy": "2024-08-08T07:12:13.799903Z",
"iopub.status.idle": "2024-08-08T07:12:13.836576Z",
"shell.execute_reply": "2024-08-08T07:12:13.835439Z",
"shell.execute_reply.started": "2024-08-08T07:12:13.800397Z"
}
},
"outputs": [],
"source": [
"model = models.LlamaCpp(llm)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "809349a7-451e-4888-ae24-ffe88fc50402",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:12:21.701769Z",
"iopub.status.busy": "2024-08-08T07:12:21.701211Z",
"iopub.status.idle": "2024-08-08T07:12:30.100184Z",
"shell.execute_reply": "2024-08-08T07:12:30.099261Z",
"shell.execute_reply.started": "2024-08-08T07:12:21.701720Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Compiling FSM index for all state transitions: 100%|██████████| 80/80 [00:02<00:00, 31.27it/s]\n"
]
}
],
"source": [
"generator = generate.regex(\n",
" model,\n",
" regex_str,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8b6574e0-6d76-4ca7-9841-57b8f1fec7d5",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-08T07:12:31.146824Z",
"iopub.status.busy": "2024-08-08T07:12:31.146210Z",
"iopub.status.idle": "2024-08-08T07:12:45.357296Z",
"shell.execute_reply": "2024-08-08T07:12:45.356781Z",
"shell.execute_reply.started": "2024-08-08T07:12:31.146767Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"'{\"reasoning\": [{\"reasoning_step\": \"Compare the whole numbers: 9 in 9.11 and 9 in 9.9.\"}, {\"reasoning_step\": \"The whole numbers are equal, so compare the decimal parts: 0.11 and 0.09.\"}, {\"reasoning_step\": \"Compare the tenths place: 1 in 0.11 and 0 in 0.09.\"}], \"conclusion\": \"9.11 is bigger than 9.9.\"}'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response = generator(prompt, max_tokens=1024, temperature=0, seed=42)\n",
"response"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment