-
-
Save python273/ae9d085ce9f2968b50c6ab90f2017076 to your computer and use it in GitHub Desktop.
not really optimal, generation blocks the eventloop
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import torch | |
from torch import autocast | |
# from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler | |
from diffusers import LMSDiscreteScheduler | |
from my_gen_pipeline import StableDiffusionPipeline | |
from datetime import datetime | |
from conf import UNIX_SOCKET_PATH | |
# torch.cuda.empty_cache() | |
pipe = None | |
lock = None | |
async def handle_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): | |
global lock | |
prompt = await reader.readline() | |
if not prompt.endswith(b'\n'): | |
return | |
print() | |
print(f'prompt: {repr(prompt)}') | |
prompt = prompt.decode('utf-8').strip() | |
print(f'decoded {prompt!r}') | |
async with lock: | |
filename = f"images/img-{datetime.utcnow().isoformat().replace('/', '_')}.png" | |
prompt_filename = filename + '.txt' | |
with autocast("cuda"): | |
image = pipe.gen( | |
prompt, | |
height=640, | |
width=640, | |
guidance_scale=9, | |
num_inference_steps=65 | |
)["sample"][0] | |
with open(prompt_filename, 'w') as f: | |
f.write(prompt) | |
image.save(filename) | |
print('sending', filename) | |
writer.write(f'{filename}\n'.encode('ascii')) | |
await writer.drain() | |
writer.close() | |
await writer.wait_closed() | |
background_tasks = set() | |
def create_connection_task(*args, **kwargs): | |
task = asyncio.create_task(handle_connection(*args, **kwargs)) | |
background_tasks.add(task) | |
task.add_done_callback(background_tasks.discard) | |
async def main(): | |
global pipe, lock | |
lock = asyncio.Lock() | |
# lms = LMSDiscreteScheduler( | |
# beta_start=0.00085, | |
# beta_end=0.012, | |
# beta_schedule="scaled_linear" | |
# ) | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"CompVis/stable-diffusion-v1-4", | |
# scheduler=lms, | |
revision="fp16", torch_dtype=torch.float16, | |
use_auth_token=" TOKEN " | |
) | |
pipe = pipe.to('cuda') | |
server = await asyncio.start_unix_server( | |
create_connection_task, UNIX_SOCKET_PATH | |
) | |
print(server) | |
async with server: | |
await server.serve_forever() | |
if __name__ == '__main__': | |
asyncio.run(main(), debug=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment