Created
April 4, 2023 11:52
-
-
Save theosanderson/611ddd712c34669124a42b81365416a9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, WebSocket | |
import openai | |
import time | |
import json | |
import asyncio | |
import os | |
#get the api key from the environment variable | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
app = FastAPI() | |
import openai | |
import time | |
import json | |
import asyncio | |
from fastapi.logger import logger | |
@app.websocket("/quiz/{topic}/{difficulty}") | |
async def generate_quiz(websocket: WebSocket, topic: str, difficulty: str): | |
await websocket.accept() | |
async def send_question(question): | |
await websocket.send_json(question) | |
def handle_chunk(chunk, buffer): | |
# loud print | |
delta = chunk['choices'][0]['delta'] | |
if 'content' in delta: | |
delta_content = delta['content'] | |
#print("delta content", end=" ", flush=True) | |
#print (delta_content) | |
buffer+=delta_content | |
# if delta_content contains \n then split the buffer and send the first part | |
if '\n' in delta_content: | |
#print(buffer, end=" ", flush=True) | |
# split the buffer and send the first part | |
buffer_parts = buffer.split('\n') | |
if len(buffer_parts) > 1: | |
buffer = buffer_parts[1] | |
question = buffer_parts[0] | |
try: | |
# if line starts with a number. then remove that bit | |
if question[0].isdigit(): | |
question = question[question.find('.')+1:] | |
question_obj = json.loads(question) | |
print("success", end=" ", flush=True) | |
asyncio.run(send_question(question_obj)) | |
except: | |
print("error") | |
print(question, end=" ", flush=True) | |
print(question, end=" ", flush=True) | |
print("\n\n\n") | |
else: | |
buffer = "" | |
return buffer | |
def read_chunks(response): | |
nonlocal buffer | |
for chunk in response: | |
chunk_time = time.time() - start_time | |
buffer = handle_chunk(chunk, buffer) | |
start_time = time.time() | |
user_message = ( | |
"Give me a quiz about " + topic + ". Present 4 choices per question A, B, C and D. " | |
"Then give the correct answer, then move onto the next Q. Use this format. " | |
"Make sure that only one answer is correct. After, your answer, confirm whether " | |
"it is the only correct answer.\n\n" | |
"1.\n{\"question\":\"\", \"responses\" : {\"A\":..,}, \"answer\":\"X\", \"only_correct_answer\":true}\n\n" | |
"Give 20 questions. Make them "+difficulty+" . " | |
) | |
response = openai.ChatCompletion.create( | |
model='gpt-4', | |
messages=[ | |
{"role": "user", "content": user_message} | |
], | |
temperature=0.5, | |
stream=True, | |
max_tokens=4000, | |
) | |
buffer = "" | |
await asyncio.to_thread(read_chunks, response) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment