Files
comfyui-serverless/handler.py
Nick dba11a9f45
All checks were successful
Build and Push Docker Image / build (push) Successful in 4m2s
Skip bypassed nodes (mode 4) in workflow conversion
Bypassed/muted nodes should not be included in the API workflow,
and connections from bypassed nodes should be ignored.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-27 23:19:24 +13:00

547 lines
18 KiB
Python

"""
ComfyUI RunPod Serverless Handler
Handles image/video generation workflows with ComfyUI API
Wan22-I2V-Remix Workflow Node Mapping:
- Node 148: LoadImage - image input
- Node 134: CLIPTextEncode - positive prompt
- Node 137: CLIPTextEncode - negative prompt
- Node 147: easy int - resolution (720 default)
- Node 150: INTConstant - steps (8 default)
- Node 151: INTConstant - split_step (4 default)
- Node 117: SaveVideo - output
"""
import os
import sys
import json
import time
import base64
import uuid
import subprocess
import signal
import requests
from pathlib import Path
import runpod
# Configuration
COMFYUI_DIR = "/workspace/ComfyUI"
COMFYUI_PORT = 8188
COMFYUI_HOST = f"http://127.0.0.1:{COMFYUI_PORT}"
MAX_TIMEOUT = 600 # 10 minutes max for video generation
POLL_INTERVAL = 1.0
STARTUP_TIMEOUT = 120
DEFAULT_WORKFLOW_PATH = "/workspace/workflows/Wan22-I2V-Remix.json"
# Wan22-I2V-Remix node IDs
NODE_IMAGE_INPUT = "148"
NODE_POSITIVE_PROMPT = "134"
NODE_NEGATIVE_PROMPT = "137"
NODE_RESOLUTION = "147"
NODE_STEPS = "150"
NODE_SPLIT_STEP = "151"
NODE_SAVE_VIDEO = "117"
# Global ComfyUI process
comfyui_process = None
def start_comfyui():
"""Start ComfyUI server if not already running."""
global comfyui_process
if comfyui_process is not None and comfyui_process.poll() is None:
return True
print("Starting ComfyUI server...")
comfyui_process = subprocess.Popen(
[
sys.executable, "main.py",
"--listen", "127.0.0.1",
"--port", str(COMFYUI_PORT),
"--disable-auto-launch"
],
cwd=COMFYUI_DIR,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
preexec_fn=os.setsid if hasattr(os, 'setsid') else None
)
# Wait for server to be ready
start_time = time.time()
while time.time() - start_time < STARTUP_TIMEOUT:
try:
resp = requests.get(f"{COMFYUI_HOST}/system_stats", timeout=2)
if resp.status_code == 200:
print("ComfyUI server ready")
return True
except requests.exceptions.RequestException:
pass
time.sleep(1)
print("ComfyUI server failed to start")
return False
def stop_comfyui():
"""Stop ComfyUI server."""
global comfyui_process
if comfyui_process is not None:
try:
os.killpg(os.getpgid(comfyui_process.pid), signal.SIGTERM)
except (OSError, ProcessLookupError):
comfyui_process.terminate()
comfyui_process = None
def load_default_workflow() -> dict:
"""Load the default Wan22-I2V-Remix workflow."""
workflow_path = Path(DEFAULT_WORKFLOW_PATH)
if not workflow_path.exists():
raise FileNotFoundError(f"Default workflow not found: {DEFAULT_WORKFLOW_PATH}")
with open(workflow_path) as f:
return json.load(f)
def convert_frontend_to_api(frontend_workflow: dict) -> dict:
"""Convert ComfyUI frontend format to API format."""
# If already in API format (no 'nodes' key), return as-is
if "nodes" not in frontend_workflow:
return frontend_workflow
api_workflow = {}
nodes = frontend_workflow.get("nodes", [])
links = frontend_workflow.get("links", [])
# Build set of active (non-bypassed) node IDs
active_nodes = {str(node["id"]) for node in nodes if node.get("mode") != 4}
# Build link lookup: link_id -> (source_node_id, source_slot)
# Only include links from active nodes
link_map = {}
for link in links:
link_id, src_node, src_slot, dst_node, dst_slot, link_type = link[:6]
if str(src_node) in active_nodes:
link_map[link_id] = (str(src_node), src_slot)
# Widget mappings for each node type: list of input names in order
WIDGET_MAPPINGS = {
"LoadImage": ["image", "upload"],
"CLIPTextEncode": ["text"],
"easy int": ["value"],
"INTConstant": ["value"],
"SaveVideo": ["filename_prefix", "format", "codec"],
"CreateVideo": ["fps"],
"RIFE VFI": ["ckpt_name", "clear_cache_after_n_frames", "multiplier", "fast_mode", "ensemble", "scale_factor"],
"CLIPLoader": ["clip_name", "type", "device"],
"GetImageSize": [],
"WanVideoTorchCompileSettings": [
"backend", "fullgraph", "mode", "dynamic", "dynamo_cache_size_limit",
"compile_transformer_blocks_only", "compile_cache_max_entries",
"compile_single_blocks", "compile_double_blocks"
],
"WanVideoBlockSwap": [
"blocks_to_swap", "offload_txt_emb", "offload_img_emb", "offload_txt_clip_emb",
"offload_modulation", "cpu_offload_streams", "use_async_transfer"
],
"WanVideoModelLoader": [
"model", "base_precision", "quantization", "load_device", "attention_mode", "lora_scale_mode"
],
"WanVideoLoraSelect": ["lora", "strength", "use_lora_sparse", "sparse_lora_blocks"],
"WanVideoVAELoader": ["model_name", "precision"],
"WanVideoTextEmbedBridge": [],
"WanVideoImageToVideoEncode": [
"width", "height", "num_frames", "noise_aug_strength",
"start_latent_strength", "end_latent_strength", "image_noise_aug",
"mask_noise_aug", "force_offload"
],
"WanVideoSampler": [
"shift", "cfg", "steps", "seed", "seed_mode", "force_offload", "scheduler",
"riflex_freq_index", "riflex_freq_dim", "riflex_freq_scale",
"use_comfy_pbar", "start_step", "end_step", "denoise"
],
"WanVideoDecode": [
"enable_tiling", "tile_x", "tile_y", "tile_stride_x", "tile_stride_y", "tiling_decoder"
],
"MathExpression|pysssss": ["expression"],
}
for node in nodes:
# Skip bypassed/muted nodes (mode 4)
if node.get("mode") == 4:
continue
node_id = str(node["id"])
class_type = node.get("type", "")
inputs = {}
# Process widget values using mappings
widgets_values = node.get("widgets_values", [])
widget_names = WIDGET_MAPPINGS.get(class_type, [])
for i, value in enumerate(widgets_values):
if i < len(widget_names) and widget_names[i]:
inputs[widget_names[i]] = value
# Process node inputs (connections) - these override widget values
for inp in node.get("inputs", []):
inp_name = inp["name"]
link_id = inp.get("link")
if link_id is not None and link_id in link_map:
src_node, src_slot = link_map[link_id]
inputs[inp_name] = [src_node, src_slot]
api_workflow[node_id] = {
"class_type": class_type,
"inputs": inputs
}
# Add meta if title exists
if "title" in node:
api_workflow[node_id]["_meta"] = {"title": node["title"]}
return api_workflow
def upload_image(image_base64: str, filename: str = None) -> str:
"""Upload base64 image to ComfyUI and return the filename."""
if filename is None:
filename = f"input_{uuid.uuid4().hex[:8]}.png"
# Decode base64
image_data = base64.b64decode(image_base64)
# Upload to ComfyUI
files = {
"image": (filename, image_data, "image/png"),
}
data = {
"overwrite": "true"
}
resp = requests.post(
f"{COMFYUI_HOST}/upload/image",
files=files,
data=data
)
if resp.status_code != 200:
raise Exception(f"Failed to upload image: {resp.text}")
result = resp.json()
return result.get("name", filename)
def inject_wan22_params(workflow: dict, params: dict) -> dict:
"""Inject parameters into Wan22-I2V-Remix workflow nodes."""
workflow = json.loads(json.dumps(workflow)) # Deep copy
# Image input (node 148)
if "image_filename" in params and NODE_IMAGE_INPUT in workflow:
workflow[NODE_IMAGE_INPUT]["inputs"]["image"] = params["image_filename"]
# Positive prompt (node 134)
if "prompt" in params and NODE_POSITIVE_PROMPT in workflow:
workflow[NODE_POSITIVE_PROMPT]["inputs"]["text"] = params["prompt"]
# Negative prompt (node 137) - optional override
if "negative_prompt" in params and NODE_NEGATIVE_PROMPT in workflow:
workflow[NODE_NEGATIVE_PROMPT]["inputs"]["text"] = params["negative_prompt"]
# Resolution (node 147)
if "resolution" in params and NODE_RESOLUTION in workflow:
workflow[NODE_RESOLUTION]["inputs"]["value"] = params["resolution"]
# Steps (node 150)
if "steps" in params and NODE_STEPS in workflow:
workflow[NODE_STEPS]["inputs"]["value"] = params["steps"]
# Split step (node 151)
if "split_step" in params and NODE_SPLIT_STEP in workflow:
workflow[NODE_SPLIT_STEP]["inputs"]["value"] = params["split_step"]
return workflow
def queue_workflow(workflow: dict, client_id: str = None) -> str:
"""Queue workflow and return prompt_id."""
if client_id is None:
client_id = uuid.uuid4().hex
payload = {
"prompt": workflow,
"client_id": client_id
}
resp = requests.post(
f"{COMFYUI_HOST}/prompt",
json=payload
)
if resp.status_code != 200:
raise Exception(f"Failed to queue workflow: {resp.text}")
result = resp.json()
# Debug: print full queue response
print(f"Queue response keys: {result.keys()}")
if "node_errors" in result and result["node_errors"]:
print(f"Node errors: {result['node_errors']}")
if "error" in result:
print(f"Queue error: {result['error']}")
return result["prompt_id"]
def get_history(prompt_id: str) -> dict:
"""Get execution history for a prompt."""
resp = requests.get(f"{COMFYUI_HOST}/history/{prompt_id}")
if resp.status_code != 200:
return {}
return resp.json()
def poll_for_completion(prompt_id: str, timeout: int = MAX_TIMEOUT) -> dict:
"""Poll until workflow completes or timeout."""
start_time = time.time()
while time.time() - start_time < timeout:
history = get_history(prompt_id)
if prompt_id in history:
status = history[prompt_id].get("status", {})
# Check for completion
if status.get("completed", False):
return history[prompt_id]
# Check for error
if status.get("status_str") == "error":
raise Exception(f"Workflow execution failed: {status}")
time.sleep(POLL_INTERVAL)
raise TimeoutError(f"Workflow execution timed out after {timeout}s")
def get_output_files(history: dict) -> list:
"""Extract output file info from history."""
outputs = []
if "outputs" not in history:
return outputs
for node_id, node_output in history["outputs"].items():
# Handle image outputs
if "images" in node_output:
for img in node_output["images"]:
outputs.append({
"type": "image",
"filename": img["filename"],
"subfolder": img.get("subfolder", ""),
"type_folder": img.get("type", "output")
})
# Handle video outputs (SaveVideo node)
if "videos" in node_output:
for vid in node_output["videos"]:
outputs.append({
"type": "video",
"filename": vid["filename"],
"subfolder": vid.get("subfolder", ""),
"type_folder": vid.get("type", "output")
})
# Handle video outputs (VideoHelperSuite gifs)
if "gifs" in node_output:
for vid in node_output["gifs"]:
outputs.append({
"type": "video",
"filename": vid["filename"],
"subfolder": vid.get("subfolder", ""),
"type_folder": vid.get("type", "output")
})
# Handle generic files
if "files" in node_output:
for f in node_output["files"]:
filename = f.get("filename", "")
ext = Path(filename).suffix.lower()
file_type = "video" if ext in [".mp4", ".webm", ".gif", ".mov"] else "image"
outputs.append({
"type": file_type,
"filename": filename,
"subfolder": f.get("subfolder", ""),
"type_folder": f.get("type", "output")
})
return outputs
def fetch_output(output_info: dict) -> bytes:
"""Fetch output file from ComfyUI."""
params = {
"filename": output_info["filename"],
"subfolder": output_info["subfolder"],
"type": output_info["type_folder"]
}
resp = requests.get(f"{COMFYUI_HOST}/view", params=params)
if resp.status_code != 200:
raise Exception(f"Failed to fetch output: {resp.status_code}")
return resp.content
def handler(job: dict) -> dict:
"""
RunPod serverless handler.
Input schema:
{
"image": "base64 encoded image (required)",
"prompt": "positive prompt text (required)",
"negative_prompt": "negative prompt (optional)",
"resolution": 720 (optional, default 720),
"steps": 8 (optional, default 8),
"split_step": 4 (optional, default 4),
"timeout": 600 (optional, max 600),
"workflow": {} (optional, override default workflow)
}
"""
job_input = job.get("input", {})
# Validate required inputs
if "image" not in job_input or not job_input["image"]:
return {"error": "Missing required 'image' (base64) in input"}
if "prompt" not in job_input or not job_input["prompt"]:
return {"error": "Missing required 'prompt' in input"}
# Ensure ComfyUI is running
if not start_comfyui():
return {"error": "Failed to start ComfyUI server"}
try:
# Load workflow (custom or default)
if "workflow" in job_input and job_input["workflow"]:
workflow = job_input["workflow"]
# Convert frontend format if needed
workflow = convert_frontend_to_api(workflow)
else:
# Load and convert default workflow
frontend_workflow = load_default_workflow()
workflow = convert_frontend_to_api(frontend_workflow)
# Upload image
image_filename = upload_image(job_input["image"])
print(f"Uploaded image: {image_filename}")
# Build params for injection
params = {
"image_filename": image_filename,
"prompt": job_input["prompt"]
}
if "negative_prompt" in job_input:
params["negative_prompt"] = job_input["negative_prompt"]
if "resolution" in job_input:
params["resolution"] = int(job_input["resolution"])
if "steps" in job_input:
params["steps"] = int(job_input["steps"])
if "split_step" in job_input:
params["split_step"] = int(job_input["split_step"])
# Inject parameters into workflow
workflow = inject_wan22_params(workflow, params)
# Debug: print output chain nodes to verify connections
print("=== Workflow Output Chain ===")
# Check the output chain: 117 <- 116 <- 115 <- 158 <- 140
for node_id in ["117", "116", "115"]:
if node_id in workflow:
node = workflow[node_id]
print(f"Node {node_id} ({node['class_type']}): {node['inputs']}")
else:
print(f"Node {node_id}: MISSING FROM WORKFLOW!")
print(f"Total nodes in workflow: {len(workflow)}")
# Queue workflow
client_id = uuid.uuid4().hex
prompt_id = queue_workflow(workflow, client_id)
print(f"Queued workflow: {prompt_id}")
# Poll for completion
timeout = min(job_input.get("timeout", MAX_TIMEOUT), MAX_TIMEOUT)
history = poll_for_completion(prompt_id, timeout)
print("Workflow completed")
# Debug: print history structure
print(f"History keys: {history.keys()}")
if "outputs" in history:
print(f"Output nodes: {list(history['outputs'].keys())}")
for node_id, node_out in history["outputs"].items():
print(f" Node {node_id}: {list(node_out.keys())}")
if "status" in history:
print(f"Status: {history['status']}")
# Get output files
outputs = get_output_files(history)
if not outputs:
return {"error": "No outputs generated"}
# Fetch and encode outputs
results = []
for output_info in outputs:
data = fetch_output(output_info)
print(f"Fetched output: {output_info['filename']} ({len(data)} bytes)")
# Check size for video files
if output_info["type"] == "video" and len(data) > 10 * 1024 * 1024:
# For large videos, save to network volume and return path
output_path = Path("/runpod-volume/outputs") / output_info["filename"]
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_bytes(data)
results.append({
"type": output_info["type"],
"filename": output_info["filename"],
"path": str(output_path),
"size": len(data)
})
else:
# Return as base64
results.append({
"type": output_info["type"],
"filename": output_info["filename"],
"data": base64.b64encode(data).decode("utf-8"),
"size": len(data)
})
return {
"status": "success",
"prompt_id": prompt_id,
"outputs": results
}
except TimeoutError as e:
return {"error": str(e), "status": "timeout"}
except Exception as e:
import traceback
traceback.print_exc()
return {"error": str(e), "status": "error"}
# RunPod serverless entry point
if __name__ == "__main__":
print("Starting ComfyUI RunPod Handler...")
runpod.serverless.start({"handler": handler})