Last active
January 8, 2025 18:23
-
-
Save memo/3c51abce7af74ce779026ca48f59b9a1 to your computer and use it in GitHub Desktop.
some utils for working with comfyui. Since comfyui enabling/disabling nodes changing pathways is done in the frontend and not the backend, it isn't supported in the api. So the code below also supports working with full workflow files. not just api files, to allow changing pathways at runtime. This is a hacky solution to get aroud this limitatio…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import websocket | |
import uuid | |
import json | |
import urllib.request | |
import urllib.parse | |
import os | |
import json | |
import base64 | |
from io import BytesIO | |
from PIL import Image | |
import time | |
import msa.logger | |
logger = msa.logger.getLogger(__name__) | |
_client_id = None | |
_server_url = None | |
_ws = None | |
def connect(server_url, server_port, verbose=True): | |
global _client_id | |
global _server_url | |
global _ws | |
if server_port and int(server_port): | |
server_url = f"{server_url}:{server_port}" | |
if server_url != _server_url or _ws is None: | |
_client_id = str(uuid.uuid4()) | |
_server_url = server_url | |
connect_str = f"{server_url}/ws?clientId={_client_id}" | |
if not connect_str.startswith("http"): | |
connect_str = f"ws://{connect_str}" | |
logger.info(f"msa.comfyui connecting to {connect_str}") | |
_ws = websocket.WebSocket() | |
_ws.connect(connect_str) | |
elif verbose: | |
logger.info(f"msa.comfyui already connected to {server_url} with client_id {_client_id}") | |
def load_workflow(workflow_path): | |
workflow_path = os.path.expanduser(os.path.expandvars(workflow_path)) | |
if not workflow_path.endswith('.json'): | |
workflow_path = workflow_path + '.json' | |
with open(workflow_path, 'r', encoding="utf8") as file: | |
workflow_txt = file.read() | |
workflow = json.loads(workflow_txt) | |
return workflow | |
def is_api_workflow(workflow): | |
'''checks if this is an API format or full''' | |
return not 'nodes' in workflow | |
def find_all_nodes_by_field(workflow, fieldname, fieldvalue, class_type, verbose, break_on_error): | |
'''returns list of {node, id}''' | |
nodes = [] | |
if is_api_workflow(workflow): # api workflow format | |
for node_id, node in workflow.items(): # dict of node_id : node | |
if (fieldname in node['_meta'] and node['_meta'][fieldname] == fieldvalue) or (fieldname in node and node[fieldname] == fieldvalue): | |
if class_type is None or node['class_type'] == class_type: | |
nodes.append( | |
dict(node=node, id=node_id) | |
) | |
else: # full workflow format | |
for node in workflow['nodes']: # list of nodes | |
if fieldname in node and node[fieldname] == fieldvalue: # WARNING: if title isn't changed in GUI, it's not saved in the json! | |
if class_type is None or node['type'] == class_type: | |
nodes.append( | |
dict(node=node, id=node['id']) | |
) | |
if len(nodes) == 0: | |
msg = f"Node with '{fieldname}' == '{fieldvalue}' not found" | |
if break_on_error: | |
raise ValueError(msg) | |
elif verbose: | |
logger.error(msg) | |
return None | |
return nodes | |
def find_node_by_field(workflow, fieldname, fieldvalue, class_type, verbose, break_on_error): | |
'''returns the first node found''' | |
nodes = find_all_nodes_by_field(workflow, fieldname, fieldvalue, class_type, verbose, break_on_error) | |
if nodes: | |
return nodes[0]['node'] | |
def find_all_nodes_by_title(workflow, title, class_type=None, verbose=True, break_on_error=False): | |
'''returns list of {node, id}''' | |
return find_all_nodes_by_field(workflow, 'title', title, class_type, verbose, break_on_error) | |
def find_nodeid_by_field(workflow, fieldname, fieldvalue, class_type=None, verbose=True, break_on_error=False): | |
'''returns the first node id found''' | |
nodes = find_all_nodes_by_field(workflow, fieldname, fieldvalue, class_type, verbose, break_on_error) | |
if nodes: | |
return nodes[0]['id'] | |
def find_nodeid_by_title(workflow, title, class_type=None, verbose=True, break_on_error=False): | |
'''returns the first node id found''' | |
return find_nodeid_by_field(workflow, 'title', title, class_type, verbose, break_on_error) | |
def find_node_by_title(workflow, title, class_type=None, verbose=True, break_on_error=False): | |
'''returns the first node found''' | |
return find_node_by_field(workflow, 'title', title, class_type, verbose, break_on_error) | |
def set_node_input(workflow, title, input_name, value, verbose=True, break_on_error=False, all=True): | |
''' | |
This can write to API format or Full workflow format. | |
In full workflow format 'input_name' can be a string key, or integer, indexing into 'widget_values' dict or array | |
''' | |
is_api = is_api_workflow(workflow) | |
if is_api: | |
msg = f"Setting '{title}.{input_name}' to '{value}' ... " | |
else: | |
msg = f"Setting '{title}.widget_values[{input_name}]' to '{value}' ... " | |
error = False | |
nodes = None | |
# if not is_api_workflow(workflow): | |
# msg += 'WORKFLOW IS NOT API FORMAT' | |
if value is None: | |
msg += "SKIPPING. VALUE NONE." | |
else: # look for node | |
nodes = find_all_nodes_by_title(workflow, title, verbose=False, break_on_error=False) | |
if nodes: # Node(s) found | |
msg += f"{len(nodes)} NODES FOUND ... " | |
if not all: | |
nodes = nodes[0] # only use first one: | |
for node_and_id in nodes: | |
node = node_and_id['node'] | |
target = node['inputs'] if is_api else node['widgets_values'] | |
if isinstance(target, dict): | |
if input_name in target: | |
target[input_name] = value | |
msg += "SUCCESS." | |
else: | |
msg += "INPUT NOT FOUND ON NODE." | |
error = True | |
else: | |
try: | |
target[input_name] = value | |
msg += "SUCCESS." | |
except Exception as e: | |
msg += f"ERROR {e}" | |
error = True | |
else: # Node not found | |
msg += "NODE NOT FOUND." | |
error = True | |
if error: | |
if break_on_error: | |
raise ValueError(msg) | |
else: | |
logger.warning(msg) | |
elif verbose: | |
logger.info(msg) | |
return nodes | |
def encode_img_to_base64(img, format="JPEG"): | |
''' | |
img: PIL.Image.Image | |
format: str, default "JPEG", other options: "PNG", "WEBP" | |
''' | |
buffered = BytesIO() | |
img.save(buffered, format=format) | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
return img_str | |
def resize_and_crop_img(img, size, crop=True, method=Image.LANCZOS): | |
''' | |
Resize and crop image to fill size while maintaining aspect ratio. | |
img: PIL.Image.Image | |
size: tuple (width, height) | |
''' | |
target_width, target_height = size | |
img_width, img_height = img.size | |
# Calculate aspect ratios | |
aspect_ratio = img_width / img_height | |
target_ratio = target_width / target_height | |
if aspect_ratio > target_ratio: | |
# Image is wider, resize based on height | |
new_height = target_height | |
new_width = int(new_height * aspect_ratio) | |
else: | |
# Image is taller, resize based on width | |
new_width = target_width | |
new_height = int(new_width / aspect_ratio) | |
# Resize image | |
img_resized = img.resize((new_width, new_height), method) | |
if crop: | |
# Calculate crop box | |
left = (new_width - target_width) // 2 | |
top = (new_height - target_height) // 2 | |
right = left + target_width | |
bottom = top + target_height | |
# Crop image | |
img_cropped = img_resized.crop((left, top, right, bottom)) | |
return img_cropped | |
return img_resized | |
def get_image(filename, subfolder, folder_type, server_url=None): | |
if server_url is None: | |
server_url = _server_url | |
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | |
url_values = urllib.parse.urlencode(data) | |
with urllib.request.urlopen(f"http://{server_url}/view?{url_values}") as response: | |
return response.read() | |
def get_history(prompt_id, server_url=None): | |
if server_url is None: | |
server_url = _server_url | |
with urllib.request.urlopen(f"http://{server_url}/history/{prompt_id}") as response: | |
return json.loads(response.read()) | |
def queue_prompt(prompt, server_url=None, client_id=None, do_wait=True, do_get_outputs=False): | |
start_time = time.time() | |
if server_url is None: | |
server_url = _server_url | |
if client_id is None: | |
client_id = _client_id | |
p = {"prompt": prompt, "client_id": client_id} | |
logger.info(f"Queuing prompt to server {server_url} with client_id {client_id}") | |
data = json.dumps(p).encode('utf-8') | |
req = urllib.request.Request(f"http://{server_url}/prompt", data=data) | |
ret = json.loads(urllib.request.urlopen(req).read()) | |
if do_get_outputs: | |
do_wait = True | |
if not do_wait: | |
return ret | |
prompt_id = ret['prompt_id'] | |
while True: | |
out = _ws.recv() | |
if isinstance(out, str): | |
message = json.loads(out) | |
if message['type'] == 'executing': | |
data = message['data'] | |
if data['node'] is None and data['prompt_id'] == prompt_id: | |
break #Execution is done | |
else: | |
continue #previews are binary data | |
end_time = time.time() | |
duration = end_time - start_time | |
logger.info(f"Queue ran in {duration:.2f} seconds") | |
history = get_history(prompt_id)[prompt_id] | |
ret['history'] = history | |
status = history['status'] | |
status_str = status['status_str'] | |
if status_str == 'error': | |
for msg_arr in status['messages']: | |
if 'error' in msg_arr[0]: | |
status_str = f"{status_str}, {msg_arr[0]}, {msg_arr[1]['node_type']}, {msg_arr[1]['exception_type']}, {msg_arr[1]['exception_message']}" | |
logger.error(f"STATUS: {status_str}") | |
# return | |
else: | |
logger.info(f"STATUS: {status_str}") | |
ret['status_str'] = status_str | |
ret['completed'] = status['completed'] | |
all_outputs = history['outputs'] | |
logger.info(f"{len(all_outputs)} output nodes") | |
for node_id, node_outputs in all_outputs.items(): | |
node_title = prompt[node_id]['_meta']['title'].encode('utf-8', errors='replace') # because names have silly characters | |
node_outputs_summary = {k:len(v) for k,v in node_outputs.items()} | |
logger.info(f"{node_title} ({node_id}) => {node_outputs_summary}") | |
pass | |
if not do_get_outputs: | |
return ret | |
# ret_outputs = {} | |
all_images = {} | |
all_texts = {} | |
for node_id in all_outputs: | |
node_output = all_outputs[node_id] | |
if 'images' in node_output: | |
node_images = [] | |
for image in node_output['images']: | |
image_data = get_image(image['filename'], image['subfolder'], image['type']) | |
node_images.append(image_data) | |
all_images[node_id] = node_images | |
if 'text' in node_output: | |
node_text = [] | |
for text in node_output['text']: | |
try: | |
node_text.append(text.decode('utf-8')) | |
except: | |
node_text.append(text) | |
all_texts[node_id] = node_text | |
ret['outputs'] = dict(images=all_images, texts=all_texts) | |
return ret | |
def test_workflow_interaction(): | |
'''Testing interaction with the workflow json, both api and full''' | |
logger.info('*'*80) | |
logger.info('*'*80) | |
from pprint import pprint | |
import os | |
os.chdir(os.path.dirname(__file__)) | |
workflow_paths = [ | |
"test.api.json", | |
"test.json", | |
] | |
for workflow_path in workflow_paths: | |
logger.info('='*40) | |
logger.info(f"Loading '{workflow_path}'") | |
workflow = load_workflow(workflow_path) | |
pprint(workflow) | |
cmds = [ | |
'find_node_by_title(workflow, "Load Image")', | |
'find_nodeid_by_title(workflow, "Load Image1")', | |
'find_node_by_title(workflow, "Load Image1")', | |
'find_all_nodes_by_title(workflow, "Load Image1")', | |
'set_node_input(workflow, "KSampler", "seed", 123)' | |
] | |
for cmd in cmds: | |
logger.info('-'*20) | |
logger.info(cmd) | |
pprint(eval(cmd)) | |
# print('-'*20) | |
# print('set_node_input') | |
# pprint(set_node_input(workflow, "KSampler", "seed", 123)) | |
if __name__ == "__main__": | |
test_workflow_interaction() | |
# print('-'*20) | |
# pprint(prompt) | |
# server_address = " |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"1": { | |
"inputs": { | |
"image": "005_horsehead_nebula.jpg", | |
"upload": "image" | |
}, | |
"class_type": "LoadImage", | |
"_meta": { | |
"title": "Load Image1" | |
} | |
}, | |
"3": { | |
"inputs": { | |
"image": "00600 20240325.185243 red streak2 v177 flower_orchidcrystal - s554295367 n50 st0.6 g8 r0.8 graded.jpg", | |
"upload": "image" | |
}, | |
"class_type": "LoadImage", | |
"_meta": { | |
"title": "Load Image1" | |
} | |
}, | |
"4": { | |
"inputs": { | |
"image": "2.-peppered-rock-shield-lichen.jpg", | |
"upload": "image" | |
}, | |
"class_type": "LoadImage", | |
"_meta": { | |
"title": "Load Image2" | |
} | |
}, | |
"6": { | |
"inputs": { | |
"images": [ | |
"1", | |
0 | |
] | |
}, | |
"class_type": "PreviewImage", | |
"_meta": { | |
"title": "Preview Image1" | |
} | |
}, | |
"7": { | |
"inputs": { | |
"seed": 1234, | |
"steps": 21, | |
"cfg": 7, | |
"sampler_name": "euler", | |
"scheduler": "normal", | |
"denoise": 0.8 | |
}, | |
"class_type": "KSampler", | |
"_meta": { | |
"title": "KSampler" | |
} | |
}, | |
"8": { | |
"inputs": { | |
"value": 0 | |
}, | |
"class_type": "INTConstant", | |
"_meta": { | |
"title": "INT Constant" | |
} | |
}, | |
"9": { | |
"inputs": { | |
"value": 0 | |
}, | |
"class_type": "INTConstant", | |
"_meta": { | |
"title": "INT Constant" | |
} | |
}, | |
"10": { | |
"inputs": { | |
"seed": 5678, | |
"steps": 22, | |
"cfg": 8, | |
"sampler_name": "euler_ancestral", | |
"scheduler": "karras", | |
"denoise": 1 | |
}, | |
"class_type": "KSampler", | |
"_meta": { | |
"title": "KSampler" | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"last_node_id": 10, | |
"last_link_id": 1, | |
"nodes": [ | |
{ | |
"id": 4, | |
"type": "LoadImage", | |
"pos": { | |
"0": 3461, | |
"1": 406 | |
}, | |
"size": { | |
"0": 315, | |
"1": 314 | |
}, | |
"flags": {}, | |
"order": 0, | |
"mode": 0, | |
"inputs": [], | |
"outputs": [ | |
{ | |
"name": "IMAGE", | |
"type": "IMAGE", | |
"links": null, | |
"shape": 3 | |
}, | |
{ | |
"name": "MASK", | |
"type": "MASK", | |
"links": null, | |
"shape": 3 | |
} | |
], | |
"title": "Load Image2", | |
"properties": { | |
"Node name for S&R": "LoadImage" | |
}, | |
"widgets_values": [ | |
"2.-peppered-rock-shield-lichen.jpg", | |
"image" | |
] | |
}, | |
{ | |
"id": 3, | |
"type": "LoadImage", | |
"pos": { | |
"0": 3043, | |
"1": 395 | |
}, | |
"size": { | |
"0": 315, | |
"1": 314 | |
}, | |
"flags": {}, | |
"order": 1, | |
"mode": 0, | |
"inputs": [], | |
"outputs": [ | |
{ | |
"name": "IMAGE", | |
"type": "IMAGE", | |
"links": null, | |
"shape": 3 | |
}, | |
{ | |
"name": "MASK", | |
"type": "MASK", | |
"links": null, | |
"shape": 3 | |
} | |
], | |
"title": "Load Image1", | |
"properties": { | |
"Node name for S&R": "LoadImage" | |
}, | |
"widgets_values": [ | |
"00600 20240325.185243 red streak2 v177 flower_orchidcrystal - s554295367 n50 st0.6 g8 r0.8 graded.jpg", | |
"image" | |
] | |
}, | |
{ | |
"id": 6, | |
"type": "PreviewImage", | |
"pos": { | |
"0": 2665, | |
"1": 393 | |
}, | |
"size": { | |
"0": 210, | |
"1": 26 | |
}, | |
"flags": {}, | |
"order": 9, | |
"mode": 0, | |
"inputs": [ | |
{ | |
"name": "images", | |
"type": "IMAGE", | |
"link": 1 | |
} | |
], | |
"outputs": [], | |
"title": "Preview Image1", | |
"properties": { | |
"Node name for S&R": "PreviewImage" | |
} | |
}, | |
{ | |
"id": 2, | |
"type": "LoadImage", | |
"pos": { | |
"0": 2338, | |
"1": 840 | |
}, | |
"size": { | |
"0": 315, | |
"1": 314 | |
}, | |
"flags": {}, | |
"order": 2, | |
"mode": 4, | |
"inputs": [], | |
"outputs": [ | |
{ | |
"name": "IMAGE", | |
"type": "IMAGE", | |
"links": null, | |
"shape": 3 | |
}, | |
{ | |
"name": "MASK", | |
"type": "MASK", | |
"links": null, | |
"shape": 3 | |
} | |
], | |
"title": "Load Image bypassed", | |
"properties": { | |
"Node name for S&R": "LoadImage" | |
}, | |
"widgets_values": [ | |
"1000_F_846598571_evJ8KPjrWWcRKT7FMWaYwaazi8CqT9V5.jpg", | |
"image" | |
] | |
}, | |
{ | |
"id": 5, | |
"type": "LoadImage", | |
"pos": { | |
"0": 2692, | |
"1": 842 | |
}, | |
"size": { | |
"0": 315, | |
"1": 314 | |
}, | |
"flags": {}, | |
"order": 3, | |
"mode": 4, | |
"inputs": [], | |
"outputs": [ | |
{ | |
"name": "IMAGE", | |
"type": "IMAGE", | |
"links": null, | |
"shape": 3 | |
}, | |
{ | |
"name": "MASK", | |
"type": "MASK", | |
"links": null, | |
"shape": 3 | |
} | |
], | |
"title": "Load Image bypassed2", | |
"properties": { | |
"Node name for S&R": "LoadImage" | |
}, | |
"widgets_values": [ | |
"6d1fb355c398b0a81bb5de9a075f3033.jpg", | |
"image" | |
] | |
}, | |
{ | |
"id": 1, | |
"type": "LoadImage", | |
"pos": { | |
"0": 2304, | |
"1": 393 | |
}, | |
"size": { | |
"0": 315, | |
"1": 314 | |
}, | |
"flags": {}, | |
"order": 4, | |
"mode": 0, | |
"inputs": [], | |
"outputs": [ | |
{ | |
"name": "IMAGE", | |
"type": "IMAGE", | |
"links": [ | |
1 | |
], | |
"slot_index": 0, | |
"shape": 3 | |
}, | |
{ | |
"name": "MASK", | |
"type": "MASK", | |
"links": null, | |
"shape": 3 | |
} | |
], | |
"title": "Load Image1", | |
"properties": { | |
"Node name for S&R": "LoadImage" | |
}, | |
"widgets_values": [ | |
"005_horsehead_nebula.jpg", | |
"image" | |
] | |
}, | |
{ | |
"id": 8, | |
"type": "INTConstant", | |
"pos": { | |
"0": 2846, | |
"1": 229 | |
}, | |
"size": [ | |
200, | |
58 | |
], | |
"flags": {}, | |
"order": 5, | |
"mode": 0, | |
"inputs": [], | |
"outputs": [ | |
{ | |
"name": "value", | |
"type": "INT", | |
"links": null, | |
"shape": 3 | |
} | |
], | |
"properties": { | |
"Node name for S&R": "INTConstant" | |
}, | |
"widgets_values": [ | |
0 | |
], | |
"color": "#1b4669", | |
"bgcolor": "#29699c" | |
}, | |
{ | |
"id": 9, | |
"type": "INTConstant", | |
"pos": { | |
"0": 3116, | |
"1": 228 | |
}, | |
"size": [ | |
200, | |
58 | |
], | |
"flags": {}, | |
"order": 6, | |
"mode": 0, | |
"inputs": [], | |
"outputs": [ | |
{ | |
"name": "value", | |
"type": "INT", | |
"links": null, | |
"shape": 3 | |
} | |
], | |
"properties": { | |
"Node name for S&R": "INTConstant" | |
}, | |
"widgets_values": [ | |
0 | |
], | |
"color": "#1b4669", | |
"bgcolor": "#29699c" | |
}, | |
{ | |
"id": 10, | |
"type": "KSampler", | |
"pos": { | |
"0": 3586, | |
"1": 864 | |
}, | |
"size": [ | |
315, | |
474 | |
], | |
"flags": {}, | |
"order": 7, | |
"mode": 0, | |
"inputs": [ | |
{ | |
"name": "model", | |
"type": "MODEL", | |
"link": null | |
}, | |
{ | |
"name": "positive", | |
"type": "CONDITIONING", | |
"link": null | |
}, | |
{ | |
"name": "negative", | |
"type": "CONDITIONING", | |
"link": null | |
}, | |
{ | |
"name": "latent_image", | |
"type": "LATENT", | |
"link": null | |
} | |
], | |
"outputs": [ | |
{ | |
"name": "LATENT", | |
"type": "LATENT", | |
"links": null, | |
"shape": 3 | |
} | |
], | |
"properties": { | |
"Node name for S&R": "KSampler" | |
}, | |
"widgets_values": [ | |
5678, | |
"randomize", | |
22, | |
8, | |
"euler_ancestral", | |
"karras", | |
1 | |
] | |
}, | |
{ | |
"id": 7, | |
"type": "KSampler", | |
"pos": { | |
"0": 3137, | |
"1": 862 | |
}, | |
"size": { | |
"0": 315, | |
"1": 262 | |
}, | |
"flags": {}, | |
"order": 8, | |
"mode": 0, | |
"inputs": [ | |
{ | |
"name": "model", | |
"type": "MODEL", | |
"link": null | |
}, | |
{ | |
"name": "positive", | |
"type": "CONDITIONING", | |
"link": null | |
}, | |
{ | |
"name": "negative", | |
"type": "CONDITIONING", | |
"link": null | |
}, | |
{ | |
"name": "latent_image", | |
"type": "LATENT", | |
"link": null | |
} | |
], | |
"outputs": [ | |
{ | |
"name": "LATENT", | |
"type": "LATENT", | |
"links": null, | |
"shape": 3 | |
} | |
], | |
"properties": { | |
"Node name for S&R": "KSampler" | |
}, | |
"widgets_values": [ | |
1234, | |
"randomize", | |
21, | |
7, | |
"euler", | |
"normal", | |
0.8 | |
] | |
} | |
], | |
"links": [ | |
[ | |
1, | |
1, | |
0, | |
6, | |
0, | |
"IMAGE" | |
] | |
], | |
"groups": [], | |
"config": {}, | |
"extra": { | |
"ds": { | |
"scale": 1, | |
"offset": [ | |
-1278.6666259765625, | |
-152.66665649414062 | |
] | |
} | |
}, | |
"version": 0.4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment