import os import sys import json import uuid import random import asyncio import logging import argparse import aiohttp from vastai import Serverless # ---------------------- Config ---------------------- DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed" ENDPOINT_NAME = "my-comfyui-endpoint" DEFAULT_WIDTH = 512 DEFAULT_HEIGHT = 512 DEFAULT_STEPS = 20 COST = 100 # Fixed cost for ComfyUI requests logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s") log = logging.getLogger(__name__) # ---------------------- API Functions ---------------------- async def call_generate( client: Serverless, *, endpoint_name: str, prompt: str, width: int, height: int, steps: int, seed: int, ) -> dict: """Generate image using Text2Image modifier""" endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { "request_id": str(uuid.uuid4()), "modifier": "Text2Image", "modifications": { "prompt": prompt, "width": width, "height": height, "steps": steps, "seed": seed, }, } } return await endpoint.request("/generate/sync", payload, cost=COST) async def call_generate_workflow( client: Serverless, *, endpoint_name: str, workflow_json: dict, ) -> dict: """Generate using custom workflow JSON""" endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { "request_id": str(uuid.uuid4()), "workflow_json": workflow_json, } } return await endpoint.request("/generate/sync", payload, cost=COST) # ---------------------- Demo Class ---------------------- class APIDemo: def __init__(self, client: Serverless, endpoint_name: str): self.client = client self.endpoint_name = endpoint_name def extract_filename(self, response: dict) -> str | None: """Extract the generated image filename from ComfyUI response""" if "comfyui_response" in response: for data in response["comfyui_response"].values(): if isinstance(data, dict) and "outputs" in data: for node_output in data["outputs"].values(): if "images" in node_output and node_output["images"]: return node_output["images"][0].get("filename") return None async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None: """Fetch and save image locally from the worker""" os.makedirs("generated_images", exist_ok=True) return await self._fetch_image(worker_url, filename, local_name) async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None: """Fetch image from worker's /view endpoint and save locally""" if not worker_url: return None try: url = f"{worker_url}/view" params = {"filename": filename, "type": "output"} async with aiohttp.ClientSession() as session: async with session.get(url, params=params, ssl=False) as resp: if resp.status == 200: path = f"generated_images/{local_name}" with open(path, "wb") as f: f.write(await resp.read()) print(f" šŸ’¾ Saved: {path}") return path return None except Exception: return None async def demo_prompt( self, prompt: str, width: int, height: int, steps: int, seed: int | None, ): """Demo: Generate image from text prompt""" print("=" * 60) print("COMFYUI TEXT-TO-IMAGE DEMO") print("=" * 60) if seed is None: seed = random.randint(0, 2**32 - 1) print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}") print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}") print("\nšŸŽØ Generating image...") response = await call_generate( self.client, endpoint_name=self.endpoint_name, prompt=prompt, width=width, height=height, steps=steps, seed=seed, ) print("\nāœ… Generation complete!") # Get worker URL for fetching images worker_url = response.get("url", "") print(f"Worker URL: {worker_url}") # Fetch and save image if "response" in response: filename = self.extract_filename(response["response"]) if filename: path = await self.save_image(worker_url, filename, f"comfy_{seed}.png") if not path: print(f"āŒ Failed to fetch image") else: print("āŒ No image in response") else: print("āŒ Unexpected response format") async def demo_workflow(self, workflow_file: str): """Demo: Generate using custom workflow file""" print("=" * 60) print("COMFYUI CUSTOM WORKFLOW DEMO") print("=" * 60) if not os.path.exists(workflow_file): log.error(f"Workflow file not found: {workflow_file}") return with open(workflow_file, "r") as f: workflow_json = json.load(f) print(f"Workflow: {workflow_file}") print("\nšŸŽØ Generating...") response = await call_generate_workflow( self.client, endpoint_name=self.endpoint_name, workflow_json=workflow_json, ) print("\nāœ… Generation complete!") worker_url = response.get("url", "") if "response" in response: filename = self.extract_filename(response["response"]) if filename: path = await self.save_image(worker_url, filename, "workflow.png") if not path: print(f"āŒ Failed to fetch image") else: print("āŒ No image in response") else: print("āŒ Unexpected response format") # ---------------------- CLI ---------------------- def build_arg_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)") p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})") p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT", help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')") p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead") p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})") p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})") p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})") p.add_argument("--seed", type=int, default=None, help="Seed (default: random)") return p async def main_async(): args = build_arg_parser().parse_args() print("=" * 60) print(f"Using endpoint: {args.endpoint}") try: async with Serverless() as client: demo = APIDemo(client, args.endpoint) if args.workflow: await demo.demo_workflow(workflow_file=args.workflow) else: await demo.demo_prompt( prompt=args.prompt, width=args.width, height=args.height, steps=args.steps, seed=args.seed, ) except AttributeError as e: if "API key" in str(e): log.error("API key missing. Set VAST_API_KEY environment variable.") else: log.error(f"Error: {e}") sys.exit(1) except Exception as e: log.error(f"Error: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": asyncio.run(main_async())