Last active
May 18, 2022 11:12
-
-
Save afiaka87/473035e4de4b99f60ba8f0f863c66bd9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import discord | |
from datetime import datetime | |
import replicate | |
import asyncio | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
import os | |
REPLICATE_API_TOKEN = os.environ.get("REPLICATE_API_TOKEN") | |
POLL_INTERVAL_SECONDS = 2 | |
if len(REPLICATE_API_TOKEN) == 0: | |
print("Please set the REPLICATE_API_TOKEN environment variable") | |
exit(1) | |
class CodeGenClient(discord.Client): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
async def on_ready(self): | |
print("Logged in as") | |
print(self.user.name) | |
print(self.user.id) | |
print("------") | |
async def create_and_poll_prediction( | |
self, message: discord.Message, input_data: dict, model_name: str = "salesforce/codegen" | |
): | |
model = replicate.models.get(model_name) | |
prediction = replicate.predictions.create( | |
version=model.versions.list()[0], | |
input=input_data, | |
) | |
while True: | |
await asyncio.sleep(POLL_INTERVAL_SECONDS) | |
prediction.reload() | |
print(prediction.status) | |
status = prediction.status | |
current_time = datetime.now().strftime("%H:%M:%S") | |
if status == "succeeded": | |
final_output = prediction.output['raw_output'].replace("======", "").strip() | |
final_output = "```py\n" + final_output + "\n```" | |
final_output = f"```sh\n# console: \n{prediction.logs}\n```\n\n{final_output}" | |
await message.edit(content=final_output) | |
return final_output | |
elif status == "starting": | |
await message.edit(content=f"Prediction still starting... {current_time}") | |
continue | |
elif status == "processing": | |
await message.edit(content=f"The predict() method of the model is currently running... {current_time}") | |
continue | |
elif status == "cancelled": | |
await message.edit(content=f"Prediction cancelled... {current_time}") | |
break | |
elif status == "failed": | |
await message.edit(content=f"Prediction failed. Error:{prediction.error}" + f"{current_time}") | |
break | |
else: | |
raise Exception(f"Unknown status: {status}") | |
async def on_code_block(self, message: discord.Message): | |
response_message = await message.reply( | |
"Detected python code block. Generating completion using codegen.", | |
mention_author=True, | |
) | |
content = message.content | |
content = content.replace("```python", "```") | |
content = content.split("```")[1] | |
input_data = { | |
"context": content, | |
"raw_output": True, | |
} | |
final_message = await self.create_and_poll_prediction( | |
input_data=input_data, | |
message=response_message, | |
) | |
emoji = '\N{THUMBS UP SIGN}' | |
# or '\U0001f44d' or '👍' | |
await response_message.add_reaction(emoji) | |
async def on_message(self, message: discord.Message): | |
# we do not want the bot to reply to itself | |
if message.author.id == self.user.id: | |
print("Ignoring message from self") | |
return | |
if "```python" in message.content: | |
await self.on_code_block(message) | |
if __name__ == "__main__": | |
client = CodeGenClient() | |
client.run(os.environ["DISCORD_API_TOKEN"]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment