Created
July 10, 2025 11:40
-
-
Save dgwyer/4342db6afae325040f4659a11da31127 to your computer and use it in GitHub Desktop.
Run ComfyUI inference on RunPod from a Python script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import websocket | |
import uuid | |
import json | |
import urllib.request | |
import csv | |
import os | |
from PIL import Image | |
import io | |
import time | |
import io as io_mod # For StringIO | |
# 🔧 CONFIG | |
SERVER_ADDRESS = "123456-3000.proxy.runpod.net" # add specific Pod URL here | |
USE_HTTPS = True | |
CSV_PATH = "prompts.csv" | |
OUTPUT_DIR = "images" | |
SEED_MODE = "random" # or "fixed" | |
FIXED_SEED = 417056104978535 | |
client_id = str(uuid.uuid4()) | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
def queue_prompt(prompt): | |
p = {"prompt": prompt, "client_id": client_id} | |
data = json.dumps(p).encode('utf-8') | |
url = f"https://{SERVER_ADDRESS}/api/prompt" | |
req = urllib.request.Request(url, data=data) | |
req.add_header("Content-Type", "application/json") | |
req.add_header("User-Agent", "ComfyUIClient/1.0") | |
return json.loads(urllib.request.urlopen(req).read()) | |
def get_images(ws, prompt): | |
prompt_id = queue_prompt(prompt)["prompt_id"] | |
output_images = {} | |
current_node = "" | |
while True: | |
out = ws.recv() | |
if isinstance(out, str): | |
message = json.loads(out) | |
if message["type"] == "executing": | |
data = message["data"] | |
if data["prompt_id"] == prompt_id: | |
if data["node"] is None: | |
break | |
else: | |
current_node = data["node"] | |
else: | |
if current_node == "42": # Node ID of SaveImageWebsocket | |
images_output = output_images.get(current_node, []) | |
images_output.append(out[8:]) | |
output_images[current_node] = images_output | |
return output_images | |
def build_prompt(prompt_text, seed): | |
return { | |
"6": { | |
"class_type": "CLIPTextEncode", | |
"inputs": { | |
"text": prompt_text, | |
"clip": ["30", 1] | |
} | |
}, | |
"8": { | |
"class_type": "VAEDecode", | |
"inputs": { | |
"samples": ["31", 0], | |
"vae": ["30", 2] | |
} | |
}, | |
"27": { | |
"class_type": "EmptySD3LatentImage", | |
"inputs": { | |
"width": 1024, | |
"height": 1024, | |
"batch_size": 1 | |
} | |
}, | |
"30": { | |
"class_type": "CheckpointLoaderSimple", | |
"inputs": { | |
"ckpt_name": "flux1-schnell-fp8.safetensors" | |
} | |
}, | |
"31": { | |
"class_type": "KSampler", | |
"inputs": { | |
"seed": seed, | |
"steps": 10, | |
"cfg": 1, | |
"sampler_name": "euler", | |
"scheduler": "simple", | |
"denoise": 1, | |
"model": ["30", 0], | |
"positive": ["6", 0], | |
"negative": ["33", 0], | |
"latent_image": ["27", 0] | |
} | |
}, | |
"33": { | |
"class_type": "CLIPTextEncode", | |
"inputs": { | |
"text": "", | |
"clip": ["30", 1] | |
} | |
}, | |
"40": { | |
"class_type": "ImageResize+", | |
"inputs": { | |
"width": 400, | |
"height": 400, | |
"interpolation": "nearest", | |
"method": "stretch", | |
"condition": "always", | |
"multiple_of": 0, | |
"image": ["8", 0] | |
} | |
}, | |
"42": { | |
"class_type": "SaveImageWebsocket", | |
"inputs": { | |
"images": ["40", 0] | |
} | |
} | |
} | |
def fix_malformed_csv_lines(lines): | |
fixed = [] | |
for i, line in enumerate(lines): | |
if not line.strip(): | |
continue # Skip blanks | |
if i == 0: | |
fixed.append(line) # Keep header | |
continue | |
first_field = line.split(",", 1)[0] | |
if "." in first_field and not first_field.isdigit(): | |
# Likely malformed ID: fix only the first period | |
dot_index = line.find(".") | |
comma_index = line.find(",") | |
if 0 <= dot_index < comma_index: # e.g. "4.controversy,..." | |
line = line[:dot_index] + "," + line[dot_index + 1:] | |
fixed.append(line) | |
return fixed | |
import random # ← required for random seed | |
def main(): | |
ws_url = f"wss://{SERVER_ADDRESS}/ws?clientId={client_id}" | |
ws = websocket.WebSocket() | |
ws.connect(ws_url) | |
with open(CSV_PATH, encoding="utf-8") as f: | |
raw_lines = f.readlines() | |
fixed_lines = fix_malformed_csv_lines(raw_lines) | |
reader = csv.DictReader(io_mod.StringIO("".join(fixed_lines))) | |
print(f"Seed mode: {SEED_MODE}") | |
for row in reader: | |
if not row or not row.get("prompt") or not row.get("id"): | |
print(f"⚠️ Skipping blank or incomplete row: {row}") | |
continue | |
raw_id = row["id"].strip() | |
if not raw_id.isdigit(): | |
print(f"⚠️ Skipping row with non-numeric ID: {row}") | |
continue | |
image_id = raw_id | |
prompt_text = row["prompt"].strip().strip('"').strip("'") | |
# 🎲 Decide seed | |
seed = FIXED_SEED if SEED_MODE == "fixed" else random.getrandbits(64) | |
print(f"🎨 Generating image for ID: {image_id}, seed: {seed}") | |
try: | |
prompt = build_prompt(prompt_text, seed) | |
images = get_images(ws, prompt) | |
for i, img_data in enumerate(images["42"]): | |
image = Image.open(io.BytesIO(img_data)) | |
out_path = os.path.join(OUTPUT_DIR, f"{image_id}.png") | |
image.save(out_path) | |
print(f"✅ Saved {out_path}") | |
except Exception as e: | |
print(f"❌ Failed for ID {image_id}: {e}") | |
time.sleep(1.0) | |
ws.close() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment