Last active
March 3, 2023 13:47
-
-
Save hizkifw/1401f9e78649e7f498eb985ded9ddfce to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
ultra jank discord bot for automatic1111's stable diffusion webui | |
uses the --api | |
""" | |
import re | |
import os | |
import io | |
import time | |
import json | |
import base64 | |
import discord | |
import asyncio | |
import requests | |
from discord import app_commands | |
from dataclasses import dataclass | |
from threading import Thread, Lock | |
from typing import Any, Optional, Tuple | |
from enum import Enum | |
config = { | |
"discord_token": "...", | |
"discord_guild_id": 123, | |
} | |
MY_GUILD = discord.Object(id=config["discord_guild_id"]) | |
class MyClient(discord.Client): | |
def __init__(self, *, intents: discord.Intents): | |
super().__init__(intents=intents) | |
self.tree = app_commands.CommandTree(self) | |
async def setup_hook(self): | |
self.tree.copy_global_to(guild=MY_GUILD) | |
await self.tree.sync(guild=MY_GUILD) | |
intents = discord.Intents.default() | |
client = MyClient(intents=intents) | |
class TaskType(Enum): | |
TXT2IMG = "txt2img" | |
IMG2IMG = "img2img" | |
@dataclass | |
class Task: | |
task: TaskType | |
prompt: str | |
image: Optional[str] = None | |
width: int = 512 | |
height: int = 768 | |
steps: int = 20 | |
cfg_scale: int = 7 | |
seed: int = -1 | |
class TaskQueue(object): | |
""" | |
TaskQueue is a queue that can be used to request tasks to be performed, and | |
periodically check if the tasks are ready to be processed. | |
""" | |
def __init__(self): | |
self.tasks: dict[int, Task] = {} | |
self.results: dict[int, Any] = {} | |
self.insert_id = 0 | |
self.process_id = 0 | |
self.mutex = Lock() | |
def add_task(self, task: Task): | |
""" | |
Add a task to the queue. | |
""" | |
with self.mutex: | |
self.tasks[self.insert_id] = task | |
self.insert_id += 1 | |
return self.insert_id - 1 | |
def get_position(self, task_id: int) -> int: | |
""" | |
Get the position of a task in the queue. | |
""" | |
with self.mutex: | |
return task_id - self.process_id | |
def get_task_to_process(self) -> Tuple[int, Optional[Task]]: | |
""" | |
Get the next task to be processed. | |
""" | |
with self.mutex: | |
id = self.process_id | |
if id == self.insert_id: | |
return id, None | |
task = self.tasks[id] | |
del self.tasks[id] | |
self.process_id += 1 | |
return id, task | |
def get_result(self, task_id: int) -> Any: | |
""" | |
Get the result of a task. | |
""" | |
with self.mutex: | |
result = self.results[task_id] | |
del self.results[task_id] | |
return result | |
def is_ready(self, task_id: int) -> bool: | |
""" | |
Check if a task is ready to be processed. | |
""" | |
with self.mutex: | |
return task_id in self.results | |
def submit_result(self, task_id: int, result): | |
""" | |
Submit the result of a task. | |
""" | |
with self.mutex: | |
self.results[task_id] = result | |
# Global application state | |
queue = TaskQueue() | |
num_steps = 6 | |
avg_time = 15 | |
avg_time_lock = Lock() | |
@client.event | |
async def on_ready(): | |
print(f"Logged in as {client.user} (ID: {client.user.id})") | |
print("------") | |
def make_param_string(task_params: Task): | |
return f"`{task_params.prompt}`\n{task_params.width:,}x{task_params.height:,}, {task_params.steps:,} steps, cfg scale {task_params.cfg_scale:,}" | |
@client.tree.command() | |
@app_commands.describe(prompt="Describe the image to generate.") | |
async def dream( | |
interaction: discord.Interaction, | |
prompt: str, | |
image_url: str = "", | |
width: int = 512, | |
height: int = 768, | |
steps: int = 28, | |
cfg_scale: int = 12, | |
seed: int = -1, | |
): | |
# Fix the prompt | |
prompt = re.sub(r"[^\x00-\x7F]+", " ", prompt).strip() | |
task_type = TaskType.TXT2IMG if image_url == "" else TaskType.IMG2IMG | |
image = None | |
if task_type == TaskType.IMG2IMG: | |
# Only accept images from Discord | |
if not image_url.startswith( | |
"https://cdn.discordapp.com/attachments/" | |
) and not image_url.startswith("https://media.discordapp.net/attachments/"): | |
return await interaction.response.send_message( | |
"Only images from Discord are accepted.", ephemeral=True | |
) | |
# Download the image | |
try: | |
image = requests.get(image_url).content | |
image = base64.b64encode(image).decode("utf-8") | |
except Exception as e: | |
print("Error downloading image:", e) | |
return await interaction.response.send_message( | |
"Failed to download the image.", ephemeral=True | |
) | |
# Prompt parameters | |
task_params = Task( | |
task=task_type, | |
prompt=prompt, | |
image=image, | |
width=width if width < 1024 else 1024, | |
height=height if height < 1024 else 1024, | |
steps=steps if steps < 50 else 50, | |
cfg_scale=cfg_scale if cfg_scale < 200 else 200, | |
seed=seed if seed >= 0 else -1, | |
) | |
# Send a message | |
await interaction.response.send_message(f"Waiting...\n`{prompt}`") | |
message = await interaction.original_response() | |
# Queue the task | |
task_id = queue.add_task(task_params) | |
await handle_message(message, task_params, task_id) | |
async def handle_message(message, task_params: Task, task_id: int): | |
global avg_time | |
param_string = make_param_string(task_params) | |
last_position = None | |
loop_sleep = 1 | |
while not queue.is_ready(task_id): | |
# Get the queue position | |
position = queue.get_position(task_id) | |
if position != last_position: | |
avg_time_local = 0 | |
with avg_time_lock: | |
avg_time_local = avg_time | |
eta = int(time.time() + (avg_time_local * (position + 2))) | |
eta_str = "ETA unknown" if avg_time_local == 0 else f"ETA <t:{eta}:R>" | |
try: | |
if position < 0: | |
await message.edit(content=f"Dreaming... {eta_str}\n{param_string}") | |
else: | |
await message.edit( | |
content=f"Waiting... {position+1} in queue, {eta_str}\n{param_string}" | |
) | |
# Update the queue position if the message was edited | |
last_position = position | |
except Exception as e: | |
print(f"Couldn't edit message: {e}") | |
# Bump up the sleep time if the message couldn't be edited | |
loop_sleep += 1 | |
# Sleep for a bit | |
await asyncio.sleep(loop_sleep) | |
# Get the result | |
result = queue.get_result(task_id) | |
# If error, send the error message | |
if "error" in result: | |
await message.edit(content=f"Error: {result['error']}\n{param_string}") | |
return | |
seed = result["params"]["seed"] | |
# Update the average time | |
with avg_time_lock: | |
elapsed = result["time_elapsed"] | |
avg_time = (avg_time * 1 + elapsed) / 2 | |
attachments = [] | |
if task_params.task == TaskType.IMG2IMG: | |
# Also send the original image for IMG2IMG | |
assert task_params.image is not None | |
attachments.append( | |
discord.File( | |
io.BytesIO(base64.b64decode(task_params.image)), | |
filename=f"original.png", | |
) | |
) | |
attachments.append(discord.File(result["image_path"])) | |
# Send the result | |
await message.edit( | |
content=param_string + f", seed {seed}, took {elapsed:.2f}s", | |
attachments=attachments, | |
) | |
# Background task to process tasks in the queue | |
def process_tasks(): | |
print("Waiting for tasks") | |
while True: | |
try: | |
id, task = queue.get_task_to_process() | |
if task is None: | |
time.sleep(1) | |
continue | |
print(f"Processing task {id}: {task}") | |
time_start = time.time() | |
if task.task == TaskType.TXT2IMG: | |
fname, params = process_txt2img(task) | |
elif task.task == TaskType.IMG2IMG: | |
fname, params = process_img2img(task) | |
time_end = time.time() | |
time_elapsed = time_end - time_start | |
print(f"Task {id} finished in {time_elapsed:.2f} seconds") | |
print(f"Done saving image {fname}") | |
queue.submit_result( | |
id, | |
{ | |
"image_path": fname, | |
"time_elapsed": time_elapsed, | |
"params": params, | |
}, | |
) | |
except KeyboardInterrupt: | |
print("Keyboard interrupt, exiting") | |
break | |
except Exception as e: | |
print("Error processing task:", e) | |
queue.submit_result(id, {"error": e}) | |
def parse_img_response(response_json) -> Tuple[str, Any]: | |
# Parse base64 to binary | |
image_bytes = base64.b64decode(response_json["images"][0]) | |
# Write binary to file | |
fname = f"outputs/bot/dream_{time.time()}.png" | |
with open(fname, "wb") as f: | |
f.write(image_bytes) | |
return fname, json.loads(response_json["info"]) | |
# Returns image path and params json | |
def process_txt2img(task: Task) -> Tuple[str, Any]: | |
request_json = { | |
"prompt": task.prompt, | |
"styles": ["korean doll"], | |
"seed": task.seed, | |
"sampler_name": "DPM++ SDE Karras", | |
"steps": task.steps, | |
"cfg_scale": task.cfg_scale, | |
"width": task.width, | |
"height": task.height, | |
"restore_faces": True, | |
} | |
# Create the request | |
r = requests.post( | |
"http://127.0.0.1:7860/sdapi/v1/txt2img", | |
json=request_json, | |
timeout=300, | |
) | |
print("Sent request, got", r.status_code) | |
return parse_img_response(r.json()) | |
def process_img2img(task: Task) -> Tuple[str, Any]: | |
assert task.image is not None | |
request_json = { | |
"init_images": [task.image], | |
"resize_mode": 0, | |
"denoising_strength": 0.75, | |
"prompt": task.prompt, | |
"styles": ["korean doll"], | |
"seed": task.seed, | |
"sampler_name": "DPM++ SDE Karras", | |
"steps": task.steps, | |
"cfg_scale": task.cfg_scale, | |
"width": task.width, | |
"height": task.height, | |
"restore_faces": True, | |
} | |
# Create the request | |
r = requests.post( | |
"http://127.0.0.1:7860/sdapi/v1/img2img", | |
json=request_json, | |
timeout=300, | |
) | |
print("Sent request, got", r.status_code) | |
return parse_img_response(r.json()) | |
# Start the background task | |
thread = Thread(target=process_tasks) | |
thread.start() | |
# Run the bot | |
client.run(config["discord_token"]) | |
# Join the background task | |
thread.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment