Skip to content

Instantly share code, notes, and snippets.

@dgwyer
Created July 10, 2025 11:40
Show Gist options
  • Save dgwyer/4342db6afae325040f4659a11da31127 to your computer and use it in GitHub Desktop.
Save dgwyer/4342db6afae325040f4659a11da31127 to your computer and use it in GitHub Desktop.
Run ComfyUI inference on RunPod from a Python script
import websocket
import uuid
import json
import urllib.request
import csv
import os
from PIL import Image
import io
import time
import io as io_mod # For StringIO
# 🔧 CONFIG
SERVER_ADDRESS = "123456-3000.proxy.runpod.net" # add specific Pod URL here
USE_HTTPS = True
CSV_PATH = "prompts.csv"
OUTPUT_DIR = "images"
SEED_MODE = "random" # or "fixed"
FIXED_SEED = 417056104978535
client_id = str(uuid.uuid4())
os.makedirs(OUTPUT_DIR, exist_ok=True)
def queue_prompt(prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
url = f"https://{SERVER_ADDRESS}/api/prompt"
req = urllib.request.Request(url, data=data)
req.add_header("Content-Type", "application/json")
req.add_header("User-Agent", "ComfyUIClient/1.0")
return json.loads(urllib.request.urlopen(req).read())
def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)["prompt_id"]
output_images = {}
current_node = ""
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["prompt_id"] == prompt_id:
if data["node"] is None:
break
else:
current_node = data["node"]
else:
if current_node == "42": # Node ID of SaveImageWebsocket
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
return output_images
def build_prompt(prompt_text, seed):
return {
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": prompt_text,
"clip": ["30", 1]
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["31", 0],
"vae": ["30", 2]
}
},
"27": {
"class_type": "EmptySD3LatentImage",
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
}
},
"30": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "flux1-schnell-fp8.safetensors"
}
},
"31": {
"class_type": "KSampler",
"inputs": {
"seed": seed,
"steps": 10,
"cfg": 1,
"sampler_name": "euler",
"scheduler": "simple",
"denoise": 1,
"model": ["30", 0],
"positive": ["6", 0],
"negative": ["33", 0],
"latent_image": ["27", 0]
}
},
"33": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": "",
"clip": ["30", 1]
}
},
"40": {
"class_type": "ImageResize+",
"inputs": {
"width": 400,
"height": 400,
"interpolation": "nearest",
"method": "stretch",
"condition": "always",
"multiple_of": 0,
"image": ["8", 0]
}
},
"42": {
"class_type": "SaveImageWebsocket",
"inputs": {
"images": ["40", 0]
}
}
}
def fix_malformed_csv_lines(lines):
fixed = []
for i, line in enumerate(lines):
if not line.strip():
continue # Skip blanks
if i == 0:
fixed.append(line) # Keep header
continue
first_field = line.split(",", 1)[0]
if "." in first_field and not first_field.isdigit():
# Likely malformed ID: fix only the first period
dot_index = line.find(".")
comma_index = line.find(",")
if 0 <= dot_index < comma_index: # e.g. "4.controversy,..."
line = line[:dot_index] + "," + line[dot_index + 1:]
fixed.append(line)
return fixed
import random # ← required for random seed
def main():
ws_url = f"wss://{SERVER_ADDRESS}/ws?clientId={client_id}"
ws = websocket.WebSocket()
ws.connect(ws_url)
with open(CSV_PATH, encoding="utf-8") as f:
raw_lines = f.readlines()
fixed_lines = fix_malformed_csv_lines(raw_lines)
reader = csv.DictReader(io_mod.StringIO("".join(fixed_lines)))
print(f"Seed mode: {SEED_MODE}")
for row in reader:
if not row or not row.get("prompt") or not row.get("id"):
print(f"⚠️ Skipping blank or incomplete row: {row}")
continue
raw_id = row["id"].strip()
if not raw_id.isdigit():
print(f"⚠️ Skipping row with non-numeric ID: {row}")
continue
image_id = raw_id
prompt_text = row["prompt"].strip().strip('"').strip("'")
# 🎲 Decide seed
seed = FIXED_SEED if SEED_MODE == "fixed" else random.getrandbits(64)
print(f"🎨 Generating image for ID: {image_id}, seed: {seed}")
try:
prompt = build_prompt(prompt_text, seed)
images = get_images(ws, prompt)
for i, img_data in enumerate(images["42"]):
image = Image.open(io.BytesIO(img_data))
out_path = os.path.join(OUTPUT_DIR, f"{image_id}.png")
image.save(out_path)
print(f"✅ Saved {out_path}")
except Exception as e:
print(f"❌ Failed for ID {image_id}: {e}")
time.sleep(1.0)
ws.close()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment