Created
April 25, 2024 06:10
-
-
Save Mistobaan/a2e11effc4f1679533517e723a19c908 to your computer and use it in GitHub Desktop.
Query Together API using Async
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm.auto import tqdm | |
import asyncio | |
import aiohttp | |
import json | |
import time | |
import os | |
import textwrap | |
TOGETHER_API_TOKEN = os.environ['TOGETHER_API_TOKEN'] | |
QPS = 100 # max queries per second | |
sem = asyncio.Semaphore(QPS) | |
async def get(sample, session): | |
try: | |
prompt = textwrap.dedent(""" | |
You are a Math teacher. | |
- What are the essential mathematical concepts required to solve the following problem? | |
- In what world context is the problem set? | |
- What is the difficulty level (1-5)? | |
Be concise and specific. Output ONLY JSON. | |
# Example JSON output | |
```json | |
{ | |
"essential_math_concepts": ["percentages", "ratios", "algebraic_equations"], | |
"world_context": ["gardening"], | |
"difficulty": 3 | |
} | |
``` | |
""") | |
endpoint = "https://api.together.xyz/v1/chat/completions" | |
async with sem: | |
async with session.post( | |
endpoint, | |
json={ | |
"model": "meta-llama/Llama-3-70b-chat-hf", | |
"max_tokens": 512, | |
"temperature": 0, | |
"top_p": 0.9, | |
"top_k": 50, | |
"repetition_penalty": 1, | |
# "response_format": {"type": "json_object"}, | |
"stop": ["<|eot_id|>"], | |
"messages": [ | |
{"content": prompt, "role": "system"}, | |
{"content": sample['question'], "role": "user"} | |
], | |
}, | |
headers={ | |
"Authorization": "Bearer "+ TOGETHER_API_TOKEN, | |
}) as response: | |
resp = await response.read() | |
if response.headers['x-ratelimit-remaining'] == 0: | |
time_remaining = int(response.headers['x-ratelimit-reset']) | |
await asyncio.sleep(time_remaining) | |
payload = json.loads(resp) | |
sample['skills_raw'] = payload['choices'][0]['message']['content'] | |
return sample | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
print("Unable to get url {} due to {}.".format(sample, e.__class__)) | |
async def amain(samples): | |
# connector = aiohttp.TCPConnector(limit=2) | |
values = [] | |
async with aiohttp.ClientSession() as session: | |
coroutines = [get(q, session) for q in samples] | |
for f in tqdm(asyncio.as_completed(coroutines), total=len(coroutines)): | |
values.append(await f) | |
return values | |
def main(): | |
import datasets | |
import pickle | |
gsm8k = datasets.load_dataset("mistobaan/gsm8k-train-nomic-text-v1.5") | |
values = asyncio.run(amain(s.copy() for s in gsm8k["train"])) | |
pickle.dump(values, open("values.pickle", "wb")) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment