""" ComfyUI RunPod Serverless Handler Handles image/video generation workflows with ComfyUI API Wan22-I2V-Remix Workflow Node Mapping: - Node 148: LoadImage - image input - Node 134: CLIPTextEncode - positive prompt - Node 137: CLIPTextEncode - negative prompt - Node 147: easy int - resolution (720 default) - Node 150: INTConstant - steps (8 default) - Node 151: INTConstant - split_step (4 default) - Node 117: SaveVideo - output """ import os import sys import json import time import base64 import uuid import subprocess import signal import requests from pathlib import Path import runpod # Configuration COMFYUI_DIR = "/workspace/ComfyUI" COMFYUI_PORT = 8188 COMFYUI_HOST = f"http://127.0.0.1:{COMFYUI_PORT}" MAX_TIMEOUT = 600 # 10 minutes max for video generation POLL_INTERVAL = 1.0 STARTUP_TIMEOUT = 120 DEFAULT_WORKFLOW_PATH = "/workspace/workflows/Wan22-I2V-Remix-API.json" # Wan22-I2V-Remix node IDs NODE_IMAGE_INPUT = "148" NODE_POSITIVE_PROMPT = "134" NODE_NEGATIVE_PROMPT = "137" NODE_RESOLUTION = "147" NODE_STEPS = "150" NODE_SPLIT_STEP = "151" NODE_SAVE_VIDEO = "117" # Global ComfyUI process comfyui_process = None def start_comfyui(): """Start ComfyUI server if not already running.""" global comfyui_process if comfyui_process is not None and comfyui_process.poll() is None: return True # Debug: show model paths print("=== Model Path Debug ===") import subprocess for path in ["/runpod-volume", "/runpod-volume/models", "/workspace/ComfyUI/models"]: result = subprocess.run(["ls", "-la", path], capture_output=True, text=True) print(f"{path}:") print(result.stdout[:500] if result.stdout else f" Error: {result.stderr}") # Find any .safetensors files result = subprocess.run(["find", "/runpod-volume", "-name", "*.safetensors", "-type", "f"], capture_output=True, text=True, timeout=30) print(f"Safetensors on volume:\n{result.stdout[:1000] if result.stdout else 'None found'}") print("Starting ComfyUI server...") comfyui_process = subprocess.Popen( [ sys.executable, "main.py", "--listen", "127.0.0.1", "--port", str(COMFYUI_PORT), "--disable-auto-launch" ], cwd=COMFYUI_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, preexec_fn=os.setsid if hasattr(os, 'setsid') else None ) # Wait for server to be ready start_time = time.time() while time.time() - start_time < STARTUP_TIMEOUT: try: resp = requests.get(f"{COMFYUI_HOST}/system_stats", timeout=2) if resp.status_code == 200: print("ComfyUI server ready") return True except requests.exceptions.RequestException: pass time.sleep(1) print("ComfyUI server failed to start") return False def stop_comfyui(): """Stop ComfyUI server.""" global comfyui_process if comfyui_process is not None: try: os.killpg(os.getpgid(comfyui_process.pid), signal.SIGTERM) except (OSError, ProcessLookupError): comfyui_process.terminate() comfyui_process = None def load_default_workflow() -> dict: """Load the default Wan22-I2V-Remix workflow.""" workflow_path = Path(DEFAULT_WORKFLOW_PATH) if not workflow_path.exists(): raise FileNotFoundError(f"Default workflow not found: {DEFAULT_WORKFLOW_PATH}") with open(workflow_path) as f: return json.load(f) def convert_frontend_to_api(frontend_workflow: dict) -> dict: """Convert ComfyUI frontend format to API format.""" # If already in API format (no 'nodes' key), return as-is if "nodes" not in frontend_workflow: return frontend_workflow api_workflow = {} nodes = frontend_workflow.get("nodes", []) links = frontend_workflow.get("links", []) # Build set of active (non-bypassed) node IDs active_nodes = {str(node["id"]) for node in nodes if node.get("mode") != 4} # Build link lookup: link_id -> (source_node_id, source_slot) # Only include links from active nodes link_map = {} for link in links: link_id, src_node, src_slot, dst_node, dst_slot, link_type = link[:6] if str(src_node) in active_nodes: link_map[link_id] = (str(src_node), src_slot) # Widget mappings for each node type: list of input names in order WIDGET_MAPPINGS = { "LoadImage": ["image", "upload"], "CLIPTextEncode": ["text"], "easy int": ["value"], "INTConstant": ["value"], "SaveVideo": ["filename_prefix", "format", "codec"], "CreateVideo": ["fps"], "RIFE VFI": ["ckpt_name", "clear_cache_after_n_frames", "multiplier", "fast_mode", "ensemble", "scale_factor"], "CLIPLoader": ["clip_name", "type", "device"], "GetImageSize": [], "WanVideoTorchCompileSettings": [ "backend", "fullgraph", "mode", "dynamic", "dynamo_cache_size_limit", "compile_transformer_blocks_only", "compile_cache_max_entries", "compile_single_blocks", "compile_double_blocks" ], "WanVideoBlockSwap": [ "blocks_to_swap", "offload_txt_emb", "offload_img_emb", "offload_txt_clip_emb", "offload_modulation", "cpu_offload_streams", "use_async_transfer" ], "WanVideoModelLoader": [ "model", "base_precision", "quantization", "load_device", "attention_mode", "lora_scale_mode" ], "WanVideoLoraSelect": ["lora", "strength", "use_lora_sparse", "sparse_lora_blocks"], "WanVideoVAELoader": ["model_name", "precision"], "WanVideoTextEmbedBridge": [], "WanVideoImageToVideoEncode": [ "width", "height", "num_frames", "noise_aug_strength", "start_latent_strength", "end_latent_strength", "image_noise_aug", "mask_noise_aug", "force_offload" ], "WanVideoSampler": [ "shift", "cfg", "steps", "seed", "seed_mode", "force_offload", "scheduler", "riflex_freq_index", "riflex_freq_dim", "riflex_freq_scale", "use_comfy_pbar", "start_step", "end_step", "denoise" ], "WanVideoDecode": [ "enable_tiling", "tile_x", "tile_y", "tile_stride_x", "tile_stride_y", "tiling_decoder" ], "MathExpression|pysssss": ["expression"], } for node in nodes: # Skip bypassed/muted nodes (mode 4) if node.get("mode") == 4: continue node_id = str(node["id"]) class_type = node.get("type", "") inputs = {} # Process widget values using mappings widgets_values = node.get("widgets_values", []) widget_names = WIDGET_MAPPINGS.get(class_type, []) for i, value in enumerate(widgets_values): if i < len(widget_names) and widget_names[i]: inputs[widget_names[i]] = value # Process node inputs (connections) - these override widget values for inp in node.get("inputs", []): inp_name = inp["name"] link_id = inp.get("link") if link_id is not None and link_id in link_map: src_node, src_slot = link_map[link_id] inputs[inp_name] = [src_node, src_slot] api_workflow[node_id] = { "class_type": class_type, "inputs": inputs } # Add meta if title exists if "title" in node: api_workflow[node_id]["_meta"] = {"title": node["title"]} return api_workflow def upload_image(image_base64: str, filename: str = None) -> str: """Upload base64 image to ComfyUI and return the filename.""" if filename is None: filename = f"input_{uuid.uuid4().hex[:8]}.png" # Decode base64 image_data = base64.b64decode(image_base64) # Upload to ComfyUI files = { "image": (filename, image_data, "image/png"), } data = { "overwrite": "true" } resp = requests.post( f"{COMFYUI_HOST}/upload/image", files=files, data=data ) if resp.status_code != 200: raise Exception(f"Failed to upload image: {resp.text}") result = resp.json() return result.get("name", filename) def inject_wan22_params(workflow: dict, params: dict) -> dict: """Inject parameters into Wan22-I2V-Remix workflow nodes.""" workflow = json.loads(json.dumps(workflow)) # Deep copy # Image input (node 148) if "image_filename" in params and NODE_IMAGE_INPUT in workflow: workflow[NODE_IMAGE_INPUT]["inputs"]["image"] = params["image_filename"] # Positive prompt (node 134) if "prompt" in params and NODE_POSITIVE_PROMPT in workflow: workflow[NODE_POSITIVE_PROMPT]["inputs"]["text"] = params["prompt"] # Negative prompt (node 137) - optional override if "negative_prompt" in params and NODE_NEGATIVE_PROMPT in workflow: workflow[NODE_NEGATIVE_PROMPT]["inputs"]["text"] = params["negative_prompt"] # Resolution (node 147) if "resolution" in params and NODE_RESOLUTION in workflow: workflow[NODE_RESOLUTION]["inputs"]["value"] = params["resolution"] # Steps (node 150) if "steps" in params and NODE_STEPS in workflow: workflow[NODE_STEPS]["inputs"]["value"] = params["steps"] # Split step (node 151) if "split_step" in params and NODE_SPLIT_STEP in workflow: workflow[NODE_SPLIT_STEP]["inputs"]["value"] = params["split_step"] return workflow def queue_workflow(workflow: dict, client_id: str = None) -> str: """Queue workflow and return prompt_id.""" if client_id is None: client_id = uuid.uuid4().hex payload = { "prompt": workflow, "client_id": client_id } resp = requests.post( f"{COMFYUI_HOST}/prompt", json=payload ) if resp.status_code != 200: raise Exception(f"Failed to queue workflow: {resp.text}") result = resp.json() # Debug: print full queue response print(f"Queue response keys: {result.keys()}") if "node_errors" in result and result["node_errors"]: print(f"Node errors: {result['node_errors']}") if "error" in result: print(f"Queue error: {result['error']}") return result["prompt_id"] def get_history(prompt_id: str) -> dict: """Get execution history for a prompt.""" resp = requests.get(f"{COMFYUI_HOST}/history/{prompt_id}") if resp.status_code != 200: return {} return resp.json() def poll_for_completion(prompt_id: str, timeout: int = MAX_TIMEOUT) -> dict: """Poll until workflow completes or timeout.""" start_time = time.time() while time.time() - start_time < timeout: history = get_history(prompt_id) if prompt_id in history: status = history[prompt_id].get("status", {}) # Check for completion if status.get("completed", False): return history[prompt_id] # Check for error if status.get("status_str") == "error": raise Exception(f"Workflow execution failed: {status}") time.sleep(POLL_INTERVAL) raise TimeoutError(f"Workflow execution timed out after {timeout}s") def get_output_files(history: dict) -> list: """Extract output file info from history.""" outputs = [] if "outputs" not in history: return outputs for node_id, node_output in history["outputs"].items(): # Handle image outputs if "images" in node_output: for img in node_output["images"]: outputs.append({ "type": "image", "filename": img["filename"], "subfolder": img.get("subfolder", ""), "type_folder": img.get("type", "output") }) # Handle video outputs (SaveVideo node) if "videos" in node_output: for vid in node_output["videos"]: outputs.append({ "type": "video", "filename": vid["filename"], "subfolder": vid.get("subfolder", ""), "type_folder": vid.get("type", "output") }) # Handle video outputs (VideoHelperSuite gifs) if "gifs" in node_output: for vid in node_output["gifs"]: outputs.append({ "type": "video", "filename": vid["filename"], "subfolder": vid.get("subfolder", ""), "type_folder": vid.get("type", "output") }) # Handle generic files if "files" in node_output: for f in node_output["files"]: filename = f.get("filename", "") ext = Path(filename).suffix.lower() file_type = "video" if ext in [".mp4", ".webm", ".gif", ".mov"] else "image" outputs.append({ "type": file_type, "filename": filename, "subfolder": f.get("subfolder", ""), "type_folder": f.get("type", "output") }) return outputs def fetch_output(output_info: dict) -> bytes: """Fetch output file from ComfyUI.""" params = { "filename": output_info["filename"], "subfolder": output_info["subfolder"], "type": output_info["type_folder"] } resp = requests.get(f"{COMFYUI_HOST}/view", params=params) if resp.status_code != 200: raise Exception(f"Failed to fetch output: {resp.status_code}") return resp.content def handler(job: dict) -> dict: """ RunPod serverless handler. Input schema: { "image": "base64 encoded image (required)", "prompt": "positive prompt text (required)", "negative_prompt": "negative prompt (optional)", "resolution": 720 (optional, default 720), "steps": 8 (optional, default 8), "split_step": 4 (optional, default 4), "timeout": 600 (optional, max 600), "workflow": {} (optional, override default workflow) } """ job_input = job.get("input", {}) # Validate required inputs if "image" not in job_input or not job_input["image"]: return {"error": "Missing required 'image' (base64) in input"} if "prompt" not in job_input or not job_input["prompt"]: return {"error": "Missing required 'prompt' in input"} # Ensure ComfyUI is running if not start_comfyui(): return {"error": "Failed to start ComfyUI server"} try: # Load workflow (custom or default) if "workflow" in job_input and job_input["workflow"]: workflow = job_input["workflow"] # Convert frontend format if needed workflow = convert_frontend_to_api(workflow) else: # Load and convert default workflow frontend_workflow = load_default_workflow() workflow = convert_frontend_to_api(frontend_workflow) # Upload image image_filename = upload_image(job_input["image"]) print(f"Uploaded image: {image_filename}") # Build params for injection params = { "image_filename": image_filename, "prompt": job_input["prompt"] } if "negative_prompt" in job_input: params["negative_prompt"] = job_input["negative_prompt"] if "resolution" in job_input: params["resolution"] = int(job_input["resolution"]) if "steps" in job_input: params["steps"] = int(job_input["steps"]) if "split_step" in job_input: params["split_step"] = int(job_input["split_step"]) # Inject parameters into workflow workflow = inject_wan22_params(workflow, params) # Debug: print output chain nodes to verify connections print("=== Workflow Output Chain ===") # Check the output chain: 117 <- 116 <- 115 <- 158 <- 140 for node_id in ["117", "116", "115"]: if node_id in workflow: node = workflow[node_id] print(f"Node {node_id} ({node['class_type']}): {node['inputs']}") else: print(f"Node {node_id}: MISSING FROM WORKFLOW!") print(f"Total nodes in workflow: {len(workflow)}") # Queue workflow client_id = uuid.uuid4().hex prompt_id = queue_workflow(workflow, client_id) print(f"Queued workflow: {prompt_id}") # Poll for completion timeout = min(job_input.get("timeout", MAX_TIMEOUT), MAX_TIMEOUT) history = poll_for_completion(prompt_id, timeout) print("Workflow completed") # Debug: print history structure print(f"History keys: {history.keys()}") if "outputs" in history: print(f"Output nodes: {list(history['outputs'].keys())}") for node_id, node_out in history["outputs"].items(): print(f" Node {node_id}: {list(node_out.keys())}") if "status" in history: print(f"Status: {history['status']}") # Get output files outputs = get_output_files(history) if not outputs: return {"error": "No outputs generated"} # Fetch and encode outputs results = [] for output_info in outputs: data = fetch_output(output_info) print(f"Fetched output: {output_info['filename']} ({len(data)} bytes)") # Check size for video files if output_info["type"] == "video" and len(data) > 10 * 1024 * 1024: # For large videos, save to network volume and return path output_path = Path("/runpod-volume/outputs") / output_info["filename"] output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_bytes(data) results.append({ "type": output_info["type"], "filename": output_info["filename"], "path": str(output_path), "size": len(data) }) else: # Return as base64 results.append({ "type": output_info["type"], "filename": output_info["filename"], "data": base64.b64encode(data).decode("utf-8"), "size": len(data) }) return { "status": "success", "prompt_id": prompt_id, "outputs": results } except TimeoutError as e: return {"error": str(e), "status": "timeout"} except Exception as e: import traceback traceback.print_exc() return {"error": str(e), "status": "error"} # RunPod serverless entry point if __name__ == "__main__": print("Starting ComfyUI RunPod Handler...") runpod.serverless.start({"handler": handler})