Created
December 22, 2023 14:16
-
-
Save danveloper/9264860ac71651eb4f37d92d52621301 to your computer and use it in GitHub Desktop.
Translate Llama2-7B-functions output model to OAI JSON spec
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, Request | |
import json | |
import requests | |
""" | |
{ | |
"messages": [ | |
{ | |
"content": "you are immensely smart function AI", | |
"role": "system" | |
}, | |
{ | |
"content": "Search for the latest AI news.", | |
"role": "user" | |
} | |
], | |
"model": "gpt-4", | |
"functions": [ | |
{ | |
"name": "search_bing", | |
"description": "Search the web for content on Bing. This allows users to search online/the internet/the web for content.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": { | |
"type": "string", | |
"description": "The search query string." | |
} | |
}, | |
"required": [ | |
"query" | |
] | |
} | |
} | |
], | |
"stream": false | |
} | |
""" | |
def _map_oai_func_to_llama(func: dict): | |
""" | |
{ | |
"name": "search_bing", | |
"description": "Search the web for content on Bing. This allows users to search online/the internet/the web for content.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": { | |
"type": "string", | |
"description": "The search query string." | |
} | |
}, | |
"required": [ | |
"query" | |
] | |
} | |
} | |
""" | |
corrected_func = { 'function': func['name'], 'description': func['description'], 'arguments': [] } | |
if func['parameters'] is not None: | |
if func['parameters']['properties'] is not None: | |
for arg_name, details in func['parameters']['properties'].items(): | |
corrected_func['arguments'].append({ | |
'name': arg_name, | |
'type': details['type'], | |
'description': details['description'] | |
}) | |
return corrected_func | |
app = FastAPI() | |
@app.post("/v1/chat/completions") | |
async def chat(orig_request : dict): | |
new_request = "<s>" | |
if "functions" in orig_request: | |
funcs = [] | |
for func in orig_request['functions']: | |
funcs.append(_map_oai_func_to_llama(func)) | |
new_request += "<FUNCTIONS>" | |
for func in funcs: | |
new_request += f"{json.dumps(func)}\n" | |
new_request += "</FUNCTIONS>\n\n" | |
seen_s = False | |
for message in orig_request['messages']: | |
if message['content'] is None: | |
continue | |
if message['role'] == "user": | |
new_request += f"{'<s>' if seen_s else ''}[INST]{message['content']}[/INST]" | |
seen_s = True | |
if message['role'] == "assistant": | |
new_request += f"{message['content']}</s>" | |
if message['role'] == "function": | |
new_request += f"{message['content']}</s>" | |
print(f"\n\n new_request --> {new_request}") | |
# Define the API endpoint | |
url = "http://localhost:8080/completion" | |
# Send the POST request to the API server | |
response = requests.post(url, json={"prompt": new_request}).json() | |
print(f"\n\n response --> {response}") | |
response['choices'] = [] | |
try: | |
resp = json.loads(response['content']) | |
if resp['function'] is not None: | |
function_name = resp['function'] | |
response['choices'].append({ 'message': { | |
'role': 'assistant', | |
'content': None, | |
'function_call': { 'name': function_name, 'arguments': json.dumps(resp['arguments']) }, | |
'finish_reason': 'function_call' | |
}}) | |
except: | |
response['choices'].append({ 'message': { | |
'role': 'assistant', | |
'content': response['content'] | |
}}) | |
# Print the response | |
return response |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment