Skip to content

Instantly share code, notes, and snippets.

@memo
Last active January 8, 2025 18:23
Show Gist options
  • Save memo/3c51abce7af74ce779026ca48f59b9a1 to your computer and use it in GitHub Desktop.
Save memo/3c51abce7af74ce779026ca48f59b9a1 to your computer and use it in GitHub Desktop.
some utils for working with comfyui. Since comfyui enabling/disabling nodes changing pathways is done in the frontend and not the backend, it isn't supported in the api. So the code below also supports working with full workflow files. not just api files, to allow changing pathways at runtime. This is a hacky solution to get aroud this limitatio…
import websocket
import uuid
import json
import urllib.request
import urllib.parse
import os
import json
import base64
from io import BytesIO
from PIL import Image
import time
import msa.logger
logger = msa.logger.getLogger(__name__)
_client_id = None
_server_url = None
_ws = None
def connect(server_url, server_port, verbose=True):
global _client_id
global _server_url
global _ws
if server_port and int(server_port):
server_url = f"{server_url}:{server_port}"
if server_url != _server_url or _ws is None:
_client_id = str(uuid.uuid4())
_server_url = server_url
connect_str = f"{server_url}/ws?clientId={_client_id}"
if not connect_str.startswith("http"):
connect_str = f"ws://{connect_str}"
logger.info(f"msa.comfyui connecting to {connect_str}")
_ws = websocket.WebSocket()
_ws.connect(connect_str)
elif verbose:
logger.info(f"msa.comfyui already connected to {server_url} with client_id {_client_id}")
def load_workflow(workflow_path):
workflow_path = os.path.expanduser(os.path.expandvars(workflow_path))
if not workflow_path.endswith('.json'):
workflow_path = workflow_path + '.json'
with open(workflow_path, 'r', encoding="utf8") as file:
workflow_txt = file.read()
workflow = json.loads(workflow_txt)
return workflow
def is_api_workflow(workflow):
'''checks if this is an API format or full'''
return not 'nodes' in workflow
def find_all_nodes_by_field(workflow, fieldname, fieldvalue, class_type, verbose, break_on_error):
'''returns list of {node, id}'''
nodes = []
if is_api_workflow(workflow): # api workflow format
for node_id, node in workflow.items(): # dict of node_id : node
if (fieldname in node['_meta'] and node['_meta'][fieldname] == fieldvalue) or (fieldname in node and node[fieldname] == fieldvalue):
if class_type is None or node['class_type'] == class_type:
nodes.append(
dict(node=node, id=node_id)
)
else: # full workflow format
for node in workflow['nodes']: # list of nodes
if fieldname in node and node[fieldname] == fieldvalue: # WARNING: if title isn't changed in GUI, it's not saved in the json!
if class_type is None or node['type'] == class_type:
nodes.append(
dict(node=node, id=node['id'])
)
if len(nodes) == 0:
msg = f"Node with '{fieldname}' == '{fieldvalue}' not found"
if break_on_error:
raise ValueError(msg)
elif verbose:
logger.error(msg)
return None
return nodes
def find_node_by_field(workflow, fieldname, fieldvalue, class_type, verbose, break_on_error):
'''returns the first node found'''
nodes = find_all_nodes_by_field(workflow, fieldname, fieldvalue, class_type, verbose, break_on_error)
if nodes:
return nodes[0]['node']
def find_all_nodes_by_title(workflow, title, class_type=None, verbose=True, break_on_error=False):
'''returns list of {node, id}'''
return find_all_nodes_by_field(workflow, 'title', title, class_type, verbose, break_on_error)
def find_nodeid_by_field(workflow, fieldname, fieldvalue, class_type=None, verbose=True, break_on_error=False):
'''returns the first node id found'''
nodes = find_all_nodes_by_field(workflow, fieldname, fieldvalue, class_type, verbose, break_on_error)
if nodes:
return nodes[0]['id']
def find_nodeid_by_title(workflow, title, class_type=None, verbose=True, break_on_error=False):
'''returns the first node id found'''
return find_nodeid_by_field(workflow, 'title', title, class_type, verbose, break_on_error)
def find_node_by_title(workflow, title, class_type=None, verbose=True, break_on_error=False):
'''returns the first node found'''
return find_node_by_field(workflow, 'title', title, class_type, verbose, break_on_error)
def set_node_input(workflow, title, input_name, value, verbose=True, break_on_error=False, all=True):
'''
This can write to API format or Full workflow format.
In full workflow format 'input_name' can be a string key, or integer, indexing into 'widget_values' dict or array
'''
is_api = is_api_workflow(workflow)
if is_api:
msg = f"Setting '{title}.{input_name}' to '{value}' ... "
else:
msg = f"Setting '{title}.widget_values[{input_name}]' to '{value}' ... "
error = False
nodes = None
# if not is_api_workflow(workflow):
# msg += 'WORKFLOW IS NOT API FORMAT'
if value is None:
msg += "SKIPPING. VALUE NONE."
else: # look for node
nodes = find_all_nodes_by_title(workflow, title, verbose=False, break_on_error=False)
if nodes: # Node(s) found
msg += f"{len(nodes)} NODES FOUND ... "
if not all:
nodes = nodes[0] # only use first one:
for node_and_id in nodes:
node = node_and_id['node']
target = node['inputs'] if is_api else node['widgets_values']
if isinstance(target, dict):
if input_name in target:
target[input_name] = value
msg += "SUCCESS."
else:
msg += "INPUT NOT FOUND ON NODE."
error = True
else:
try:
target[input_name] = value
msg += "SUCCESS."
except Exception as e:
msg += f"ERROR {e}"
error = True
else: # Node not found
msg += "NODE NOT FOUND."
error = True
if error:
if break_on_error:
raise ValueError(msg)
else:
logger.warning(msg)
elif verbose:
logger.info(msg)
return nodes
def encode_img_to_base64(img, format="JPEG"):
'''
img: PIL.Image.Image
format: str, default "JPEG", other options: "PNG", "WEBP"
'''
buffered = BytesIO()
img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
def resize_and_crop_img(img, size, crop=True, method=Image.LANCZOS):
'''
Resize and crop image to fill size while maintaining aspect ratio.
img: PIL.Image.Image
size: tuple (width, height)
'''
target_width, target_height = size
img_width, img_height = img.size
# Calculate aspect ratios
aspect_ratio = img_width / img_height
target_ratio = target_width / target_height
if aspect_ratio > target_ratio:
# Image is wider, resize based on height
new_height = target_height
new_width = int(new_height * aspect_ratio)
else:
# Image is taller, resize based on width
new_width = target_width
new_height = int(new_width / aspect_ratio)
# Resize image
img_resized = img.resize((new_width, new_height), method)
if crop:
# Calculate crop box
left = (new_width - target_width) // 2
top = (new_height - target_height) // 2
right = left + target_width
bottom = top + target_height
# Crop image
img_cropped = img_resized.crop((left, top, right, bottom))
return img_cropped
return img_resized
def get_image(filename, subfolder, folder_type, server_url=None):
if server_url is None:
server_url = _server_url
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"http://{server_url}/view?{url_values}") as response:
return response.read()
def get_history(prompt_id, server_url=None):
if server_url is None:
server_url = _server_url
with urllib.request.urlopen(f"http://{server_url}/history/{prompt_id}") as response:
return json.loads(response.read())
def queue_prompt(prompt, server_url=None, client_id=None, do_wait=True, do_get_outputs=False):
start_time = time.time()
if server_url is None:
server_url = _server_url
if client_id is None:
client_id = _client_id
p = {"prompt": prompt, "client_id": client_id}
logger.info(f"Queuing prompt to server {server_url} with client_id {client_id}")
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request(f"http://{server_url}/prompt", data=data)
ret = json.loads(urllib.request.urlopen(req).read())
if do_get_outputs:
do_wait = True
if not do_wait:
return ret
prompt_id = ret['prompt_id']
while True:
out = _ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break #Execution is done
else:
continue #previews are binary data
end_time = time.time()
duration = end_time - start_time
logger.info(f"Queue ran in {duration:.2f} seconds")
history = get_history(prompt_id)[prompt_id]
ret['history'] = history
status = history['status']
status_str = status['status_str']
if status_str == 'error':
for msg_arr in status['messages']:
if 'error' in msg_arr[0]:
status_str = f"{status_str}, {msg_arr[0]}, {msg_arr[1]['node_type']}, {msg_arr[1]['exception_type']}, {msg_arr[1]['exception_message']}"
logger.error(f"STATUS: {status_str}")
# return
else:
logger.info(f"STATUS: {status_str}")
ret['status_str'] = status_str
ret['completed'] = status['completed']
all_outputs = history['outputs']
logger.info(f"{len(all_outputs)} output nodes")
for node_id, node_outputs in all_outputs.items():
node_title = prompt[node_id]['_meta']['title'].encode('utf-8', errors='replace') # because names have silly characters
node_outputs_summary = {k:len(v) for k,v in node_outputs.items()}
logger.info(f"{node_title} ({node_id}) => {node_outputs_summary}")
pass
if not do_get_outputs:
return ret
# ret_outputs = {}
all_images = {}
all_texts = {}
for node_id in all_outputs:
node_output = all_outputs[node_id]
if 'images' in node_output:
node_images = []
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
node_images.append(image_data)
all_images[node_id] = node_images
if 'text' in node_output:
node_text = []
for text in node_output['text']:
try:
node_text.append(text.decode('utf-8'))
except:
node_text.append(text)
all_texts[node_id] = node_text
ret['outputs'] = dict(images=all_images, texts=all_texts)
return ret
def test_workflow_interaction():
'''Testing interaction with the workflow json, both api and full'''
logger.info('*'*80)
logger.info('*'*80)
from pprint import pprint
import os
os.chdir(os.path.dirname(__file__))
workflow_paths = [
"test.api.json",
"test.json",
]
for workflow_path in workflow_paths:
logger.info('='*40)
logger.info(f"Loading '{workflow_path}'")
workflow = load_workflow(workflow_path)
pprint(workflow)
cmds = [
'find_node_by_title(workflow, "Load Image")',
'find_nodeid_by_title(workflow, "Load Image1")',
'find_node_by_title(workflow, "Load Image1")',
'find_all_nodes_by_title(workflow, "Load Image1")',
'set_node_input(workflow, "KSampler", "seed", 123)'
]
for cmd in cmds:
logger.info('-'*20)
logger.info(cmd)
pprint(eval(cmd))
# print('-'*20)
# print('set_node_input')
# pprint(set_node_input(workflow, "KSampler", "seed", 123))
if __name__ == "__main__":
test_workflow_interaction()
# print('-'*20)
# pprint(prompt)
# server_address = "
{
"1": {
"inputs": {
"image": "005_horsehead_nebula.jpg",
"upload": "image"
},
"class_type": "LoadImage",
"_meta": {
"title": "Load Image1"
}
},
"3": {
"inputs": {
"image": "00600 20240325.185243 red streak2 v177 flower_orchidcrystal - s554295367 n50 st0.6 g8 r0.8 graded.jpg",
"upload": "image"
},
"class_type": "LoadImage",
"_meta": {
"title": "Load Image1"
}
},
"4": {
"inputs": {
"image": "2.-peppered-rock-shield-lichen.jpg",
"upload": "image"
},
"class_type": "LoadImage",
"_meta": {
"title": "Load Image2"
}
},
"6": {
"inputs": {
"images": [
"1",
0
]
},
"class_type": "PreviewImage",
"_meta": {
"title": "Preview Image1"
}
},
"7": {
"inputs": {
"seed": 1234,
"steps": 21,
"cfg": 7,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 0.8
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"8": {
"inputs": {
"value": 0
},
"class_type": "INTConstant",
"_meta": {
"title": "INT Constant"
}
},
"9": {
"inputs": {
"value": 0
},
"class_type": "INTConstant",
"_meta": {
"title": "INT Constant"
}
},
"10": {
"inputs": {
"seed": 5678,
"steps": 22,
"cfg": 8,
"sampler_name": "euler_ancestral",
"scheduler": "karras",
"denoise": 1
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
}
}
{
"last_node_id": 10,
"last_link_id": 1,
"nodes": [
{
"id": 4,
"type": "LoadImage",
"pos": {
"0": 3461,
"1": 406
},
"size": {
"0": 315,
"1": 314
},
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": null,
"shape": 3
},
{
"name": "MASK",
"type": "MASK",
"links": null,
"shape": 3
}
],
"title": "Load Image2",
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"2.-peppered-rock-shield-lichen.jpg",
"image"
]
},
{
"id": 3,
"type": "LoadImage",
"pos": {
"0": 3043,
"1": 395
},
"size": {
"0": 315,
"1": 314
},
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": null,
"shape": 3
},
{
"name": "MASK",
"type": "MASK",
"links": null,
"shape": 3
}
],
"title": "Load Image1",
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"00600 20240325.185243 red streak2 v177 flower_orchidcrystal - s554295367 n50 st0.6 g8 r0.8 graded.jpg",
"image"
]
},
{
"id": 6,
"type": "PreviewImage",
"pos": {
"0": 2665,
"1": 393
},
"size": {
"0": 210,
"1": 26
},
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 1
}
],
"outputs": [],
"title": "Preview Image1",
"properties": {
"Node name for S&R": "PreviewImage"
}
},
{
"id": 2,
"type": "LoadImage",
"pos": {
"0": 2338,
"1": 840
},
"size": {
"0": 315,
"1": 314
},
"flags": {},
"order": 2,
"mode": 4,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": null,
"shape": 3
},
{
"name": "MASK",
"type": "MASK",
"links": null,
"shape": 3
}
],
"title": "Load Image bypassed",
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"1000_F_846598571_evJ8KPjrWWcRKT7FMWaYwaazi8CqT9V5.jpg",
"image"
]
},
{
"id": 5,
"type": "LoadImage",
"pos": {
"0": 2692,
"1": 842
},
"size": {
"0": 315,
"1": 314
},
"flags": {},
"order": 3,
"mode": 4,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": null,
"shape": 3
},
{
"name": "MASK",
"type": "MASK",
"links": null,
"shape": 3
}
],
"title": "Load Image bypassed2",
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"6d1fb355c398b0a81bb5de9a075f3033.jpg",
"image"
]
},
{
"id": 1,
"type": "LoadImage",
"pos": {
"0": 2304,
"1": 393
},
"size": {
"0": 315,
"1": 314
},
"flags": {},
"order": 4,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
1
],
"slot_index": 0,
"shape": 3
},
{
"name": "MASK",
"type": "MASK",
"links": null,
"shape": 3
}
],
"title": "Load Image1",
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"005_horsehead_nebula.jpg",
"image"
]
},
{
"id": 8,
"type": "INTConstant",
"pos": {
"0": 2846,
"1": 229
},
"size": [
200,
58
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "value",
"type": "INT",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "INTConstant"
},
"widgets_values": [
0
],
"color": "#1b4669",
"bgcolor": "#29699c"
},
{
"id": 9,
"type": "INTConstant",
"pos": {
"0": 3116,
"1": 228
},
"size": [
200,
58
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "value",
"type": "INT",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "INTConstant"
},
"widgets_values": [
0
],
"color": "#1b4669",
"bgcolor": "#29699c"
},
{
"id": 10,
"type": "KSampler",
"pos": {
"0": 3586,
"1": 864
},
"size": [
315,
474
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": null
},
{
"name": "positive",
"type": "CONDITIONING",
"link": null
},
{
"name": "negative",
"type": "CONDITIONING",
"link": null
},
{
"name": "latent_image",
"type": "LATENT",
"link": null
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "KSampler"
},
"widgets_values": [
5678,
"randomize",
22,
8,
"euler_ancestral",
"karras",
1
]
},
{
"id": 7,
"type": "KSampler",
"pos": {
"0": 3137,
"1": 862
},
"size": {
"0": 315,
"1": 262
},
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": null
},
{
"name": "positive",
"type": "CONDITIONING",
"link": null
},
{
"name": "negative",
"type": "CONDITIONING",
"link": null
},
{
"name": "latent_image",
"type": "LATENT",
"link": null
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "KSampler"
},
"widgets_values": [
1234,
"randomize",
21,
7,
"euler",
"normal",
0.8
]
}
],
"links": [
[
1,
1,
0,
6,
0,
"IMAGE"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 1,
"offset": [
-1278.6666259765625,
-152.66665649414062
]
}
},
"version": 0.4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment