Skip to content

Instantly share code, notes, and snippets.

@hizkifw
Created August 31, 2022 13:11
Show Gist options
  • Save hizkifw/1608fbcd0ffecd719adbd5612cfd0e3d to your computer and use it in GitHub Desktop.
Save hizkifw/1608fbcd0ffecd719adbd5612cfd0e3d to your computer and use it in GitHub Desktop.
Discord bot to generate images using Stable Diffusion
"""
discord.py==2.0.1
diffusers==0.2.4
transformers
scipy
ftfy
"""
import os
import time
import discord
import asyncio
from discord import app_commands
from diffusers import StableDiffusionPipeline
from threading import Thread, Lock
config = {
"huggingface_token": os.environ.get("HUGGINGFACE_TOKEN", None),
"discord_token": os.environ.get("DISCORD_TOKEN", None),
"discord_guild_id": int(os.environ.get("DISCORD_GUILD_ID", None)),
}
MY_GUILD = discord.Object(id=config["discord_guild_id"])
class MyClient(discord.Client):
def __init__(self, *, intents: discord.Intents):
super().__init__(intents=intents)
self.tree = app_commands.CommandTree(self)
async def setup_hook(self):
self.tree.copy_global_to(guild=MY_GUILD)
await self.tree.sync(guild=MY_GUILD)
intents = discord.Intents.default()
client = MyClient(intents=intents)
class TaskQueue(object):
"""
TaskQueue is a queue that can be used to request tasks to be performed, and
periodically check if the tasks are ready to be processed.
"""
def __init__(self):
self.tasks = {}
self.results = {}
self.insert_id = 0
self.process_id = 0
self.mutex = Lock()
def add_task(self, task):
"""
Add a task to the queue.
"""
with self.mutex:
self.tasks[self.insert_id] = task
self.insert_id += 1
return self.insert_id - 1
def get_position(self, task_id):
"""
Get the position of a task in the queue.
"""
with self.mutex:
return task_id - self.process_id
def get_task_to_process(self):
"""
Get the next task to be processed.
"""
with self.mutex:
id = self.process_id
if id == self.insert_id:
return id, None
task = self.tasks[id]
del self.tasks[id]
self.process_id += 1
return id, task
def get_result(self, task_id):
"""
Get the result of a task.
"""
with self.mutex:
result = self.results[task_id]
del self.results[task_id]
return result
def is_ready(self, task_id):
"""
Check if a task is ready to be processed.
"""
with self.mutex:
return task_id in self.results
def submit_result(self, task_id, result):
"""
Submit the result of a task.
"""
with self.mutex:
self.results[task_id] = result
# Global application state
queue = TaskQueue()
avg_time = 230
avg_time_lock = Lock()
@client.event
async def on_ready():
print(f"Logged in as {client.user} (ID: {client.user.id})")
print("------")
@client.tree.command()
@app_commands.describe(prompt="Describe the image to generate.")
async def dream(interaction: discord.Interaction, prompt: str):
global avg_time
# Send a message
await interaction.response.send_message("Hang on a sec...")
message = await interaction.original_response()
# Queue the task
task_id = queue.add_task({"prompt": prompt.strip()})
last_position = None
loop_sleep = 1
while not queue.is_ready(task_id):
# Get the queue position
position = queue.get_position(task_id)
if position != last_position:
avg_time_local = 0
with avg_time_lock:
avg_time_local = avg_time
eta = int(time.time() + (avg_time_local * (position + 2)))
eta_str = "ETA unknown" if avg_time_local == 0 else f"ETA <t:{eta}:R>"
try:
if position < 0:
await message.edit(content=f"Dreaming... {eta_str}\n`{prompt}`")
else:
await message.edit(
content=f"Waiting... {position+1} in queue, {eta_str}\n`{prompt}`"
)
# Update the queue position if the message was edited
last_position = position
except:
# Bump up the sleep time if the message couldn't be edited
loop_sleep += 1
# Sleep for a bit
await asyncio.sleep(loop_sleep)
# Get the result
result = queue.get_result(task_id)
# Update the average time
with avg_time_lock:
elapsed = result["time_elapsed"]
avg_time = (avg_time * 4 + elapsed) / 5
if result["nsfw_content_detected"]:
await message.edit(content="NSFW content detected. Aborted.")
return
# Send the result
await message.edit(
content=f"`{prompt}`", attachments=[discord.File(result["image_path"])]
)
# Background task to process tasks in the queue
def process_tasks():
print("Starting task processing thread")
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", use_auth_token=config["huggingface_token"]
)
print("Initialized model, waiting for tasks")
while True:
id, task = queue.get_task_to_process()
if task is None:
time.sleep(1)
continue
print(f"Processing task {id}: {task}")
time_start = time.time()
result = pipe(
task["prompt"],
guidance_scale=7.5,
num_inference_steps=8,
width=768,
height=512,
)
time_end = time.time()
time_elapsed = time_end - time_start
print(f"Task {id} finished in {time_elapsed:.2f} seconds")
print(f"Done processing task {id}, saving image...")
timenow = time.strftime("%Y%m%d%H%M%S")
fname = f"dream/dream_{timenow}_{id}.png"
result["sample"][0].save(fname)
print(f"Done saving image {fname}")
queue.submit_result(
id,
{
"nsfw_content_detected": result["nsfw_content_detected"][0],
"image_path": fname,
"time_elapsed": time_elapsed,
},
)
# Start the background task
thread = Thread(target=process_tasks)
thread.start()
# Run the bot
client.run(config["discord_token"])
# Join the background task
thread.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment