Created
August 31, 2022 13:11
-
-
Save hizkifw/1608fbcd0ffecd719adbd5612cfd0e3d to your computer and use it in GitHub Desktop.
Discord bot to generate images using Stable Diffusion
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
discord.py==2.0.1 | |
diffusers==0.2.4 | |
transformers | |
scipy | |
ftfy | |
""" | |
import os | |
import time | |
import discord | |
import asyncio | |
from discord import app_commands | |
from diffusers import StableDiffusionPipeline | |
from threading import Thread, Lock | |
config = { | |
"huggingface_token": os.environ.get("HUGGINGFACE_TOKEN", None), | |
"discord_token": os.environ.get("DISCORD_TOKEN", None), | |
"discord_guild_id": int(os.environ.get("DISCORD_GUILD_ID", None)), | |
} | |
MY_GUILD = discord.Object(id=config["discord_guild_id"]) | |
class MyClient(discord.Client): | |
def __init__(self, *, intents: discord.Intents): | |
super().__init__(intents=intents) | |
self.tree = app_commands.CommandTree(self) | |
async def setup_hook(self): | |
self.tree.copy_global_to(guild=MY_GUILD) | |
await self.tree.sync(guild=MY_GUILD) | |
intents = discord.Intents.default() | |
client = MyClient(intents=intents) | |
class TaskQueue(object): | |
""" | |
TaskQueue is a queue that can be used to request tasks to be performed, and | |
periodically check if the tasks are ready to be processed. | |
""" | |
def __init__(self): | |
self.tasks = {} | |
self.results = {} | |
self.insert_id = 0 | |
self.process_id = 0 | |
self.mutex = Lock() | |
def add_task(self, task): | |
""" | |
Add a task to the queue. | |
""" | |
with self.mutex: | |
self.tasks[self.insert_id] = task | |
self.insert_id += 1 | |
return self.insert_id - 1 | |
def get_position(self, task_id): | |
""" | |
Get the position of a task in the queue. | |
""" | |
with self.mutex: | |
return task_id - self.process_id | |
def get_task_to_process(self): | |
""" | |
Get the next task to be processed. | |
""" | |
with self.mutex: | |
id = self.process_id | |
if id == self.insert_id: | |
return id, None | |
task = self.tasks[id] | |
del self.tasks[id] | |
self.process_id += 1 | |
return id, task | |
def get_result(self, task_id): | |
""" | |
Get the result of a task. | |
""" | |
with self.mutex: | |
result = self.results[task_id] | |
del self.results[task_id] | |
return result | |
def is_ready(self, task_id): | |
""" | |
Check if a task is ready to be processed. | |
""" | |
with self.mutex: | |
return task_id in self.results | |
def submit_result(self, task_id, result): | |
""" | |
Submit the result of a task. | |
""" | |
with self.mutex: | |
self.results[task_id] = result | |
# Global application state | |
queue = TaskQueue() | |
avg_time = 230 | |
avg_time_lock = Lock() | |
@client.event | |
async def on_ready(): | |
print(f"Logged in as {client.user} (ID: {client.user.id})") | |
print("------") | |
@client.tree.command() | |
@app_commands.describe(prompt="Describe the image to generate.") | |
async def dream(interaction: discord.Interaction, prompt: str): | |
global avg_time | |
# Send a message | |
await interaction.response.send_message("Hang on a sec...") | |
message = await interaction.original_response() | |
# Queue the task | |
task_id = queue.add_task({"prompt": prompt.strip()}) | |
last_position = None | |
loop_sleep = 1 | |
while not queue.is_ready(task_id): | |
# Get the queue position | |
position = queue.get_position(task_id) | |
if position != last_position: | |
avg_time_local = 0 | |
with avg_time_lock: | |
avg_time_local = avg_time | |
eta = int(time.time() + (avg_time_local * (position + 2))) | |
eta_str = "ETA unknown" if avg_time_local == 0 else f"ETA <t:{eta}:R>" | |
try: | |
if position < 0: | |
await message.edit(content=f"Dreaming... {eta_str}\n`{prompt}`") | |
else: | |
await message.edit( | |
content=f"Waiting... {position+1} in queue, {eta_str}\n`{prompt}`" | |
) | |
# Update the queue position if the message was edited | |
last_position = position | |
except: | |
# Bump up the sleep time if the message couldn't be edited | |
loop_sleep += 1 | |
# Sleep for a bit | |
await asyncio.sleep(loop_sleep) | |
# Get the result | |
result = queue.get_result(task_id) | |
# Update the average time | |
with avg_time_lock: | |
elapsed = result["time_elapsed"] | |
avg_time = (avg_time * 4 + elapsed) / 5 | |
if result["nsfw_content_detected"]: | |
await message.edit(content="NSFW content detected. Aborted.") | |
return | |
# Send the result | |
await message.edit( | |
content=f"`{prompt}`", attachments=[discord.File(result["image_path"])] | |
) | |
# Background task to process tasks in the queue | |
def process_tasks(): | |
print("Starting task processing thread") | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"CompVis/stable-diffusion-v1-4", use_auth_token=config["huggingface_token"] | |
) | |
print("Initialized model, waiting for tasks") | |
while True: | |
id, task = queue.get_task_to_process() | |
if task is None: | |
time.sleep(1) | |
continue | |
print(f"Processing task {id}: {task}") | |
time_start = time.time() | |
result = pipe( | |
task["prompt"], | |
guidance_scale=7.5, | |
num_inference_steps=8, | |
width=768, | |
height=512, | |
) | |
time_end = time.time() | |
time_elapsed = time_end - time_start | |
print(f"Task {id} finished in {time_elapsed:.2f} seconds") | |
print(f"Done processing task {id}, saving image...") | |
timenow = time.strftime("%Y%m%d%H%M%S") | |
fname = f"dream/dream_{timenow}_{id}.png" | |
result["sample"][0].save(fname) | |
print(f"Done saving image {fname}") | |
queue.submit_result( | |
id, | |
{ | |
"nsfw_content_detected": result["nsfw_content_detected"][0], | |
"image_path": fname, | |
"time_elapsed": time_elapsed, | |
}, | |
) | |
# Start the background task | |
thread = Thread(target=process_tasks) | |
thread.start() | |
# Run the bot | |
client.run(config["discord_token"]) | |
# Join the background task | |
thread.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment