|
#!/usr/bin/env python3 |
|
import os |
|
import json |
|
import uuid |
|
import asyncio |
|
import aiohttp |
|
import websockets |
|
import base64 |
|
from pathlib import Path |
|
from typing import List, Dict, Any, Optional |
|
import logging |
|
from datetime import datetime |
|
|
|
from mcp.server.models import InitializationOptions |
|
import mcp.types as types |
|
from mcp.server import NotificationOptions, Server |
|
from pydantic import AnyUrl |
|
import mcp.server.stdio |
|
|
|
# Configure logging |
|
logging.basicConfig( |
|
level=os.getenv('LOG_LEVEL', 'INFO'), |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
# Server configuration |
|
COMFYUI_SERVER_URL = os.getenv('COMFYUI_SERVER_URL', 'http://localhost:8188') |
|
WORKFLOW_DIR = Path(__file__).parent # Workflows are in the same directory for gist compatibility |
|
OUTPUT_DIR = Path("/comfyui/output") |
|
LORA_DIR = Path("/comfyui/models/loras") |
|
CHECKPOINT_DIR = Path("/comfyui/models/checkpoints") |
|
|
|
# MCP Server instance |
|
server = Server("comfyui-mcp") |
|
|
|
# Temporary storage for chunked uploads |
|
chunked_uploads = {} |
|
|
|
# Available tools |
|
MCP_TOOLS = [ |
|
'generate-image', |
|
'list-workflows', |
|
'get-workflow', |
|
'submit-workflow', |
|
'list-loras', |
|
'get-lora-info', |
|
'upload-lora', |
|
'upload-lora-chunked-start', |
|
'upload-lora-chunked-append', |
|
'upload-lora-chunked-finish', |
|
'list-checkpoints', |
|
'upload-checkpoint', |
|
'get-comfyui-nodes', |
|
'get-node-info', |
|
'validate-workflow', |
|
'get-generation-status', |
|
'get-system-stats', |
|
'list-outputs', |
|
'download-output' |
|
] |
|
|
|
class ComfyUIClient: |
|
"""Client for interacting with ComfyUI API""" |
|
|
|
def __init__(self, server_url: str): |
|
self.server_url = server_url.rstrip('/') |
|
self.client_id = str(uuid.uuid4()) |
|
self.ws = None |
|
|
|
async def connect_websocket(self): |
|
"""Connect to ComfyUI websocket""" |
|
ws_url = f"{self.server_url.replace('http', 'ws')}/ws?clientId={self.client_id}" |
|
self.ws = await websockets.connect(ws_url) |
|
logger.info(f"Connected to ComfyUI websocket: {ws_url}") |
|
|
|
async def disconnect_websocket(self): |
|
"""Disconnect from ComfyUI websocket""" |
|
if self.ws: |
|
await self.ws.close() |
|
self.ws = None |
|
|
|
async def queue_prompt(self, workflow: dict) -> str: |
|
"""Queue a workflow prompt and return the prompt ID""" |
|
async with aiohttp.ClientSession() as session: |
|
data = { |
|
"prompt": workflow, |
|
"client_id": self.client_id |
|
} |
|
async with session.post(f"{self.server_url}/prompt", json=data) as resp: |
|
result = await resp.json() |
|
return result['prompt_id'] |
|
|
|
async def get_history(self, prompt_id: str) -> dict: |
|
"""Get generation history for a prompt ID""" |
|
async with aiohttp.ClientSession() as session: |
|
async with session.get(f"{self.server_url}/history/{prompt_id}") as resp: |
|
return await resp.json() |
|
|
|
async def get_object_info(self) -> dict: |
|
"""Get all available ComfyUI nodes and their info""" |
|
async with aiohttp.ClientSession() as session: |
|
async with session.get(f"{self.server_url}/object_info") as resp: |
|
return await resp.json() |
|
|
|
async def get_system_stats(self) -> dict: |
|
"""Get ComfyUI system statistics""" |
|
async with aiohttp.ClientSession() as session: |
|
async with session.get(f"{self.server_url}/system_stats") as resp: |
|
return await resp.json() |
|
|
|
async def get_all_history(self, max_items: int = 100) -> dict: |
|
"""Get all generation history (limited to max_items most recent)""" |
|
async with aiohttp.ClientSession() as session: |
|
async with session.get(f"{self.server_url}/history") as resp: |
|
history = await resp.json() |
|
# Sort by timestamp and limit |
|
sorted_items = sorted(history.items(), |
|
key=lambda x: x[1].get('_timestamp', 0), |
|
reverse=True)[:max_items] |
|
return dict(sorted_items) |
|
|
|
async def download_output(self, filename: str, subfolder: str = "", output_type: str = "output") -> bytes: |
|
"""Download an output image from ComfyUI""" |
|
async with aiohttp.ClientSession() as session: |
|
params = { |
|
"filename": filename, |
|
"type": output_type |
|
} |
|
if subfolder: |
|
params["subfolder"] = subfolder |
|
|
|
async with session.get(f"{self.server_url}/view", params=params) as resp: |
|
if resp.status == 200: |
|
return await resp.read() |
|
else: |
|
raise Exception(f"Failed to download {filename}: HTTP {resp.status}") |
|
|
|
async def wait_for_completion(self, prompt_id: str) -> List[str]: |
|
"""Wait for a prompt to complete and return output image paths""" |
|
if not self.ws: |
|
await self.connect_websocket() |
|
|
|
output_images = [] |
|
|
|
while True: |
|
message = await self.ws.recv() |
|
if isinstance(message, str): |
|
data = json.loads(message) |
|
if data['type'] == 'executing': |
|
if data['data']['node'] is None and data['data']['prompt_id'] == prompt_id: |
|
# Execution completed |
|
break |
|
|
|
# Get the output images from history |
|
history = await self.get_history(prompt_id) |
|
if prompt_id in history: |
|
for node_id, node_output in history[prompt_id]['outputs'].items(): |
|
if 'images' in node_output: |
|
for image in node_output['images']: |
|
filename = image['filename'] |
|
subfolder = image.get('subfolder', '') |
|
if subfolder: |
|
image_path = OUTPUT_DIR / subfolder / filename |
|
else: |
|
image_path = OUTPUT_DIR / filename |
|
output_images.append(str(image_path)) |
|
|
|
return output_images |
|
|
|
# Workflow management functions |
|
def list_workflows() -> List[str]: |
|
"""List available workflow templates""" |
|
workflows = [] |
|
if WORKFLOW_DIR.exists(): |
|
for file in WORKFLOW_DIR.glob("*.json"): |
|
workflows.append(file.stem) |
|
return sorted(workflows) |
|
|
|
def load_workflow(name: str) -> Optional[dict]: |
|
"""Load a workflow template by name""" |
|
workflow_path = WORKFLOW_DIR / f"{name}.json" |
|
if workflow_path.exists(): |
|
with open(workflow_path, 'r') as f: |
|
return json.load(f) |
|
return None |
|
|
|
def list_lora_models(search_term: Optional[str] = None) -> List[str]: |
|
"""List available LoRA models""" |
|
if not LORA_DIR.exists(): |
|
return [] |
|
|
|
lora_files = [] |
|
for file in LORA_DIR.glob("*.*"): |
|
if file.suffix.lower() in ['.safetensors', '.pt', '.ckpt', '.bin']: |
|
filename = file.name |
|
if search_term is None or search_term.lower() in filename.lower(): |
|
lora_files.append(filename) |
|
|
|
return sorted(lora_files) |
|
|
|
def get_lora_info(lora_name: str) -> Optional[dict]: |
|
"""Get LoRA model information including metadata if available""" |
|
lora_path = LORA_DIR / lora_name |
|
if not lora_path.exists(): |
|
return None |
|
|
|
info = { |
|
"name": lora_name, |
|
"path": str(lora_path), |
|
"size": lora_path.stat().st_size, |
|
"modified": datetime.fromtimestamp(lora_path.stat().st_mtime).isoformat() |
|
} |
|
|
|
# Check for metadata files (try multiple naming conventions) |
|
base_name = lora_path.stem |
|
metadata_paths = [ |
|
lora_path.with_suffix('.json'), # Inkpunk_Flux.json |
|
lora_path.with_suffix('.metadata.json'), # Inkpunk_Flux.metadata.json |
|
LORA_DIR / f"{base_name}.metadata.json" # Inkpunk_Flux.metadata.json |
|
] |
|
|
|
for metadata_path in metadata_paths: |
|
if metadata_path.exists(): |
|
try: |
|
with open(metadata_path, 'r') as f: |
|
info["metadata"] = json.load(f) |
|
info["metadata_source"] = str(metadata_path.name) |
|
break |
|
except Exception as e: |
|
logger.warning(f"Failed to load metadata from {metadata_path}: {e}") |
|
|
|
return info |
|
|
|
def list_checkpoints() -> List[str]: |
|
"""List available checkpoint models""" |
|
if not CHECKPOINT_DIR.exists(): |
|
return [] |
|
|
|
checkpoint_files = [] |
|
for file in CHECKPOINT_DIR.glob("*.*"): |
|
if file.suffix.lower() in ['.safetensors', '.pt', '.ckpt']: |
|
checkpoint_files.append(file.name) |
|
|
|
return sorted(checkpoint_files) |
|
|
|
def create_simple_workflow( |
|
prompt: str, |
|
negative_prompt: str = "", |
|
checkpoint: str = "flux1-dev-fp8.safetensors", |
|
lora: Optional[str] = None, |
|
lora_strength: float = 1.0, |
|
width: int = 1024, |
|
height: int = 1024, |
|
batch_size: int = 1, |
|
steps: int = 25, |
|
cfg: float = 3.5, |
|
sampler: str = "euler_ancestral", |
|
seed: Optional[int] = None |
|
) -> dict: |
|
"""Create a simple text-to-image workflow""" |
|
if seed is None: |
|
seed = int.from_bytes(os.urandom(6), 'big') |
|
|
|
workflow = { |
|
"1": { |
|
"class_type": "CheckpointLoaderSimple", |
|
"inputs": { |
|
"ckpt_name": checkpoint |
|
} |
|
}, |
|
"2": { |
|
"class_type": "CLIPTextEncode", |
|
"inputs": { |
|
"text": prompt, |
|
"clip": ["1", 1] |
|
} |
|
}, |
|
"3": { |
|
"class_type": "CLIPTextEncode", |
|
"inputs": { |
|
"text": negative_prompt, |
|
"clip": ["1", 1] |
|
} |
|
}, |
|
"4": { |
|
"class_type": "EmptyLatentImage", |
|
"inputs": { |
|
"width": width, |
|
"height": height, |
|
"batch_size": batch_size |
|
} |
|
}, |
|
"5": { |
|
"class_type": "KSampler", |
|
"inputs": { |
|
"seed": seed, |
|
"steps": steps, |
|
"cfg": cfg, |
|
"sampler_name": sampler, |
|
"scheduler": "normal", |
|
"denoise": 1.0, |
|
"model": ["1", 0], |
|
"positive": ["2", 0], |
|
"negative": ["3", 0], |
|
"latent_image": ["4", 0] |
|
} |
|
}, |
|
"6": { |
|
"class_type": "VAEDecode", |
|
"inputs": { |
|
"samples": ["5", 0], |
|
"vae": ["1", 2] |
|
} |
|
}, |
|
"7": { |
|
"class_type": "SaveImage", |
|
"inputs": { |
|
"filename_prefix": f"comfyui_{datetime.now().strftime('%Y%m%d_%H%M%S')}", |
|
"images": ["6", 0] |
|
} |
|
} |
|
} |
|
|
|
# Add LoRA if specified |
|
if lora: |
|
workflow["8"] = { |
|
"class_type": "LoraLoader", |
|
"inputs": { |
|
"lora_name": lora, |
|
"strength_model": lora_strength, |
|
"strength_clip": lora_strength, |
|
"model": ["1", 0], |
|
"clip": ["1", 1] |
|
} |
|
} |
|
# Update connections to use LoRA outputs |
|
workflow["2"]["inputs"]["clip"] = ["8", 1] |
|
workflow["3"]["inputs"]["clip"] = ["8", 1] |
|
workflow["5"]["inputs"]["model"] = ["8", 0] |
|
|
|
return workflow |
|
|
|
# MCP Tool handlers |
|
@server.list_tools() |
|
async def handle_list_tools() -> List[types.Tool]: |
|
"""List available MCP tools""" |
|
return [ |
|
types.Tool( |
|
name="generate-image", |
|
description="Generate an image using ComfyUI with a simple workflow", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"prompt": {"type": "string", "description": "The positive prompt for image generation"}, |
|
"negative_prompt": {"type": "string", "description": "The negative prompt (what to avoid)"}, |
|
"checkpoint": {"type": "string", "description": "Checkpoint model to use (use list-checkpoints to see available)"}, |
|
"lora": {"type": "string", "description": "LoRA model to use (optional, use list-loras to see available)"}, |
|
"lora_strength": {"type": "number", "description": "LoRA strength (0.0-2.0, default 1.0)"}, |
|
"width": {"type": "integer", "description": "Image width (default 1024)"}, |
|
"height": {"type": "integer", "description": "Image height (default 1024)"}, |
|
"batch_size": {"type": "integer", "description": "Number of images to generate (default 1)"}, |
|
"steps": {"type": "integer", "description": "Number of sampling steps (default 25)"}, |
|
"cfg": {"type": "number", "description": "CFG scale (default 3.5)"}, |
|
"sampler": {"type": "string", "description": "Sampler name (default euler_ancestral)"}, |
|
"seed": {"type": "integer", "description": "Random seed (optional)"} |
|
}, |
|
"required": ["prompt"] |
|
} |
|
), |
|
types.Tool( |
|
name="list-workflows", |
|
description="List available workflow templates", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": {}, |
|
"required": [] |
|
} |
|
), |
|
types.Tool( |
|
name="get-workflow", |
|
description="Get a workflow template by name", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"name": {"type": "string", "description": "Name of the workflow template"} |
|
}, |
|
"required": ["name"] |
|
} |
|
), |
|
types.Tool( |
|
name="submit-workflow", |
|
description="Submit a custom workflow to ComfyUI", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"workflow": {"type": "object", "description": "The workflow JSON object"} |
|
}, |
|
"required": ["workflow"] |
|
} |
|
), |
|
types.Tool( |
|
name="list-loras", |
|
description="List available LoRA models", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"search": {"type": "string", "description": "Search term to filter LoRAs (optional)"} |
|
}, |
|
"required": [] |
|
} |
|
), |
|
types.Tool( |
|
name="get-lora-info", |
|
description="Get information about a specific LoRA model", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"name": {"type": "string", "description": "LoRA filename"} |
|
}, |
|
"required": ["name"] |
|
} |
|
), |
|
types.Tool( |
|
name="upload-lora", |
|
description="Upload a LoRA model to ComfyUI", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"filename": {"type": "string", "description": "Filename for the LoRA (must end with .safetensors)"}, |
|
"content": {"type": "string", "description": "Base64-encoded file content"}, |
|
"metadata": {"type": "object", "description": "Optional metadata for the LoRA"} |
|
}, |
|
"required": ["filename", "content"] |
|
} |
|
), |
|
types.Tool( |
|
name="upload-lora-chunked-start", |
|
description="Start a chunked upload for a large LoRA model", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"upload_id": {"type": "string", "description": "Unique identifier for this upload session"}, |
|
"filename": {"type": "string", "description": "Filename for the LoRA (must end with .safetensors)"}, |
|
"total_size": {"type": "integer", "description": "Total file size in bytes"}, |
|
"metadata": {"type": "object", "description": "Optional metadata for the LoRA"} |
|
}, |
|
"required": ["upload_id", "filename", "total_size"] |
|
} |
|
), |
|
types.Tool( |
|
name="upload-lora-chunked-append", |
|
description="Append a chunk to an ongoing LoRA upload", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"upload_id": {"type": "string", "description": "Upload session identifier"}, |
|
"chunk": {"type": "string", "description": "Base64-encoded chunk content"}, |
|
"chunk_index": {"type": "integer", "description": "Chunk sequence number (starting from 0)"} |
|
}, |
|
"required": ["upload_id", "chunk", "chunk_index"] |
|
} |
|
), |
|
types.Tool( |
|
name="upload-lora-chunked-finish", |
|
description="Finalize a chunked LoRA upload", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"upload_id": {"type": "string", "description": "Upload session identifier"} |
|
}, |
|
"required": ["upload_id"] |
|
} |
|
), |
|
types.Tool( |
|
name="list-checkpoints", |
|
description="List available checkpoint models", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": {}, |
|
"required": [] |
|
} |
|
), |
|
types.Tool( |
|
name="upload-checkpoint", |
|
description="Upload a checkpoint model to ComfyUI", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"filename": {"type": "string", "description": "Filename for the checkpoint (must end with .safetensors, .ckpt, or .pt)"}, |
|
"content": {"type": "string", "description": "Base64-encoded file content"} |
|
}, |
|
"required": ["filename", "content"] |
|
} |
|
), |
|
types.Tool( |
|
name="get-comfyui-nodes", |
|
description="Get all available ComfyUI node types", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"category": {"type": "string", "description": "Filter by category (optional)"} |
|
}, |
|
"required": [] |
|
} |
|
), |
|
types.Tool( |
|
name="get-node-info", |
|
description="Get detailed information about a specific ComfyUI node", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"node_type": {"type": "string", "description": "The node class type"} |
|
}, |
|
"required": ["node_type"] |
|
} |
|
), |
|
types.Tool( |
|
name="validate-workflow", |
|
description="Validate a workflow before submission", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"workflow": {"type": "object", "description": "The workflow to validate"} |
|
}, |
|
"required": ["workflow"] |
|
} |
|
), |
|
types.Tool( |
|
name="get-generation-status", |
|
description="Get the status of a generation by prompt ID", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"prompt_id": {"type": "string", "description": "The prompt ID returned from generation"} |
|
}, |
|
"required": ["prompt_id"] |
|
} |
|
), |
|
types.Tool( |
|
name="get-system-stats", |
|
description="Get ComfyUI system statistics", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": {}, |
|
"required": [] |
|
} |
|
), |
|
types.Tool( |
|
name="list-outputs", |
|
description="List recently generated output images", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"max_items": {"type": "integer", "description": "Maximum number of outputs to list (default 20)"} |
|
}, |
|
"required": [] |
|
} |
|
), |
|
types.Tool( |
|
name="download-output", |
|
description="Download a generated output image as base64", |
|
inputSchema={ |
|
"type": "object", |
|
"properties": { |
|
"filename": {"type": "string", "description": "The filename of the output image"}, |
|
"subfolder": {"type": "string", "description": "Optional subfolder path"}, |
|
"save_to": {"type": "string", "description": "Optional local file path to save the downloaded image"} |
|
}, |
|
"required": ["filename"] |
|
} |
|
) |
|
] |
|
|
|
@server.call_tool() |
|
async def handle_call_tool( |
|
name: str, |
|
arguments: dict | None |
|
) -> List[types.TextContent | types.ImageContent | types.EmbeddedResource]: |
|
"""Handle tool execution requests""" |
|
|
|
if name not in MCP_TOOLS: |
|
raise ValueError(f"Unknown tool: {name}") |
|
|
|
client = ComfyUIClient(COMFYUI_SERVER_URL) |
|
|
|
try: |
|
if name == "generate-image": |
|
# Create workflow from parameters |
|
workflow = create_simple_workflow( |
|
prompt=arguments.get("prompt", ""), |
|
negative_prompt=arguments.get("negative_prompt", ""), |
|
checkpoint=arguments.get("checkpoint", "flux1-dev-fp8.safetensors"), |
|
lora=arguments.get("lora"), |
|
lora_strength=arguments.get("lora_strength", 1.0), |
|
width=arguments.get("width", 1024), |
|
height=arguments.get("height", 1024), |
|
batch_size=arguments.get("batch_size", 1), |
|
steps=arguments.get("steps", 25), |
|
cfg=arguments.get("cfg", 3.5), |
|
sampler=arguments.get("sampler", "euler_ancestral"), |
|
seed=arguments.get("seed") |
|
) |
|
|
|
# Submit workflow and wait for completion |
|
prompt_id = await client.queue_prompt(workflow) |
|
logger.info(f"Queued prompt: {prompt_id}") |
|
|
|
output_images = await client.wait_for_completion(prompt_id) |
|
|
|
if output_images: |
|
result = f"Generated {len(output_images)} image(s):\n" |
|
result += "\n".join([f"- {img}" for img in output_images]) |
|
else: |
|
result = "Generation completed but no images were saved." |
|
|
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "list-workflows": |
|
workflows = list_workflows() |
|
if workflows: |
|
result = "Available workflow templates:\n" |
|
result += "\n".join([f"- {w}" for w in workflows]) |
|
else: |
|
result = "No workflow templates found." |
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "get-workflow": |
|
workflow_name = arguments.get("name") |
|
workflow = load_workflow(workflow_name) |
|
if workflow: |
|
result = f"Workflow '{workflow_name}':\n\n{json.dumps(workflow, indent=2)}" |
|
else: |
|
result = f"Workflow '{workflow_name}' not found." |
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "submit-workflow": |
|
workflow = arguments.get("workflow") |
|
if not isinstance(workflow, dict): |
|
return [types.TextContent(type="text", text="Invalid workflow format. Must be a JSON object.")] |
|
|
|
prompt_id = await client.queue_prompt(workflow) |
|
output_images = await client.wait_for_completion(prompt_id) |
|
|
|
if output_images: |
|
result = f"Workflow completed. Generated {len(output_images)} image(s):\n" |
|
result += "\n".join([f"- {img}" for img in output_images]) |
|
else: |
|
result = "Workflow completed but no images were saved." |
|
|
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "list-loras": |
|
search_term = arguments.get("search") |
|
loras = list_lora_models(search_term) |
|
if loras: |
|
result = "Available LoRA models" |
|
if search_term: |
|
result += f" matching '{search_term}'" |
|
result += ":\n" |
|
result += "\n".join([f"- {lora}" for lora in loras]) |
|
else: |
|
result = "No LoRA models found." |
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "get-lora-info": |
|
lora_name = arguments.get("name") |
|
info = get_lora_info(lora_name) |
|
if info: |
|
result = f"LoRA information for '{lora_name}':\n\n{json.dumps(info, indent=2)}" |
|
else: |
|
result = f"LoRA '{lora_name}' not found." |
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "upload-lora": |
|
filename = arguments.get("filename") |
|
content = arguments.get("content") |
|
metadata = arguments.get("metadata") |
|
|
|
# Validate filename |
|
if not filename.endswith('.safetensors'): |
|
return [types.TextContent(type="text", text="Error: LoRA filename must end with .safetensors")] |
|
|
|
# Ensure filename is safe (no path traversal) |
|
safe_filename = Path(filename).name |
|
lora_path = LORA_DIR / safe_filename |
|
|
|
try: |
|
# Decode and save the file |
|
file_data = base64.b64decode(content) |
|
with open(lora_path, 'wb') as f: |
|
f.write(file_data) |
|
|
|
# Save metadata if provided |
|
if metadata: |
|
metadata_path = lora_path.with_suffix('.metadata.json') |
|
with open(metadata_path, 'w') as f: |
|
json.dump(metadata, f, indent=2) |
|
result = f"Successfully uploaded LoRA: {safe_filename}\nMetadata saved to: {metadata_path.name}" |
|
else: |
|
result = f"Successfully uploaded LoRA: {safe_filename}" |
|
|
|
# Verify it appears in the list |
|
loras = list_lora_models() |
|
if safe_filename in loras: |
|
result += f"\n✓ Verified: LoRA now appears in model list" |
|
|
|
return [types.TextContent(type="text", text=result)] |
|
|
|
except Exception as e: |
|
# Clean up on failure |
|
if lora_path.exists(): |
|
lora_path.unlink() |
|
return [types.TextContent(type="text", text=f"Error uploading LoRA: {str(e)}")] |
|
|
|
elif name == "upload-lora-chunked-start": |
|
upload_id = arguments.get("upload_id") |
|
filename = arguments.get("filename") |
|
total_size = arguments.get("total_size") |
|
metadata = arguments.get("metadata") |
|
|
|
# Validate filename |
|
if not filename.endswith('.safetensors'): |
|
return [types.TextContent(type="text", text="Error: LoRA filename must end with .safetensors")] |
|
|
|
# Ensure filename is safe (no path traversal) |
|
safe_filename = Path(filename).name |
|
|
|
# Initialize upload session |
|
chunked_uploads[upload_id] = { |
|
"filename": safe_filename, |
|
"total_size": total_size, |
|
"metadata": metadata, |
|
"chunks": {}, |
|
"received_size": 0, |
|
"start_time": datetime.now() |
|
} |
|
|
|
return [types.TextContent(type="text", text=f"Chunked upload started for '{safe_filename}'\nUpload ID: {upload_id}\nExpected size: {total_size} bytes")] |
|
|
|
elif name == "upload-lora-chunked-append": |
|
upload_id = arguments.get("upload_id") |
|
chunk = arguments.get("chunk") |
|
chunk_index = arguments.get("chunk_index") |
|
|
|
# Check if upload session exists |
|
if upload_id not in chunked_uploads: |
|
return [types.TextContent(type="text", text=f"Error: Upload session '{upload_id}' not found")] |
|
|
|
session = chunked_uploads[upload_id] |
|
|
|
# Decode chunk |
|
try: |
|
chunk_data = base64.b64decode(chunk) |
|
session["chunks"][chunk_index] = chunk_data |
|
session["received_size"] += len(chunk_data) |
|
|
|
# Calculate progress |
|
progress = (session["received_size"] / session["total_size"]) * 100 |
|
|
|
return [types.TextContent( |
|
type="text", |
|
text=f"Chunk {chunk_index} received ({len(chunk_data)} bytes)\nProgress: {progress:.1f}% ({session['received_size']}/{session['total_size']} bytes)" |
|
)] |
|
except Exception as e: |
|
return [types.TextContent(type="text", text=f"Error processing chunk: {str(e)}")] |
|
|
|
elif name == "upload-lora-chunked-finish": |
|
upload_id = arguments.get("upload_id") |
|
|
|
# Check if upload session exists |
|
if upload_id not in chunked_uploads: |
|
return [types.TextContent(type="text", text=f"Error: Upload session '{upload_id}' not found")] |
|
|
|
session = chunked_uploads[upload_id] |
|
filename = session["filename"] |
|
metadata = session.get("metadata") |
|
|
|
try: |
|
# Combine all chunks in order |
|
chunk_indices = sorted(session["chunks"].keys()) |
|
combined_data = b"" |
|
|
|
for idx in chunk_indices: |
|
combined_data += session["chunks"][idx] |
|
|
|
# Verify size |
|
if len(combined_data) != session["total_size"]: |
|
return [types.TextContent( |
|
type="text", |
|
text=f"Error: Size mismatch. Expected {session['total_size']} bytes, got {len(combined_data)} bytes" |
|
)] |
|
|
|
# Save the file |
|
lora_path = LORA_DIR / filename |
|
with open(lora_path, 'wb') as f: |
|
f.write(combined_data) |
|
|
|
# Save metadata if provided |
|
if metadata: |
|
metadata_path = lora_path.with_suffix('.metadata.json') |
|
with open(metadata_path, 'w') as f: |
|
json.dump(metadata, f, indent=2) |
|
|
|
# Calculate upload time |
|
upload_time = (datetime.now() - session["start_time"]).total_seconds() |
|
|
|
# Clean up session |
|
del chunked_uploads[upload_id] |
|
|
|
# Verify it appears in the list |
|
loras = list_lora_models() |
|
verified = filename in loras |
|
|
|
result = f"Successfully uploaded LoRA: {filename}\n" |
|
result += f"Upload time: {upload_time:.1f} seconds\n" |
|
result += f"File size: {len(combined_data) / (1024*1024):.1f} MB" |
|
if metadata: |
|
result += f"\nMetadata saved" |
|
if verified: |
|
result += f"\n✓ Verified: LoRA now appears in model list" |
|
|
|
return [types.TextContent(type="text", text=result)] |
|
|
|
except Exception as e: |
|
# Clean up on failure |
|
if upload_id in chunked_uploads: |
|
del chunked_uploads[upload_id] |
|
return [types.TextContent(type="text", text=f"Error finalizing upload: {str(e)}")] |
|
|
|
elif name == "list-checkpoints": |
|
checkpoints = list_checkpoints() |
|
if checkpoints: |
|
result = "Available checkpoint models:\n" |
|
result += "\n".join([f"- {ckpt}" for ckpt in checkpoints]) |
|
else: |
|
result = "No checkpoint models found." |
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "upload-checkpoint": |
|
filename = arguments.get("filename") |
|
content = arguments.get("content") |
|
|
|
# Validate filename |
|
valid_extensions = ['.safetensors', '.ckpt', '.pt'] |
|
if not any(filename.endswith(ext) for ext in valid_extensions): |
|
return [types.TextContent(type="text", text=f"Error: Checkpoint filename must end with one of: {', '.join(valid_extensions)}")] |
|
|
|
# Ensure filename is safe (no path traversal) |
|
safe_filename = Path(filename).name |
|
checkpoint_path = CHECKPOINT_DIR / safe_filename |
|
|
|
try: |
|
# Decode and save the file |
|
file_data = base64.b64decode(content) |
|
with open(checkpoint_path, 'wb') as f: |
|
f.write(file_data) |
|
|
|
result = f"Successfully uploaded checkpoint: {safe_filename}" |
|
|
|
# Verify it appears in the list |
|
checkpoints = list_checkpoints() |
|
if safe_filename in checkpoints: |
|
result += f"\n✓ Verified: Checkpoint now appears in model list" |
|
|
|
return [types.TextContent(type="text", text=result)] |
|
|
|
except Exception as e: |
|
# Clean up on failure |
|
if checkpoint_path.exists(): |
|
checkpoint_path.unlink() |
|
return [types.TextContent(type="text", text=f"Error uploading checkpoint: {str(e)}")] |
|
|
|
elif name == "get-comfyui-nodes": |
|
node_info = await client.get_object_info() |
|
category_filter = arguments.get("category", "").lower() |
|
|
|
nodes_by_category = {} |
|
for node_type, info in node_info.items(): |
|
category = info.get("category", "Uncategorized") |
|
if category_filter and category_filter not in category.lower(): |
|
continue |
|
if category not in nodes_by_category: |
|
nodes_by_category[category] = [] |
|
nodes_by_category[category].append(node_type) |
|
|
|
result = "Available ComfyUI nodes:\n\n" |
|
for category, nodes in sorted(nodes_by_category.items()): |
|
result += f"{category}:\n" |
|
for node in sorted(nodes): |
|
result += f" - {node}\n" |
|
result += "\n" |
|
|
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "get-node-info": |
|
node_type = arguments.get("node_type") |
|
all_nodes = await client.get_object_info() |
|
|
|
if node_type in all_nodes: |
|
info = all_nodes[node_type] |
|
result = f"Node information for '{node_type}':\n\n{json.dumps(info, indent=2)}" |
|
else: |
|
result = f"Node type '{node_type}' not found." |
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "validate-workflow": |
|
workflow = arguments.get("workflow") |
|
if not isinstance(workflow, dict): |
|
return [types.TextContent(type="text", text="Invalid workflow format. Must be a JSON object.")] |
|
|
|
# Basic validation - check node structure |
|
errors = [] |
|
for node_id, node_data in workflow.items(): |
|
if not isinstance(node_data, dict): |
|
errors.append(f"Node {node_id}: Invalid node data format") |
|
continue |
|
if "class_type" not in node_data: |
|
errors.append(f"Node {node_id}: Missing 'class_type'") |
|
if "inputs" not in node_data: |
|
errors.append(f"Node {node_id}: Missing 'inputs'") |
|
|
|
if errors: |
|
result = "Workflow validation failed:\n" |
|
result += "\n".join([f"- {error}" for error in errors]) |
|
else: |
|
result = "Workflow validation passed." |
|
|
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "get-generation-status": |
|
prompt_id = arguments.get("prompt_id") |
|
history = await client.get_history(prompt_id) |
|
|
|
if prompt_id in history: |
|
status = history[prompt_id] |
|
result = f"Generation status for prompt {prompt_id}:\n\n{json.dumps(status, indent=2)}" |
|
else: |
|
result = f"No history found for prompt ID: {prompt_id}" |
|
|
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "get-system-stats": |
|
stats = await client.get_system_stats() |
|
result = f"ComfyUI System Statistics:\n\n{json.dumps(stats, indent=2)}" |
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "list-outputs": |
|
max_items = arguments.get("max_items", 20) if arguments else 20 |
|
history = await client.get_all_history(max_items) |
|
|
|
outputs = [] |
|
for prompt_id, data in history.items(): |
|
if 'outputs' in data: |
|
timestamp = data.get('_timestamp', 'unknown') |
|
for node_id, node_output in data['outputs'].items(): |
|
if 'images' in node_output: |
|
for image in node_output['images']: |
|
outputs.append({ |
|
'prompt_id': prompt_id, |
|
'filename': image['filename'], |
|
'subfolder': image.get('subfolder', ''), |
|
'timestamp': timestamp |
|
}) |
|
|
|
result = f"Recent output images ({len(outputs)} found):\n\n" |
|
for idx, output in enumerate(outputs): |
|
result += f"{idx + 1}. {output['filename']}\n" |
|
result += f" Prompt ID: {output['prompt_id']}\n" |
|
if output['subfolder']: |
|
result += f" Subfolder: {output['subfolder']}\n" |
|
result += f" Timestamp: {output['timestamp']}\n\n" |
|
|
|
if not outputs: |
|
result = "No output images found in recent history." |
|
|
|
return [types.TextContent(type="text", text=result)] |
|
|
|
elif name == "download-output": |
|
filename = arguments.get("filename") |
|
subfolder = arguments.get("subfolder", "") |
|
save_to = arguments.get("save_to") |
|
|
|
if not filename: |
|
return [types.TextContent(type="text", text="Error: filename is required")] |
|
|
|
try: |
|
# Download the image |
|
image_data = await client.download_output(filename, subfolder) |
|
|
|
# If save_to is specified, save locally |
|
if save_to: |
|
save_path = Path(save_to) |
|
save_path.parent.mkdir(parents=True, exist_ok=True) |
|
with open(save_path, 'wb') as f: |
|
f.write(image_data) |
|
result = f"Downloaded {filename} and saved to {save_path}" |
|
else: |
|
# Return as base64 |
|
b64_data = base64.b64encode(image_data).decode() |
|
result = { |
|
"filename": filename, |
|
"size_bytes": len(image_data), |
|
"base64": b64_data |
|
} |
|
|
|
return [types.TextContent(type="text", text=json.dumps(result))] |
|
|
|
except Exception as e: |
|
return [types.TextContent(type="text", text=f"Error downloading {filename}: {str(e)}")] |
|
|
|
except Exception as e: |
|
logger.error(f"Error executing tool {name}: {e}") |
|
return [types.TextContent(type="text", text=f"Error: {str(e)}")] |
|
|
|
finally: |
|
await client.disconnect_websocket() |
|
|
|
async def main(): |
|
"""Run the MCP server""" |
|
logger.info("Starting ComfyUI MCP Server") |
|
|
|
# No need to create directories for gist - all files are in root |
|
|
|
# Run the server using stdin/stdout streams |
|
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): |
|
await server.run( |
|
read_stream, |
|
write_stream, |
|
InitializationOptions( |
|
server_name="comfyui-mcp", |
|
server_version="1.0.0", |
|
capabilities=server.get_capabilities( |
|
notification_options=NotificationOptions(), |
|
experimental_capabilities={}, |
|
), |
|
), |
|
) |
|
|
|
if __name__ == "__main__": |
|
asyncio.run(main()) |