Skip to content

Instantly share code, notes, and snippets.

@hizkifw
Last active March 3, 2023 13:47
Show Gist options
  • Save hizkifw/1401f9e78649e7f498eb985ded9ddfce to your computer and use it in GitHub Desktop.
Save hizkifw/1401f9e78649e7f498eb985ded9ddfce to your computer and use it in GitHub Desktop.
"""
ultra jank discord bot for automatic1111's stable diffusion webui
uses the --api
"""
import re
import os
import io
import time
import json
import base64
import discord
import asyncio
import requests
from discord import app_commands
from dataclasses import dataclass
from threading import Thread, Lock
from typing import Any, Optional, Tuple
from enum import Enum
config = {
"discord_token": "...",
"discord_guild_id": 123,
}
MY_GUILD = discord.Object(id=config["discord_guild_id"])
class MyClient(discord.Client):
def __init__(self, *, intents: discord.Intents):
super().__init__(intents=intents)
self.tree = app_commands.CommandTree(self)
async def setup_hook(self):
self.tree.copy_global_to(guild=MY_GUILD)
await self.tree.sync(guild=MY_GUILD)
intents = discord.Intents.default()
client = MyClient(intents=intents)
class TaskType(Enum):
TXT2IMG = "txt2img"
IMG2IMG = "img2img"
@dataclass
class Task:
task: TaskType
prompt: str
image: Optional[str] = None
width: int = 512
height: int = 768
steps: int = 20
cfg_scale: int = 7
seed: int = -1
class TaskQueue(object):
"""
TaskQueue is a queue that can be used to request tasks to be performed, and
periodically check if the tasks are ready to be processed.
"""
def __init__(self):
self.tasks: dict[int, Task] = {}
self.results: dict[int, Any] = {}
self.insert_id = 0
self.process_id = 0
self.mutex = Lock()
def add_task(self, task: Task):
"""
Add a task to the queue.
"""
with self.mutex:
self.tasks[self.insert_id] = task
self.insert_id += 1
return self.insert_id - 1
def get_position(self, task_id: int) -> int:
"""
Get the position of a task in the queue.
"""
with self.mutex:
return task_id - self.process_id
def get_task_to_process(self) -> Tuple[int, Optional[Task]]:
"""
Get the next task to be processed.
"""
with self.mutex:
id = self.process_id
if id == self.insert_id:
return id, None
task = self.tasks[id]
del self.tasks[id]
self.process_id += 1
return id, task
def get_result(self, task_id: int) -> Any:
"""
Get the result of a task.
"""
with self.mutex:
result = self.results[task_id]
del self.results[task_id]
return result
def is_ready(self, task_id: int) -> bool:
"""
Check if a task is ready to be processed.
"""
with self.mutex:
return task_id in self.results
def submit_result(self, task_id: int, result):
"""
Submit the result of a task.
"""
with self.mutex:
self.results[task_id] = result
# Global application state
queue = TaskQueue()
num_steps = 6
avg_time = 15
avg_time_lock = Lock()
@client.event
async def on_ready():
print(f"Logged in as {client.user} (ID: {client.user.id})")
print("------")
def make_param_string(task_params: Task):
return f"`{task_params.prompt}`\n{task_params.width:,}x{task_params.height:,}, {task_params.steps:,} steps, cfg scale {task_params.cfg_scale:,}"
@client.tree.command()
@app_commands.describe(prompt="Describe the image to generate.")
async def dream(
interaction: discord.Interaction,
prompt: str,
image_url: str = "",
width: int = 512,
height: int = 768,
steps: int = 28,
cfg_scale: int = 12,
seed: int = -1,
):
# Fix the prompt
prompt = re.sub(r"[^\x00-\x7F]+", " ", prompt).strip()
task_type = TaskType.TXT2IMG if image_url == "" else TaskType.IMG2IMG
image = None
if task_type == TaskType.IMG2IMG:
# Only accept images from Discord
if not image_url.startswith(
"https://cdn.discordapp.com/attachments/"
) and not image_url.startswith("https://media.discordapp.net/attachments/"):
return await interaction.response.send_message(
"Only images from Discord are accepted.", ephemeral=True
)
# Download the image
try:
image = requests.get(image_url).content
image = base64.b64encode(image).decode("utf-8")
except Exception as e:
print("Error downloading image:", e)
return await interaction.response.send_message(
"Failed to download the image.", ephemeral=True
)
# Prompt parameters
task_params = Task(
task=task_type,
prompt=prompt,
image=image,
width=width if width < 1024 else 1024,
height=height if height < 1024 else 1024,
steps=steps if steps < 50 else 50,
cfg_scale=cfg_scale if cfg_scale < 200 else 200,
seed=seed if seed >= 0 else -1,
)
# Send a message
await interaction.response.send_message(f"Waiting...\n`{prompt}`")
message = await interaction.original_response()
# Queue the task
task_id = queue.add_task(task_params)
await handle_message(message, task_params, task_id)
async def handle_message(message, task_params: Task, task_id: int):
global avg_time
param_string = make_param_string(task_params)
last_position = None
loop_sleep = 1
while not queue.is_ready(task_id):
# Get the queue position
position = queue.get_position(task_id)
if position != last_position:
avg_time_local = 0
with avg_time_lock:
avg_time_local = avg_time
eta = int(time.time() + (avg_time_local * (position + 2)))
eta_str = "ETA unknown" if avg_time_local == 0 else f"ETA <t:{eta}:R>"
try:
if position < 0:
await message.edit(content=f"Dreaming... {eta_str}\n{param_string}")
else:
await message.edit(
content=f"Waiting... {position+1} in queue, {eta_str}\n{param_string}"
)
# Update the queue position if the message was edited
last_position = position
except Exception as e:
print(f"Couldn't edit message: {e}")
# Bump up the sleep time if the message couldn't be edited
loop_sleep += 1
# Sleep for a bit
await asyncio.sleep(loop_sleep)
# Get the result
result = queue.get_result(task_id)
# If error, send the error message
if "error" in result:
await message.edit(content=f"Error: {result['error']}\n{param_string}")
return
seed = result["params"]["seed"]
# Update the average time
with avg_time_lock:
elapsed = result["time_elapsed"]
avg_time = (avg_time * 1 + elapsed) / 2
attachments = []
if task_params.task == TaskType.IMG2IMG:
# Also send the original image for IMG2IMG
assert task_params.image is not None
attachments.append(
discord.File(
io.BytesIO(base64.b64decode(task_params.image)),
filename=f"original.png",
)
)
attachments.append(discord.File(result["image_path"]))
# Send the result
await message.edit(
content=param_string + f", seed {seed}, took {elapsed:.2f}s",
attachments=attachments,
)
# Background task to process tasks in the queue
def process_tasks():
print("Waiting for tasks")
while True:
try:
id, task = queue.get_task_to_process()
if task is None:
time.sleep(1)
continue
print(f"Processing task {id}: {task}")
time_start = time.time()
if task.task == TaskType.TXT2IMG:
fname, params = process_txt2img(task)
elif task.task == TaskType.IMG2IMG:
fname, params = process_img2img(task)
time_end = time.time()
time_elapsed = time_end - time_start
print(f"Task {id} finished in {time_elapsed:.2f} seconds")
print(f"Done saving image {fname}")
queue.submit_result(
id,
{
"image_path": fname,
"time_elapsed": time_elapsed,
"params": params,
},
)
except KeyboardInterrupt:
print("Keyboard interrupt, exiting")
break
except Exception as e:
print("Error processing task:", e)
queue.submit_result(id, {"error": e})
def parse_img_response(response_json) -> Tuple[str, Any]:
# Parse base64 to binary
image_bytes = base64.b64decode(response_json["images"][0])
# Write binary to file
fname = f"outputs/bot/dream_{time.time()}.png"
with open(fname, "wb") as f:
f.write(image_bytes)
return fname, json.loads(response_json["info"])
# Returns image path and params json
def process_txt2img(task: Task) -> Tuple[str, Any]:
request_json = {
"prompt": task.prompt,
"styles": ["korean doll"],
"seed": task.seed,
"sampler_name": "DPM++ SDE Karras",
"steps": task.steps,
"cfg_scale": task.cfg_scale,
"width": task.width,
"height": task.height,
"restore_faces": True,
}
# Create the request
r = requests.post(
"http://127.0.0.1:7860/sdapi/v1/txt2img",
json=request_json,
timeout=300,
)
print("Sent request, got", r.status_code)
return parse_img_response(r.json())
def process_img2img(task: Task) -> Tuple[str, Any]:
assert task.image is not None
request_json = {
"init_images": [task.image],
"resize_mode": 0,
"denoising_strength": 0.75,
"prompt": task.prompt,
"styles": ["korean doll"],
"seed": task.seed,
"sampler_name": "DPM++ SDE Karras",
"steps": task.steps,
"cfg_scale": task.cfg_scale,
"width": task.width,
"height": task.height,
"restore_faces": True,
}
# Create the request
r = requests.post(
"http://127.0.0.1:7860/sdapi/v1/img2img",
json=request_json,
timeout=300,
)
print("Sent request, got", r.status_code)
return parse_img_response(r.json())
# Start the background task
thread = Thread(target=process_tasks)
thread.start()
# Run the bot
client.run(config["discord_token"])
# Join the background task
thread.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment