Skip to content

Instantly share code, notes, and snippets.

@init27
Created August 12, 2024 15:46
Show Gist options
  • Save init27/92176d52334065f897172dee90df076e to your computer and use it in GitHub Desktop.
Save init27/92176d52334065f897172dee90df076e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The HF Method is broken, it somehow forces the model to go into ALL FUNCTION CALLING MODE"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Steps:\n",
"\n",
"- Import from HF\n",
"- Set checkpoint to 8B\n",
"- Enable iPython in system prompt, without this the model doesnt work as expected"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## HuggingFace implementation currently doesn't work, please see example below. \n",
"\n",
"While they fix their tokenizer, I have used the approach suggested [here](https://github.com/meta-llama/llama-agentic-system/issues/36)\n",
"\n",
"TL;DR Put tool usage in System prompt instead of first user prompt otherwise model goes into function call all the time mode"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"import torch\n",
"\n",
"checkpoint = \"meta-llama/Meta-Llama-3.1-8B-Instruct\" #first see this, then we check 70B\n",
"SYSTEM_PROMPT=\"\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3618361024284f1a8d4eb31ca80928cd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map=\"auto\")\n",
"model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map=\"auto\") # You may want to use bfloat16 and/or move to GPU here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I am noticing by virtue of how hugging face adds these to the model, it misbehaves and starts going in all tool mode for example with 8B:\n",
"\n",
"Which company trained the llama LLM?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"<|python_tag|>{\"name\": \"get_current_temperature\", \"parameters\": {\"location\": \"Menlo Park, California\", \"unit\": \"celsius\"}}<|eom_id|>"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def get_current_temperature(location: str, unit: str) -> float:\n",
" \"\"\"\n",
" Get the current temperature at a location.\n",
" \n",
" Args:\n",
" location: The location to get the temperature for, in the format \"City, Country\"\n",
" unit: The unit to return the temperature in. (choices: [\"celsius\", \"fahrenheit\"])\n",
" Returns:\n",
" The current temperature at the specified location in the specified units, as a float.\n",
" \"\"\"\n",
" return 22. # A real function should probably actually get the temperature!\n",
"\n",
"def get_current_wind_speed(location: str) -> float:\n",
" \"\"\"\n",
" Get the current wind speed in km/h at a given location.\n",
" \n",
" Args:\n",
" location: The location to get the temperature for, in the format \"City, Country\"\n",
" Returns:\n",
" The current wind speed at the given location in km/h, as a float.\n",
" \"\"\"\n",
" return 6. # A real function should probably actually get the wind speed!\n",
"\n",
"tools = [get_current_temperature, get_current_wind_speed]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Environment: ipython\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n",
"\n",
"Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n",
"\n",
"{\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_temperature\",\n",
" \"description\": \"Get the current temperature at a location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"enum\": [\n",
" \"celsius\",\n",
" \"fahrenheit\"\n",
" ],\n",
" \"description\": \"The unit to return the temperature in.\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\",\n",
" \"unit\"\n",
" ]\n",
" },\n",
" \"return\": {\n",
" \"type\": \"number\",\n",
" \"description\": \"The current temperature at the specified location in the specified units, as a float.\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"{\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_wind_speed\",\n",
" \"description\": \"Get the current wind speed in km/h at a given location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\"\n",
" ]\n",
" },\n",
" \"return\": {\n",
" \"type\": \"number\",\n",
" \"description\": \"The current wind speed at the given location in km/h, as a float.\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"Who is the CEO of Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"\n"
]
}
],
"source": [
"messages = [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": SYSTEM_PROMPT,\n",
" },\n",
" {\"role\": \"user\", \"content\": \"Who is the CEO of Facebook?\"},\n",
" ]\n",
"tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\",tools=tools)\n",
"print(tokenizer.decode(tokenized_chat[0]))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"/home/sanyambhutani/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py:1885: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Environment: ipython\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n",
"\n",
"Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n",
"\n",
"{\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_temperature\",\n",
" \"description\": \"Get the current temperature at a location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"enum\": [\n",
" \"celsius\",\n",
" \"fahrenheit\"\n",
" ],\n",
" \"description\": \"The unit to return the temperature in.\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\",\n",
" \"unit\"\n",
" ]\n",
" },\n",
" \"return\": {\n",
" \"type\": \"number\",\n",
" \"description\": \"The current temperature at the specified location in the specified units, as a float.\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"{\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_wind_speed\",\n",
" \"description\": \"Get the current wind speed in km/h at a given location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\"\n",
" ]\n",
" },\n",
" \"return\": {\n",
" \"type\": \"number\",\n",
" \"description\": \"The current wind speed at the given location in km/h, as a float.\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"Who is the CEO of Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"<|python_tag|>{\"name\": \"get_current_wind_speed\", \"parameters\": {\"location\": \"Menlo Park, USA\"}}<|eom_id|>\n"
]
}
],
"source": [
"outputs = model.generate(tokenized_chat, max_new_tokens=2048) \n",
"print(tokenizer.decode(outputs[0]))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e0f0be71cf714545ac805d8234a32d2a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/30 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
"/home/sanyambhutani/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py:1885: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Environment: ipython\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n",
"\n",
"Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n",
"\n",
"{\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_temperature\",\n",
" \"description\": \"Get the current temperature at a location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"enum\": [\n",
" \"celsius\",\n",
" \"fahrenheit\"\n",
" ],\n",
" \"description\": \"The unit to return the temperature in.\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\",\n",
" \"unit\"\n",
" ]\n",
" },\n",
" \"return\": {\n",
" \"type\": \"number\",\n",
" \"description\": \"The current temperature at the specified location in the specified units, as a float.\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"{\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_wind_speed\",\n",
" \"description\": \"Get the current wind speed in km/h at a given location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\"\n",
" ]\n",
" },\n",
" \"return\": {\n",
" \"type\": \"number\",\n",
" \"description\": \"The current wind speed at the given location in km/h, as a float.\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"Who is the CEO of Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"There is no function available to find the CEO of Facebook. The provided functions are related to weather information, and there is no function that can be used to find the CEO of a company.<|eot_id|>\n"
]
}
],
"source": [
"checkpoint = \"meta-llama/Meta-Llama-3.1-70B-Instruct\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map=\"auto\")\n",
"model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map=\"auto\") # You may want to use bfloat16 and/or move to GPU here\n",
"\n",
"messages = [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": SYSTEM_PROMPT,\n",
" },\n",
" {\"role\": \"user\", \"content\": \"Who is the CEO of Facebook?\"},\n",
" ]\n",
"tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\",tools=tools)\n",
"#print(tokenizer.decode(tokenized_chat[0]))\n",
"\n",
"outputs = model.generate(tokenized_chat, max_new_tokens=2048) \n",
"print(tokenizer.decode(outputs[0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Error with 8B model:\n",
"\n",
"It doesn't like Wolfram Alpha"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import transformers"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"SYSTEM_PROMPT = \"\"\"\n",
"Here are the tools you have access to, along with brave_search and wolfram_apha which you can use for searching the web and doing math operations respectively:\n",
"\n",
"Use the function 'spotify_trending_songs' to: Get top trending songs on Spotify\n",
"{\n",
" \"name\": \"spotify_trending_songs\",\n",
" \"description\": \"Get top trending songs on Spotify\",\n",
" \"parameters\": {\n",
" \"n\": {\n",
" \"param_type\": \"int\",\n",
" \"description\": \"Number of trending songs to get\",\n",
" \"required\": true\n",
" }\n",
" }\n",
"}\n",
"\n",
"Use the function 'yahoo_weather' to: Get weather reports\n",
"{\n",
" \"name\": \"yahoo_weather\",\n",
" \"description\": \"Check weather details, temperature, wind, etc\",\n",
" \"parameters\": {\n",
" \"city\": {\n",
" \"param_type\": \"string\",\n",
" \"description\": \"Name of City\",\n",
" \"required\": true\n",
" }\n",
" }\n",
"}\n",
"\n",
"Use the function 'yahoo_finance' to: Get Stock Price\n",
"{\n",
" \"name\": \"yahoo_finance\",\n",
" \"description\": \"Get Stock Price\",\n",
" \"parameters\": {\n",
" \"stock\": {\n",
" \"param_type\": \"string\",\n",
" \"description\": \"Name of Stock\",\n",
" \"required\": true\n",
" }\n",
" }\n",
"}\n",
"\n",
"Think very carefully before calling functions.\n",
"If you choose to call a function ONLY reply in the following format with no prefix or suffix:\n",
"\n",
"<function=function_name_being_used>{{\"example_name\": \"example_value\"}}</function>\n",
"\n",
"Reminder:\n",
"- Do not use a function call if you can answer the Q directly, only use it if it satisfies the requirement\n",
"- If looking for real time information use relevant functions before falling back to brave_search\n",
"- Function calls MUST follow the specified format, start with <function= and end with </function>\n",
"- Required parameters MUST be specified\n",
"- Only call one function at a time\n",
"- Put the entire function call reply on one line\n",
"\"\"\"\n",
"checkpoint = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
"#checkpoint = \"meta-llama/Meta-Llama-3.1-70B-Instruct\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[Prompt Ref](https://github.com/meta-llama/llama-agentic-system/issues/36)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b0df751198504f488470f67a6b5cb3b8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pipeline = transformers.pipeline(\n",
" \"text-generation\",\n",
" model=checkpoint,\n",
" model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
" device_map=\"auto\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"#\"Who founded facebook?\"\n",
"#\"How is the weather for walking llamas tomorrow in Menlo Park?\"\n",
"#\"\"What festivals are happening in Mountain View tomorrow?\" <-8B fails at this, does weather search instead of brave_search"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'role': 'assistant', 'content': '<function=wolfram_alpha>{{\"query\": \"densest metals on earth\"}}</function>'}\n"
]
}
],
"source": [
"messages = [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": \"What are the densest metals on earth?\"},\n",
"]\n",
"\n",
"#\"content\": \"What is the sqrt of 21313131\"}\n",
"\n",
"outputs = pipeline(\n",
" messages,\n",
" max_new_tokens=2048,\n",
")\n",
"print(outputs[0][\"generated_text\"][-1])\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<function=wolfram_alpha>{{\"query\": \"densest metals on earth\"}}</function>\n"
]
}
],
"source": [
"print(outputs[0][\"generated_text\"][-1]['content'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 70B works fine"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"#checkpoint = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
"checkpoint = \"meta-llama/Meta-Llama-3.1-70B-Instruct\""
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d877526cc5964818895f19a96b052d53",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/30 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pipeline = transformers.pipeline(\n",
" \"text-generation\",\n",
" model=checkpoint,\n",
" model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
" device_map=\"auto\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'role': 'assistant', 'content': '<function=wolfram_alpha>{}</function>'}\n"
]
}
],
"source": [
"messages = [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": \"What are the densest metals on earth?\"},\n",
"]\n",
"\n",
"#\"content\": \"What is the sqrt of 21313131\"}\n",
"\n",
"outputs = pipeline(\n",
" messages,\n",
" max_new_tokens=2048,\n",
")\n",
"print(outputs[0][\"generated_text\"][-1])\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<function=wolfram_alpha>{}</function>\n"
]
}
],
"source": [
"print(outputs[0][\"generated_text\"][-1]['content'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## TUTORIAL STARTS HERE"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/inference/local_inference/Getting_Started_With_Tool_Calling.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
"\n",
"# Tool Calling with llama 3.1 models\n",
"\n",
"Tool (or Function) calling became popularised through the Gorilla paper which showcased that Large Language Models (LLMs) can be used effectively to call other programs. \n",
"\n",
"**Why is this useful?**\n",
"\n",
"LLMs are great at planning and \"thinking\" around how to execute a program. We can leverage this fact to connect them with multiple tools and allow the model to act as \"brains\" of a larger program. \n",
"\n",
"In other words, llama 3.1 can order your pizza for you :) \n",
"\n",
"llama 3.1 model family really shines in tool calling, which has been a popular community request. \n",
"\n",
"This notebook gives a starter tutorial followed by an implementation on how to use the 8B model and connect it with external tools:\n",
"- We use locally hosted HuggingFace weights of `llama 3.1-8B-Instruct` model\n",
"- Understand how function calling happens when we interact with the model\n",
"- Set the System prompt to give the model knowledge of existing tools\n",
"- Profit!\n",
"\n",
"Note: If you'd like to look at examples using hosting models, please check the following links:\n",
"- []()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Understanding Tool Calling behaviour of llama 3.1\n",
"\n",
"Let's first understand how tool calling behaviour works in the new models\n",
"\n",
"llama-3.1 model family allows you to extend the model access to tools, by defining them in the system prompt. Let's take a first example where we give the model access to 3 additional tools:\n",
"- Finance\n",
"- Weather\n",
"- Spotify"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import transformers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the system prompt, we share the expected behaviour of all tools with the model"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"SYSTEM_PROMPT = \"\"\"\n",
"Here are the tools you have access to, along with brave_search and wolfram_apha which you can use for searching the web and doing math operations respectively:\n",
"\n",
"Use the function 'spotify_trending_songs' to: Get top trending songs on Spotify\n",
"{\n",
" \"name\": \"spotify_trending_songs\",\n",
" \"description\": \"Get top trending songs on Spotify\",\n",
" \"parameters\": {\n",
" \"n\": {\n",
" \"param_type\": \"int\",\n",
" \"description\": \"Number of trending songs to get\",\n",
" \"required\": true\n",
" }\n",
" }\n",
"}\n",
"\n",
"Use the function 'yahoo_weather' to: Get weather reports\n",
"{\n",
" \"name\": \"yahoo_weather\",\n",
" \"description\": \"Check weather details, temperature, wind, etc\",\n",
" \"parameters\": {\n",
" \"city\": {\n",
" \"param_type\": \"string\",\n",
" \"description\": \"Name of City\",\n",
" \"required\": true\n",
" }\n",
" }\n",
"}\n",
"\n",
"Use the function 'yahoo_finance' to: Get Stock Price\n",
"{\n",
" \"name\": \"yahoo_finance\",\n",
" \"description\": \"Get Stock Price\",\n",
" \"parameters\": {\n",
" \"stock\": {\n",
" \"param_type\": \"string\",\n",
" \"description\": \"Name of Stock\",\n",
" \"required\": true\n",
" }\n",
" }\n",
"}\n",
"\n",
"Think very carefully before calling functions.\n",
"If you choose to call a function ONLY reply in the following format with no prefix or suffix:\n",
"\n",
"<function=function_name_being_used>{{\"example_name\": \"example_value\"}}</function>\n",
"\n",
"Reminder:\n",
"- Do not use a function call if you can answer the Q directly, only use it if it satisfies the requirement\n",
"- If looking for real time information use relevant functions before falling back to brave_search\n",
"- Function calls MUST follow the specified format, start with <function= and end with </function>\n",
"- Required parameters MUST be specified\n",
"- Only call one function at a time\n",
"- Put the entire function call reply on one line\n",
"\"\"\"\n",
"checkpoint = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
"#checkpoint = \"meta-llama/Meta-Llama-3.1-70B-Instruct\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To simplify things, we can use `pipeline` function from `transformers` library. This will allows us to focus on the output directly"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9ef7587377904766b1c831d6b39222a8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pipeline = transformers.pipeline(\n",
" \"text-generation\",\n",
" model=checkpoint,\n",
" model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
" device_map=\"auto\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After we have moved the model on the GPU, we can start passing the messsages to it"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'role': 'assistant', 'content': '<function=yahoo_finance>{{\"stock\": \"GameStop Corporation (GME)\"}}</function>'}\n"
]
}
],
"source": [
"messages = [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": \"What is price of GameStonk today?\"},\n",
"]\n",
"\n",
"#\"content\": \"What is the sqrt of 21313131\"}\n",
"\n",
"outputs = pipeline(\n",
" messages,\n",
" max_new_tokens=2048,\n",
")\n",
"print(outputs[0][\"generated_text\"][-1])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So to extract the function call, we can just use the last message and call the API"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<function=yahoo_finance>{{\"stock\": \"GameStop Corporation (GME)\"}}</function>\n"
]
}
],
"source": [
"output = outputs[0][\"generated_text\"][-1]['content']\n",
"print(output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At this point, we need to implement a method "
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'role': 'assistant', 'content': '<function=yahoo_weather>{{\"city\": \"Menlo Park\"}}</function>'}\n"
]
}
],
"source": [
"messages.append({\"role\":\"user\", \"content\":\"Is it a good weather to walk llamas in Menlo Park Today?\"})\n",
"outputs = pipeline(\n",
" messages,\n",
" max_new_tokens=2048,\n",
")\n",
"print(outputs[0][\"generated_text\"][-1])\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'role': 'assistant', 'content': \"The Lands Between, a realm of strife,\\nWhere Tarnished souls, in darkness rife,\\nSeek to claim the Elden Ring's great might,\\nAnd bring order to the twisted night.\\n\\nThe Erdtree looms, a sight to behold,\\nA symbol of the Ring's ancient hold,\\nThe Fingers, four, with power and with guile,\\nSeek to claim it, and the world's great trial.\\n\\nFrom Stormveil to Volcano Manor's gate,\\nThe Tarnished journey, through trials and fate,\\nMeet the demigods, with strength and with will,\\nAnd face the horrors, that in darkness dwell.\\n\\nMalenia's beauty, a cruel deceit,\\nA princess, with a heart, that's incomplete,\\nMohg's dark past, a tale of woe and pain,\\nA god, who's lost, in his own, dark reign.\\n\\nThe Tarnished, with their own, dark past,\\nSeek redemption, in this, endless quest,\\nThrough blood and fire, they make their way,\\nTo claim the Ring, and bring a new day.\\n\\nBut will they succeed, or fall to the ground?\\nOnly time will tell, in this, endless round,\\nThe Elden Ring, a symbol of power and might,\\nA test of wills, in the dark of night.\"}\n"
]
}
],
"source": [
"messages.append({\"role\":\"user\", \"content\":\"Write a poem on Elden Ring\"})\n",
"outputs = pipeline(\n",
" messages,\n",
" max_new_tokens=2048,\n",
")\n",
"print(outputs[0][\"generated_text\"][-1])"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'role': 'assistant', 'content': '<function=yahoo_search>{{\"query\": \"Elden Ring successor 2024\"}}</function>'}\n"
]
}
],
"source": [
"messages.append({\"role\":\"user\", \"content\":\"Whats the successor of Elden Ring in 2024?\"})\n",
"outputs = pipeline(\n",
" messages,\n",
" max_new_tokens=2048,\n",
")\n",
"print(outputs[0][\"generated_text\"][-1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So let's recap, function/tool calling involves:\n",
"\n",
"- Reminding llama model it has access to tools by providing them in system prompt\n",
"- Generating responses to tool calls\n",
"- In the above examples, you can see fine-grained examples where model is able to select the right function\n",
"- Implementing the function call and returning the value to the model\n",
"- Generating output using the function call"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The HuggingFace way 🤗:\n",
"\n",
"Now that we understand how tool calling works at a fundamental level, here is how transformers allow using it, we can simply pass a tool list to the inputs"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def get_current_temperature(location: str, unit: str) -> float:\n",
" \"\"\"\n",
" Get the current temperature at a location.\n",
" \n",
" Args:\n",
" location: The location to get the temperature for, in the format \"City, Country\"\n",
" unit: The unit to return the temperature in. (choices: [\"celsius\", \"fahrenheit\"])\n",
" Returns:\n",
" The current temperature at the specified location in the specified units, as a float.\n",
" \"\"\"\n",
" return 22. # A real function should probably actually get the temperature!\n",
"\n",
"def get_current_wind_speed(location: str) -> float:\n",
" \"\"\"\n",
" Get the current wind speed in km/h at a given location.\n",
" \n",
" Args:\n",
" location: The location to get the temperature for, in the format \"City, Country\"\n",
" Returns:\n",
" The current wind speed at the given location in km/h, as a float.\n",
" \"\"\"\n",
" return 6. # A real function should probably actually get the wind speed!\n",
"\n",
"tools = [get_current_temperature, get_current_wind_speed]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Environment: ipython\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"You are a helpful assistant named llama 3.1 with the following functions available along with brave_search and wolfram_alpha<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n",
"\n",
"Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n",
"\n",
"{\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_temperature\",\n",
" \"description\": \"Get the current temperature at a location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"enum\": [\n",
" \"celsius\",\n",
" \"fahrenheit\"\n",
" ],\n",
" \"description\": \"The unit to return the temperature in.\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\",\n",
" \"unit\"\n",
" ]\n",
" },\n",
" \"return\": {\n",
" \"type\": \"number\",\n",
" \"description\": \"The current temperature at the specified location in the specified units, as a float.\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"{\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_wind_speed\",\n",
" \"description\": \"Get the current wind speed in km/h at a given location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\"\n",
" ]\n",
" },\n",
" \"return\": {\n",
" \"type\": \"number\",\n",
" \"description\": \"The current wind speed at the given location in km/h, as a float.\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"What is llama 3.1 release date?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"\n"
]
}
],
"source": [
"messages = [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a helpful assistant named llama 3.1 with the following functions available along with brave_search and wolfram_alpha\",\n",
" },\n",
" {\"role\": \"user\", \"content\": \"What is llama 3.1 release date?\"},\n",
" ]\n",
"tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\",tools=tools)\n",
"print(tokenizer.decode(tokenized_chat[0]))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
"/home/sanyambhutani/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py:1885: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Environment: ipython\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"You are a helpful assistant named llama 3.1 with the following functions available along with brave_search and wolfram_alpha<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n",
"\n",
"Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n",
"\n",
"{\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_temperature\",\n",
" \"description\": \"Get the current temperature at a location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"enum\": [\n",
" \"celsius\",\n",
" \"fahrenheit\"\n",
" ],\n",
" \"description\": \"The unit to return the temperature in.\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\",\n",
" \"unit\"\n",
" ]\n",
" },\n",
" \"return\": {\n",
" \"type\": \"number\",\n",
" \"description\": \"The current temperature at the specified location in the specified units, as a float.\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"{\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_wind_speed\",\n",
" \"description\": \"Get the current wind speed in km/h at a given location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\"\n",
" ]\n",
" },\n",
" \"return\": {\n",
" \"type\": \"number\",\n",
" \"description\": \"The current wind speed at the given location in km/h, as a float.\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"What is llama 3.1 release date?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"<|python_tag|>{\"name\": \"brave_search\", \"parameters\": {\"query\": \"Llama 3.1 release date\"}}<|eom_id|>\n"
]
}
],
"source": [
"outputs = model.generate(tokenized_chat, max_new_tokens=2048) \n",
"print(tokenizer.decode(outputs[0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Function call tags\n",
"\n",
"As you can see the model generates `<python_tag|>` along with an `<|eom_id|>' instead of `<|eot_id|>` which can be used to call the function and send the output to the model"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"#fin"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"captumWidgetMessage": [],
"dataExplorerConfig": [],
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
},
"last_base_url": "https://bento.edge.x2p.facebook.net/",
"last_kernel_id": "161e2a7b-2d2b-4995-87f3-d1539860ecac",
"last_msg_id": "4eab1242-d815b886ebe4f5b1966da982_543",
"last_server_session_id": "4a7b41c5-ed66-4dcb-a376-22673aebb469",
"operator_data": [],
"outputWidgetContext": []
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment