From 0bcd2219ea550c4f34bb67702bf555fd16a06057 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 3 Dec 2025 12:38:52 -0800 Subject: [PATCH 01/11] Increase model wait time for vLLM --- workers/openai/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/workers/openai/server.py b/workers/openai/server.py index 8dc962f..63f21f9 100644 --- a/workers/openai/server.py +++ b/workers/openai/server.py @@ -35,6 +35,7 @@ backend = Backend( model_server_url=os.environ["MODEL_SERVER_URL"], model_log_file=os.environ["MODEL_LOG"], allow_parallel_requests=True, + max_wait_time=600.0, benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256), log_actions=[ *[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG], From adedb8ba909e387bb6fdd66faed0bbddf450f387 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Wed, 3 Dec 2025 16:57:28 -0800 Subject: [PATCH 02/11] defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first if present --- workers/openai/README.md | 54 +++++++++++++++++++++++----------------- workers/openai/client.py | 48 +++++++++++++++++++++++------------ 2 files changed, 63 insertions(+), 39 deletions(-) diff --git a/workers/openai/README.md b/workers/openai/README.md index 2436784..0dbaaa4 100644 --- a/workers/openai/README.md +++ b/workers/openai/README.md @@ -34,38 +34,20 @@ uv pip install -r requirements.txt Several examples have been provided in the client to help you get started with your own implementation. -### Completions - -Call to `/v1/completions` with json response +First, set your API key as an environment variable: ```bash -python -m workers.openai.client -k -e --completion --model +export VAST_API_KEY= ``` -### Chat Completion (json) - -Call to `/v1/chat/completions` with json response - -```bash -python -m workers.openai.client -k -e --chat --model -``` +The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively. ### Chat Completion (streaming) Call to `/v1/chat/completions` with streaming response ```bash -python -m workers.openai.client -k -e --chat-stream --model -``` - -### Tool Use (json) - -Call to `/v1/chat/completions` with tool and json response. - -This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model. - -```bash -python -m workers.openai.client -k -e --tools --model +python -m workers.openai.client --chat-stream --endpoint --model ``` ### Interactive Chat (streaming) @@ -75,6 +57,32 @@ Interactive session with calls to `/v1/chat/completions`. Type `clear` to clear the chat history or `quit` to exit. ```bash -python -m workers.openai.client -k -e --interactive --model +python -m workers.openai.client --interactive --endpoint --model +``` + +### Chat Completion (json) + +Call to `/v1/chat/completions` with json response + +```bash +python -m workers.openai.client --chat --endpoint --model +``` + +### Tool Use (json) + +Call to `/v1/chat/completions` with tool and json response. + +This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model. + +```bash +python -m workers.openai.client --tools --endpoint --model +``` + +### Completions + +Call to `/v1/completions` with json response + +```bash +python -m workers.openai.client --completion --endpoint --model ``` diff --git a/workers/openai/client.py b/workers/openai/client.py index 8c88444..a92ad95 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -18,7 +18,7 @@ logging.basicConfig( log = logging.getLogger(__file__) # ---------------------- Prompts ---------------------- -COMPLETIONS_PROMPT = "the capital of USA is" +COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by" CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." TOOLS_PROMPT = ( "Can you list the files in the current working directory and tell me what you see? " @@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[ # ---- OpenAI-compatible calls (non-streaming) ---- -async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]: +async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) return resp["response"] -async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: +async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis return resp["response"] # ---- Streaming variants ---- -async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs): +async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs): - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) return resp["response"] # async generator -async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs): +async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs): - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -174,9 +174,10 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L class APIDemo: """Demo and testing functionality for the API client""" - def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None): + def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None): self.client = client self.model = model + self.endpoint_name = endpoint_name self.tool_manager = tool_manager or ToolManager() # ----- Streaming handler ----- @@ -185,10 +186,15 @@ class APIDemo: reasoning_content = "" printed_reasoning = False printed_answer = False + finish_reason = None async for chunk in stream: choice = (chunk.get("choices") or [{}])[0] delta = choice.get("delta", {}) + + # Track finish reason + if choice.get("finish_reason"): + finish_reason = choice.get("finish_reason") # reasoning tokens rc = delta.get("reasoning_content") @@ -219,6 +225,8 @@ class APIDemo: print(f"Reasoning tokens: {len(reasoning_content.split())}") if printed_answer: print(f"Response tokens: {len(full_response.split())}") + if finish_reason: + print(f"Finish reason: {finish_reason}") return full_response @@ -231,6 +239,7 @@ class APIDemo: client=self.client, model=self.model, prompt=COMPLETIONS_PROMPT, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE, ) @@ -249,6 +258,7 @@ class APIDemo: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE ) @@ -261,6 +271,7 @@ class APIDemo: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE ) @@ -287,6 +298,7 @@ class APIDemo: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, tools=minimal_tool, tool_choice="none", max_tokens=10 @@ -312,6 +324,7 @@ class APIDemo: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, tools=self.tool_manager.get_ls_tool_definition(), tool_choice="auto", max_tokens=MAX_TOKENS, @@ -389,6 +402,7 @@ class APIDemo: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE, ) @@ -427,7 +441,6 @@ class APIDemo: print("=" * 60) print("INTERACTIVE STREAMING CHAT") print("=" * 60) - print(f"Using model: {self.model}") print("Type 'quit' to exit, 'clear' to clear history") print() @@ -453,7 +466,8 @@ class APIDemo: stream = await stream_chat_completions( client=self.client, model=self.model, - messages=messages, + messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=0.7 ) @@ -473,8 +487,8 @@ class APIDemo: # ---------------------- CLI ---------------------- def build_arg_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)") - p.add_argument("--model", required=True, help="Model to use for requests (required)") - p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)") + p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})") + p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})") modes = p.add_mutually_exclusive_group(required=False) modes.add_argument("--completion", action="store_true", help="Test completions endpoint") @@ -502,12 +516,14 @@ async def main_async(): print("Please specify exactly one test mode") sys.exit(1) - print(f"Using model: {args.model}") print("=" * 60) + print(f"Using model: {args.model}") + print(f"Using endpoint: {args.endpoint}") + try: async with Serverless() as client: - demo = APIDemo(client, args.model, ToolManager()) + demo = APIDemo(client, args.model, args.endpoint, ToolManager()) if args.completion: await demo.demo_completions() From 6b5b1341a79387a0bb953eaf304335ab50a8c0bf Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Wed, 3 Dec 2025 18:38:42 -0800 Subject: [PATCH 03/11] update tgi client --- workers/openai/README.md | 5 +- workers/tgi/README.md | 102 +++++++++++++++-- workers/tgi/client.py | 229 +++++++++++++++++++++++++++++++++------ 3 files changed, 290 insertions(+), 46 deletions(-) diff --git a/workers/openai/README.md b/workers/openai/README.md index 0dbaaa4..f7596f3 100644 --- a/workers/openai/README.md +++ b/workers/openai/README.md @@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker. -- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended) +- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended) - [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless)) -- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless)) All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected. -2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. +2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. ## Client Setup (Demo) diff --git a/workers/tgi/README.md b/workers/tgi/README.md index 5cf8488..9147e38 100644 --- a/workers/tgi/README.md +++ b/workers/tgi/README.md @@ -1,19 +1,103 @@ -This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints: +# HuggingFace TGI PyWorker -1. `generate`: Generates the LLM's response to a given prompt in a single request. -2. `generate_stream`: Streams the LLM's response token by token. +This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's. -Both endpoints use the following API payload format: +## Instance Setup + +1. Pick a template + +This worker is compatible with any TGI backend. We have a template you can use or you can create your own. + +- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless)) + +The template can be configured via the template interface. You may want to change the model or startup arguments. + +2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. + +## Client Setup (Demo) + +1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client. + +```bash +git clone https://github.com/vast-ai/pyworker +cd pyworker +pip install uv +uv venv -p 3.12 +source .venv/bin/activate +uv pip install -r requirements.txt +``` + +## Using the Test Client + +The test client demonstrates both streaming and non-streaming generation using TGI's native API. + +First, set your API key as an environment variable: + +```bash +export VAST_API_KEY= +``` + +The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`. + +### Generate (Streaming) + +Call to `/generate_stream` with streaming response: + +```bash +python -m workers.tgi.client --generate-stream --endpoint +``` + +### Generate (Non-Streaming) + +Call to `/generate` with json response: + +```bash +python -m workers.tgi.client --generate --endpoint +``` + +### Interactive Session (Streaming) + +Interactive session with streaming responses. Type `quit` to exit. + +```bash +python -m workers.tgi.client --interactive --endpoint +``` + +## API Endpoints + +TGI provides two primary endpoints: + +### Generate (Non-Streaming) + +`/generate` - Returns the complete response in a single request. ```json { - "inputs": "PROMPT", + "inputs": "Your prompt here", "parameters": { - "max_new_tokens": 250 + "max_new_tokens": 1024, + "temperature": 0.7, + "return_full_text": false } } ``` -Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an -instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take -approximately 2 seconds to complete. +### Generate Stream (Streaming) + +`/generate_stream` - Streams the response token by token. + +```json +{ + "inputs": "Your prompt here", + "parameters": { + "max_new_tokens": 1024, + "temperature": 0.7, + "do_sample": true, + "return_full_text": false + } +} +``` + +## Performance Notes + +The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete. diff --git a/workers/tgi/client.py b/workers/tgi/client.py index f307602..23b40c2 100644 --- a/workers/tgi/client.py +++ b/workers/tgi/client.py @@ -1,61 +1,222 @@ +import logging +import json +import os +import sys +import argparse + from vastai import Serverless import asyncio -ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name -MAX_TOKENS = 1024 -PROMPT = "Think step by step: Tell me about the Python programming language." +# ---------------------- Logging ---------------------- +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s[%(levelname)-5s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +log = logging.getLogger(__file__) -async def call_generate(client: Serverless) -> None: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) +# ---------------------- Defaults ---------------------- +DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language." + +ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name +MAX_TOKENS = 1024 +DEFAULT_TEMPERATURE = 0.7 + + +# ---------------------- API Calls ---------------------- +async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict: + """Non-streaming generation via /generate endpoint""" + endpoint = await client.get_endpoint(name=endpoint_name) payload = { - "inputs": PROMPT, + "inputs": prompt, "parameters": { - "max_new_tokens": MAX_TOKENS, - "temperature": 0.7, - "return_full_text": False + "max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + "return_full_text": False, } } - - resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS) - - print(resp["response"]["generated_text"]) + log.debug("POST /generate %s", json.dumps(payload)[:500]) + resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"]) + return resp["response"] -async def call_generate_stream(client: Serverless) -> None: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) +async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs): + """Streaming generation via /generate_stream endpoint""" + endpoint = await client.get_endpoint(name=endpoint_name) payload = { - "inputs": PROMPT, + "inputs": prompt, "parameters": { - "max_new_tokens": MAX_TOKENS, - "temperature": 0.7, + "max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), "do_sample": True, "return_full_text": False, } } - + log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500]) resp = await endpoint.request( "/generate_stream", payload, - cost=MAX_TOKENS, + cost=payload["parameters"]["max_new_tokens"], stream=True, ) - stream = resp["response"] + return resp["response"] # async generator - printed_answer = False - async for event in stream: - tok = (event.get("token") or {}).get("text") - if tok: - if not printed_answer: - printed_answer = True - print("Answer:\n", end="", flush=True) - print(tok, end="", flush=True) -async def main(): - async with Serverless() as client: - await call_generate(client) - await call_generate_stream(client) +# ---------------------- Demo Runner ---------------------- +class APIDemo: + """Demo and testing functionality for the TGI API client""" + + def __init__(self, client: Serverless, endpoint_name: str): + self.client = client + self.endpoint_name = endpoint_name + + async def handle_streaming_response(self, stream) -> str: + """Process streaming response and print tokens""" + full_response = "" + printed_answer = False + + async for event in stream: + tok = (event.get("token") or {}).get("text") + if tok: + if not printed_answer: + printed_answer = True + print("\nšŸ’¬ Response: ", end="", flush=True) + print(tok, end="", flush=True) + full_response += tok + + print() # newline + if printed_answer: + print(f"\nStreaming completed. Response tokens: {len(full_response.split())}") + + return full_response + + async def demo_generate(self) -> None: + """Demo non-streaming generation""" + print("=" * 60) + print("GENERATE DEMO (NON-STREAMING)") + print("=" * 60) + + response = await call_generate( + client=self.client, + endpoint_name=self.endpoint_name, + prompt=DEFAULT_PROMPT, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) + + print(f"\nšŸ’¬ Response: {response.get('generated_text', '')}") + print(f"\nFull Response:\n{json.dumps(response, indent=2)}") + + async def demo_generate_stream(self) -> None: + """Demo streaming generation""" + print("=" * 60) + print("GENERATE DEMO (STREAMING)") + print("=" * 60) + + stream = await call_generate_stream( + client=self.client, + endpoint_name=self.endpoint_name, + prompt=DEFAULT_PROMPT, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) + + try: + await self.handle_streaming_response(stream) + except Exception as e: + log.error("\nError during streaming: %s", e, exc_info=True) + + async def interactive_chat(self) -> None: + """Interactive session with streaming generation""" + print("=" * 60) + print("INTERACTIVE STREAMING SESSION") + print("=" * 60) + print(f"Using endpoint: {self.endpoint_name}") + print("Type 'quit' to exit") + print() + + while True: + try: + user_input = input("You: ").strip() + + if user_input.lower() == "quit": + print("šŸ‘‹ Goodbye!") + break + elif not user_input: + continue + + print("Assistant: ", end="", flush=True) + stream = await call_generate_stream( + client=self.client, + endpoint_name=self.endpoint_name, + prompt=user_input, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) + + full_response = "" + async for event in stream: + tok = (event.get("token") or {}).get("text") + if tok: + print(tok, end="", flush=True) + full_response += tok + print() # newline + + except KeyboardInterrupt: + print("\nšŸ‘‹ Session interrupted. Goodbye!") + break + except Exception as e: + log.error("\nError: %s", e) + continue + + +# ---------------------- CLI ---------------------- +def build_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)") + p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})") + + modes = p.add_mutually_exclusive_group(required=False) + modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)") + modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming") + modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session") + return p + + +async def main_async(): + args = build_arg_parser().parse_args() + + selected = sum([args.generate, args.generate_stream, args.interactive]) + if selected == 0: + print("Please specify exactly one test mode:") + print(" --generate : Test generate endpoint (non-streaming)") + print(" --generate-stream : Test generate endpoint with streaming") + print(" --interactive : Start interactive streaming session") + print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint") + sys.exit(1) + elif selected > 1: + print("Please specify exactly one test mode") + sys.exit(1) + + print("=" * 60) + print(f"Using endpoint: {args.endpoint}") + + try: + async with Serverless() as client: + demo = APIDemo(client, args.endpoint) + + if args.generate: + await demo.demo_generate() + elif args.generate_stream: + await demo.demo_generate_stream() + elif args.interactive: + await demo.interactive_chat() + + except Exception as e: + log.error("Error during test: %s", e, exc_info=True) + sys.exit(1) + if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(main_async()) From f04138e13bee6835e9237fd129f4f7a40c1550ff Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Wed, 3 Dec 2025 20:16:25 -0800 Subject: [PATCH 04/11] update to be able to get images --- workers/comfyui-json/client.py | 351 ++++++++++++++++++++++++++++++--- workers/comfyui-json/server.py | 32 +++ 2 files changed, 359 insertions(+), 24 deletions(-) diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index 93e184c..b80a9ba 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -1,35 +1,338 @@ -from .data_types import count_workload +import os +import sys +import json import uuid import random +import base64 import asyncio -import random +import logging +import argparse from vastai import Serverless -async def main(): - async with Serverless() as client: - endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name +# ---------------------- Config ---------------------- +DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed" +ENDPOINT_NAME = "Comfy-Prod2" +DEFAULT_WIDTH = 512 +DEFAULT_HEIGHT = 512 +DEFAULT_STEPS = 20 +COST = 100 # Fixed cost for ComfyUI requests - payload = { - "input": { - "request_id": str(uuid.uuid4()), - "modifier": "Text2Image", - "modifications": { - "prompt": "a beautiful landscape with mountains and lakes", - "width": 1024, - "height": 1024, - "steps": 20, - "seed": random.randint(0, 2**32 - 1) - }, - "workflow_json": {} # Empty since using modifier approach - } +logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s") +log = logging.getLogger(__name__) + + +# ---------------------- API Functions ---------------------- +async def call_generate( + client: Serverless, + *, + endpoint_name: str, + prompt: str, + width: int, + height: int, + steps: int, + seed: int, +) -> dict: + """Generate image using Text2Image modifier""" + endpoint = await client.get_endpoint(name=endpoint_name) + payload = { + "input": { + "request_id": str(uuid.uuid4()), + "modifier": "Text2Image", + "modifications": { + "prompt": prompt, + "width": width, + "height": height, + "steps": steps, + "seed": seed, + }, } - - response = await endpoint.request("/generate/sync", payload, cost=count_workload()) + } + return await endpoint.request("/generate/sync", payload, cost=COST) + + +async def call_generate_workflow( + client: Serverless, + *, + endpoint_name: str, + workflow_json: dict, +) -> dict: + """Generate using custom workflow JSON""" + endpoint = await client.get_endpoint(name=endpoint_name) + payload = { + "input": { + "request_id": str(uuid.uuid4()), + "workflow_json": workflow_json, + } + } + return await endpoint.request("/generate/sync", payload, cost=COST) + + +# ---------------------- Demo Class ---------------------- +class APIDemo: + def __init__(self, client: Serverless, endpoint_name: str): + self.client = client + self.endpoint_name = endpoint_name + + def extract_images(self, response: dict) -> list: + """Extract image info from ComfyUI response""" + images = [] + + # Check for output array (S3/webhook configured) + if "output" in response: + for item in response["output"]: + if "url" in item: + images.append({"type": "url", "path": item["url"]}) + elif "local_path" in item: + images.append({"type": "local", "path": item["local_path"]}) + elif "base64" in item: + images.append({"type": "base64", "data": item["base64"]}) + + # Check for comfyui_response format (default) + if "comfyui_response" in response: + for prompt_id, data in response["comfyui_response"].items(): + if isinstance(data, dict) and "outputs" in data: + for node_id, node_output in data["outputs"].items(): + if "images" in node_output: + for img in node_output["images"]: + images.append({ + "type": "remote", + "filename": img.get("filename"), + "subfolder": img.get("subfolder", ""), + }) + + return images + + async def save_images(self, images: list, worker_url: str, prefix: str = "comfy") -> list: + """Save images locally by fetching from remote server""" + os.makedirs("generated_images", exist_ok=True) + saved = [] + seen = set() + + for i, img in enumerate(images): + if img["type"] == "base64": + data = img["data"] + if data.startswith("data:"): + data = data.split(",", 1)[-1] + path = f"generated_images/{prefix}_{i}.png" + with open(path, "wb") as f: + f.write(base64.b64decode(data)) + print(f" šŸ’¾ Saved: {path}") + saved.append(path) + + elif img["type"] == "url": + url = img["path"] + if url in seen: + continue + seen.add(url) + try: + import urllib.request + path = f"generated_images/{prefix}_{len(saved)}.png" + urllib.request.urlretrieve(url, path) + print(f" šŸ’¾ Downloaded: {path}") + saved.append(path) + except Exception as e: + print(f" šŸ”— URL: {url}") + saved.append(url) + + elif img["type"] == "local": + remote_path = img["path"] + if remote_path in seen: + continue + seen.add(remote_path) + filename = os.path.basename(remote_path) + # Try to fetch via /view endpoint + local_path = await self._fetch_image(worker_url, filename, "", f"{prefix}_{len(saved)}.png") + if local_path: + saved.append(local_path) + else: + print(f" šŸ“‚ Remote: {remote_path}") + saved.append(remote_path) + + elif img["type"] == "remote": + filename = img["filename"] + if filename in seen: + continue + seen.add(filename) + subfolder = img.get("subfolder", "") + # Try to fetch via /view endpoint + local_path = await self._fetch_image(worker_url, filename, subfolder, f"{prefix}_{len(saved)}.png") + if local_path: + saved.append(local_path) + else: + print(f" šŸ–¼ļø Remote: {filename}") + saved.append(filename) + + return saved + + async def _fetch_image(self, worker_url: str, filename: str, subfolder: str, local_name: str) -> str | None: + """Fetch image directly from worker's /view endpoint""" + if not worker_url: + print(f" āš ļø No worker URL available") + return None + + try: + import aiohttp + + params = {"filename": filename, "type": "output"} + if subfolder: + params["subfolder"] = subfolder + + url = f"{worker_url}/view" + print(f" šŸ”— Fetching from: {url}") + + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params, ssl=False) as resp: + if resp.status == 200: + raw_bytes = await resp.read() + path = f"generated_images/{local_name}" + with open(path, "wb") as f: + f.write(raw_bytes) + print(f" šŸ’¾ Saved: {path}") + return path + else: + text = await resp.text() + print(f" āŒ HTTP {resp.status}: {text[:100]}") + return None + except Exception as e: + print(f" āŒ Fetch error: {e}") + return None + + async def demo_prompt( + self, + prompt: str, + width: int, + height: int, + steps: int, + seed: int | None, + ): + """Demo: Generate image from text prompt""" + print("=" * 60) + print("COMFYUI TEXT-TO-IMAGE DEMO") + print("=" * 60) + + if seed is None: + seed = random.randint(0, 2**32 - 1) + + print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}") + print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}") + print("\nšŸŽØ Generating image...") + + response = await call_generate( + self.client, + endpoint_name=self.endpoint_name, + prompt=prompt, + width=width, + height=height, + steps=steps, + seed=seed, + ) + + print("\nāœ… Generation complete!") + + # Get worker URL for fetching images + worker_url = response.get("url", "") + print(f"Worker URL: {worker_url}") + + # Extract and handle images + if "response" in response: + images = self.extract_images(response["response"]) + if images: + print(f"\nšŸ“ {len(images)} image(s) generated:") + await self.save_images(images, worker_url, prefix=f"comfy_{seed}") + else: + print("\nNo images found in response") + print(json.dumps(response, indent=2, default=str)[:2000]) + else: + print("\nUnexpected response format") + print(json.dumps(response, indent=2, default=str)[:2000]) + + async def demo_workflow(self, workflow_file: str): + """Demo: Generate using custom workflow file""" + print("=" * 60) + print("COMFYUI CUSTOM WORKFLOW DEMO") + print("=" * 60) + + if not os.path.exists(workflow_file): + log.error(f"Workflow file not found: {workflow_file}") + return + + with open(workflow_file, "r") as f: + workflow_json = json.load(f) + + print(f"Workflow: {workflow_file}") + print("\nšŸŽØ Generating...") + + response = await call_generate_workflow( + self.client, + endpoint_name=self.endpoint_name, + workflow_json=workflow_json, + ) + + print("\nāœ… Generation complete!") + + worker_url = response.get("url", "") + + if "response" in response: + images = self.extract_images(response["response"]) + if images: + print(f"\nšŸ“ {len(images)} image(s) generated:") + await self.save_images(images, worker_url, prefix="workflow") + else: + print("\nNo images found in response") + print(json.dumps(response, indent=2, default=str)[:2000]) + else: + print("\nUnexpected response format") + print(json.dumps(response, indent=2, default=str)[:2000]) + + +# ---------------------- CLI ---------------------- +def build_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)") + p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})") + p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT", + help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')") + p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead") + p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})") + p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})") + p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})") + p.add_argument("--seed", type=int, default=None, help="Seed (default: random)") + return p + + +async def main_async(): + args = build_arg_parser().parse_args() + + print("=" * 60) + print(f"Using endpoint: {args.endpoint}") + + try: + async with Serverless() as client: + demo = APIDemo(client, args.endpoint) + + if args.workflow: + await demo.demo_workflow(workflow_file=args.workflow) + else: + await demo.demo_prompt( + prompt=args.prompt, + width=args.width, + height=args.height, + steps=args.steps, + seed=args.seed, + ) + + except AttributeError as e: + if "API key" in str(e): + log.error("API key missing. Set VAST_API_KEY environment variable.") + else: + log.error(f"Error: {e}") + sys.exit(1) + except Exception as e: + log.error(f"Error: {e}") + import traceback + traceback.print_exc() + sys.exit(1) - # Get the file from the path on the local machine using SCP or SFTP - # or configure S3 to upload to cloud storage. - print(response["response"]["output"][0]["local_path"]) if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main_async()) diff --git a/workers/comfyui-json/server.py b/workers/comfyui-json/server.py index ed4e578..7998e71 100644 --- a/workers/comfyui-json/server.py +++ b/workers/comfyui-json/server.py @@ -4,6 +4,7 @@ import dataclasses import base64 from typing import Optional, Union, Type +import aiohttp from aiohttp import web, ClientResponse from lib.backend import Backend, LogAction @@ -108,8 +109,39 @@ async def handle_ping(_): return web.Response(body="pong") +async def handle_view(request: web.Request) -> web.Response: + """Proxy /view requests to ComfyUI to fetch generated images""" + # Forward query params to ComfyUI + query_string = request.query_string + url = f"{MODEL_SERVER_URL}/view?{query_string}" + + log.debug(f"Proxying /view request to: {url}") + + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + if resp.status == 200: + content = await resp.read() + return web.Response( + body=content, + status=200, + content_type=resp.content_type or "image/png" + ) + else: + text = await resp.text() + return web.Response( + text=text, + status=resp.status, + content_type="text/plain" + ) + except Exception as e: + log.error(f"Error proxying /view: {e}") + return web.Response(text=str(e), status=500) + + routes = [ web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())), + web.get("/view", handle_view), web.get("/ping", handle_ping), ] From e839cfc6e8fa3a32eef152381359df27bf15a953 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Wed, 3 Dec 2025 20:22:45 -0800 Subject: [PATCH 05/11] include view in API wrapper --- workers/comfyui-json/server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/workers/comfyui-json/server.py b/workers/comfyui-json/server.py index 7998e71..daf35e5 100644 --- a/workers/comfyui-json/server.py +++ b/workers/comfyui-json/server.py @@ -14,6 +14,7 @@ from .data_types import ComfyWorkflowData MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288") +COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:18188") # Raw ComfyUI server # This is the last log line that gets emitted once comfyui+extensions have been fully loaded MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: " @@ -110,10 +111,10 @@ async def handle_ping(_): async def handle_view(request: web.Request) -> web.Response: - """Proxy /view requests to ComfyUI to fetch generated images""" - # Forward query params to ComfyUI + """Proxy /view requests to raw ComfyUI server to fetch generated images""" + # Forward query params to raw ComfyUI (not the API wrapper) query_string = request.query_string - url = f"{MODEL_SERVER_URL}/view?{query_string}" + url = f"{COMFYUI_URL}/view?{query_string}" log.debug(f"Proxying /view request to: {url}") From d4d36bf86e03f40f727179975dfb8d53518e9ed2 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Wed, 3 Dec 2025 20:45:55 -0800 Subject: [PATCH 06/11] done with comfy updates --- workers/comfyui-json/README.md | 71 ++++++++++++--- workers/comfyui-json/client.py | 156 +++++++-------------------------- 2 files changed, 94 insertions(+), 133 deletions(-) diff --git a/workers/comfyui-json/README.md b/workers/comfyui-json/README.md index 7aa1ba3..5306a23 100644 --- a/workers/comfyui-json/README.md +++ b/workers/comfyui-json/README.md @@ -1,8 +1,16 @@ # ComfyUI PyWorker -This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. +This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's. -The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node. +The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node. + +## Instance Setup + +1. Pick a template + +- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless)) + +2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. ## Requirements @@ -10,6 +18,57 @@ This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) a A docker image is provided but you may use any if the above requirements are met. +## Client + +The client demonstrates how to use the Vast Serverless SDK to generate images and save them locally. + +### Setup + +1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client. + +```bash +git clone https://github.com/vast-ai/pyworker +cd pyworker +pip install uv +uv venv -p 3.12 +source .venv/bin/activate +uv pip install -r requirements.txt +``` + +2. Set your API key: + +```bash +export VAST_API_KEY= +``` + +### Usage + +```bash +# Default prompt +python -m workers.comfyui-json.client + +# Custom prompt +python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow" + +# With options +python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30 +``` + +### CLI Flags + +| Flag | Default | Description | +|------|---------|-------------| +| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name | +| `--prompt` | (default) | Text prompt for image generation | +| `--width` | 512 | Image width in pixels | +| `--height` | 512 | Image height in pixels | +| `--steps` | 20 | Number of denoising steps | +| `--seed` | (random) | Random seed for reproducibility | + +### Output + +Images are saved to `./generated_images/comfy_{seed}.png`. + ## Benchmarking ### Custom Benchmark Workflows @@ -212,11 +271,3 @@ WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds } } ``` - -## Client Libraries - -See the test client examples for implementation details on how to integrate with the ComfyUI worker. - ---- - -See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler. \ No newline at end of file diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index b80a9ba..a243183 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -3,16 +3,16 @@ import sys import json import uuid import random -import base64 import asyncio import logging import argparse +import aiohttp from vastai import Serverless # ---------------------- Config ---------------------- DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed" -ENDPOINT_NAME = "Comfy-Prod2" +ENDPOINT_NAME = "my-comfyui-endpoint" DEFAULT_WIDTH = 512 DEFAULT_HEIGHT = 512 DEFAULT_STEPS = 20 @@ -74,128 +74,40 @@ class APIDemo: self.client = client self.endpoint_name = endpoint_name - def extract_images(self, response: dict) -> list: - """Extract image info from ComfyUI response""" - images = [] - - # Check for output array (S3/webhook configured) - if "output" in response: - for item in response["output"]: - if "url" in item: - images.append({"type": "url", "path": item["url"]}) - elif "local_path" in item: - images.append({"type": "local", "path": item["local_path"]}) - elif "base64" in item: - images.append({"type": "base64", "data": item["base64"]}) - - # Check for comfyui_response format (default) + def extract_filename(self, response: dict) -> str | None: + """Extract the generated image filename from ComfyUI response""" if "comfyui_response" in response: - for prompt_id, data in response["comfyui_response"].items(): + for data in response["comfyui_response"].values(): if isinstance(data, dict) and "outputs" in data: - for node_id, node_output in data["outputs"].items(): - if "images" in node_output: - for img in node_output["images"]: - images.append({ - "type": "remote", - "filename": img.get("filename"), - "subfolder": img.get("subfolder", ""), - }) - - return images + for node_output in data["outputs"].values(): + if "images" in node_output and node_output["images"]: + return node_output["images"][0].get("filename") + return None - async def save_images(self, images: list, worker_url: str, prefix: str = "comfy") -> list: - """Save images locally by fetching from remote server""" + async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None: + """Fetch and save image locally from the worker""" os.makedirs("generated_images", exist_ok=True) - saved = [] - seen = set() + return await self._fetch_image(worker_url, filename, local_name) - for i, img in enumerate(images): - if img["type"] == "base64": - data = img["data"] - if data.startswith("data:"): - data = data.split(",", 1)[-1] - path = f"generated_images/{prefix}_{i}.png" - with open(path, "wb") as f: - f.write(base64.b64decode(data)) - print(f" šŸ’¾ Saved: {path}") - saved.append(path) - - elif img["type"] == "url": - url = img["path"] - if url in seen: - continue - seen.add(url) - try: - import urllib.request - path = f"generated_images/{prefix}_{len(saved)}.png" - urllib.request.urlretrieve(url, path) - print(f" šŸ’¾ Downloaded: {path}") - saved.append(path) - except Exception as e: - print(f" šŸ”— URL: {url}") - saved.append(url) - - elif img["type"] == "local": - remote_path = img["path"] - if remote_path in seen: - continue - seen.add(remote_path) - filename = os.path.basename(remote_path) - # Try to fetch via /view endpoint - local_path = await self._fetch_image(worker_url, filename, "", f"{prefix}_{len(saved)}.png") - if local_path: - saved.append(local_path) - else: - print(f" šŸ“‚ Remote: {remote_path}") - saved.append(remote_path) - - elif img["type"] == "remote": - filename = img["filename"] - if filename in seen: - continue - seen.add(filename) - subfolder = img.get("subfolder", "") - # Try to fetch via /view endpoint - local_path = await self._fetch_image(worker_url, filename, subfolder, f"{prefix}_{len(saved)}.png") - if local_path: - saved.append(local_path) - else: - print(f" šŸ–¼ļø Remote: {filename}") - saved.append(filename) - - return saved - - async def _fetch_image(self, worker_url: str, filename: str, subfolder: str, local_name: str) -> str | None: - """Fetch image directly from worker's /view endpoint""" + async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None: + """Fetch image from worker's /view endpoint and save locally""" if not worker_url: - print(f" āš ļø No worker URL available") return None try: - import aiohttp - - params = {"filename": filename, "type": "output"} - if subfolder: - params["subfolder"] = subfolder - url = f"{worker_url}/view" - print(f" šŸ”— Fetching from: {url}") + params = {"filename": filename, "type": "output"} async with aiohttp.ClientSession() as session: async with session.get(url, params=params, ssl=False) as resp: if resp.status == 200: - raw_bytes = await resp.read() path = f"generated_images/{local_name}" with open(path, "wb") as f: - f.write(raw_bytes) + f.write(await resp.read()) print(f" šŸ’¾ Saved: {path}") return path - else: - text = await resp.text() - print(f" āŒ HTTP {resp.status}: {text[:100]}") - return None - except Exception as e: - print(f" āŒ Fetch error: {e}") + return None + except Exception: return None async def demo_prompt( @@ -234,18 +146,17 @@ class APIDemo: worker_url = response.get("url", "") print(f"Worker URL: {worker_url}") - # Extract and handle images + # Fetch and save image if "response" in response: - images = self.extract_images(response["response"]) - if images: - print(f"\nšŸ“ {len(images)} image(s) generated:") - await self.save_images(images, worker_url, prefix=f"comfy_{seed}") + filename = self.extract_filename(response["response"]) + if filename: + path = await self.save_image(worker_url, filename, f"comfy_{seed}.png") + if not path: + print(f"āŒ Failed to fetch image") else: - print("\nNo images found in response") - print(json.dumps(response, indent=2, default=str)[:2000]) + print("āŒ No image in response") else: - print("\nUnexpected response format") - print(json.dumps(response, indent=2, default=str)[:2000]) + print("āŒ Unexpected response format") async def demo_workflow(self, workflow_file: str): """Demo: Generate using custom workflow file""" @@ -274,16 +185,15 @@ class APIDemo: worker_url = response.get("url", "") if "response" in response: - images = self.extract_images(response["response"]) - if images: - print(f"\nšŸ“ {len(images)} image(s) generated:") - await self.save_images(images, worker_url, prefix="workflow") + filename = self.extract_filename(response["response"]) + if filename: + path = await self.save_image(worker_url, filename, "workflow.png") + if not path: + print(f"āŒ Failed to fetch image") else: - print("\nNo images found in response") - print(json.dumps(response, indent=2, default=str)[:2000]) + print("āŒ No image in response") else: - print("\nUnexpected response format") - print(json.dumps(response, indent=2, default=str)[:2000]) + print("āŒ Unexpected response format") # ---------------------- CLI ---------------------- From 40aed9b5f8d85f3f5589cf10cc528901f57b8976 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Thu, 4 Dec 2025 10:52:57 -0800 Subject: [PATCH 07/11] adding s3 as an option --- workers/comfyui-json/README.md | 33 ++++++++++++++- workers/comfyui-json/client.py | 74 +++++++++++++++++++++++++++++++--- 2 files changed, 101 insertions(+), 6 deletions(-) diff --git a/workers/comfyui-json/README.md b/workers/comfyui-json/README.md index 5306a23..9517dbb 100644 --- a/workers/comfyui-json/README.md +++ b/workers/comfyui-json/README.md @@ -20,7 +20,7 @@ A docker image is provided but you may use any if the above requirements are met ## Client -The client demonstrates how to use the Vast Serverless SDK to generate images and save them locally. +The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage. ### Setup @@ -52,6 +52,12 @@ python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow" # With options python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30 + +# Using a custom workflow file +python -m workers.comfyui-json.client --workflow my_workflow.json + +# With S3 upload +python -m workers.comfyui-json.client --s3 ``` ### CLI Flags @@ -60,15 +66,40 @@ python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 10 |------|---------|-------------| | `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name | | `--prompt` | (default) | Text prompt for image generation | +| `--workflow` | (none) | Path to custom workflow JSON file | | `--width` | 512 | Image width in pixels | | `--height` | 512 | Image height in pixels | | `--steps` | 20 | Number of denoising steps | | `--seed` | (random) | Random seed for reproducibility | +| `--s3` | (disabled) | Upload generated images to S3 | ### Output Images are saved to `./generated_images/comfy_{seed}.png`. +### S3 Upload (Optional) + +You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag. + +**1. Set environment variables:** + +```bash +export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com" +export S3_BUCKET_NAME="my-bucket" +export S3_ACCESS_KEY_ID="your-access-key-id" +export S3_SECRET_ACCESS_KEY="your-secret-access-key" +``` + +**2. Run with S3 upload enabled:** + +```bash +python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3 +``` + +Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`. + +**Note:** Requires `boto3` (`pip install boto3`). + ## Benchmarking ### Custom Benchmark Workflows diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index a243183..10a1d91 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -12,16 +12,45 @@ from vastai import Serverless # ---------------------- Config ---------------------- DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed" -ENDPOINT_NAME = "my-comfyui-endpoint" +ENDPOINT_NAME = "Comfy-Prod" DEFAULT_WIDTH = 512 DEFAULT_HEIGHT = 512 DEFAULT_STEPS = 20 COST = 100 # Fixed cost for ComfyUI requests +# Optional S3 Configuration (from environment variables) +S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL") +S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") +S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID") +S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY") + logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s") log = logging.getLogger(__name__) +def get_s3_client(): + """Create and return an S3 client configured for the S3-compatible endpoint""" + try: + import boto3 + from botocore.config import Config + except ImportError: + log.error("boto3 is required for S3 uploads. Install with: pip install boto3") + return None + + if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]): + log.error("S3 environment variables not fully configured. Required:") + log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY") + return None + + return boto3.client( + "s3", + endpoint_url=S3_ENDPOINT_URL, + aws_access_key_id=S3_ACCESS_KEY_ID, + aws_secret_access_key=S3_SECRET_ACCESS_KEY, + config=Config(signature_version="s3v4"), + ) + + # ---------------------- API Functions ---------------------- async def call_generate( client: Serverless, @@ -70,9 +99,14 @@ async def call_generate_workflow( # ---------------------- Demo Class ---------------------- class APIDemo: - def __init__(self, client: Serverless, endpoint_name: str): + def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False): self.client = client self.endpoint_name = endpoint_name + self.upload_s3 = upload_s3 + self.s3_client = get_s3_client() if upload_s3 else None + + if upload_s3 and not self.s3_client: + log.warning("S3 upload requested but client creation failed. Images will only be saved locally.") def extract_filename(self, response: dict) -> str | None: """Extract the generated image filename from ComfyUI response""" @@ -85,10 +119,29 @@ class APIDemo: return None async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None: - """Fetch and save image locally from the worker""" + """Fetch and save image locally from the worker, optionally upload to S3""" os.makedirs("generated_images", exist_ok=True) return await self._fetch_image(worker_url, filename, local_name) + def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None: + """Upload a local file to S3 and return the S3 URL""" + if not self.s3_client: + return None + + try: + self.s3_client.upload_file( + local_path, + S3_BUCKET_NAME, + s3_key, + ExtraArgs={"ContentType": "image/png"} + ) + s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}" + print(f" ā˜ļø Uploaded to S3: {s3_key}") + return s3_url + except Exception as e: + log.error(f"Failed to upload to S3: {e}") + return None + async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None: """Fetch image from worker's /view endpoint and save locally""" if not worker_url: @@ -102,9 +155,16 @@ class APIDemo: async with session.get(url, params=params, ssl=False) as resp: if resp.status == 200: path = f"generated_images/{local_name}" + image_data = await resp.read() with open(path, "wb") as f: - f.write(await resp.read()) + f.write(image_data) print(f" šŸ’¾ Saved: {path}") + + # Upload to S3 if enabled + if self.upload_s3 and self.s3_client: + s3_key = f"comfyui/{local_name}" + self._upload_to_s3(path, s3_key) + return path return None except Exception: @@ -207,6 +267,8 @@ def build_arg_parser() -> argparse.ArgumentParser: p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})") p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})") p.add_argument("--seed", type=int, default=None, help="Seed (default: random)") + p.add_argument("--s3", action="store_true", + help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)") return p @@ -215,10 +277,12 @@ async def main_async(): print("=" * 60) print(f"Using endpoint: {args.endpoint}") + if args.s3: + print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})") try: async with Serverless() as client: - demo = APIDemo(client, args.endpoint) + demo = APIDemo(client, args.endpoint, upload_s3=args.s3) if args.workflow: await demo.demo_workflow(workflow_file=args.workflow) From 222ac2a0ddfe77c96abd5036bb78af6534274d85 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Thu, 4 Dec 2025 10:54:55 -0800 Subject: [PATCH 08/11] default endpoint name --- workers/comfyui-json/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index 10a1d91..d79b30d 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -12,7 +12,7 @@ from vastai import Serverless # ---------------------- Config ---------------------- DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed" -ENDPOINT_NAME = "Comfy-Prod" +ENDPOINT_NAME = "my-comfyui-endpoint" DEFAULT_WIDTH = 512 DEFAULT_HEIGHT = 512 DEFAULT_STEPS = 20 From 7be8aa63978ef0e061d2dc155484b65dd99e9b02 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 10 Dec 2025 17:38:03 -0800 Subject: [PATCH 09/11] pin pycares --- requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 377b20a..8583584 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ -aiohttp[speedups]==3.10.1 +aiohttp~=3.10.1 +aiodns~=3.6.0 +pycares~=4.11.0 anyio~=4.4 lib~=4.0 nltk~=3.9 From df61e6e9467a7ad7650995f6109bfab366717158 Mon Sep 17 00:00:00 2001 From: edgaratvast Date: Wed, 10 Dec 2025 19:34:52 -0800 Subject: [PATCH 10/11] correct version pin for aiohttp (#73) Co-authored-by: Edgar Lin --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8583584..b484d2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aiohttp~=3.10.1 +aiohttp==3.10.1 aiodns~=3.6.0 pycares~=4.11.0 anyio~=4.4 @@ -10,4 +10,4 @@ Requests~=2.32 transformers~=4.52 utils==1.0.* hf_transfer>=0.1.9 -vastai-sdk>=0.2.0 \ No newline at end of file +vastai-sdk>=0.2.0 From 4ecc07032ff829baeae0f0a807bb5df8c9a7e536 Mon Sep 17 00:00:00 2001 From: Abiola Akinnubi Date: Thu, 11 Dec 2025 12:51:56 -0800 Subject: [PATCH 11/11] Mark pyworkers as "Error" if startup script fails. to avoid silent fail that waits for autoscaler. --- start_server.sh | 167 +++++++++++++++++++++++++++++++----------------- 1 file changed, 110 insertions(+), 57 deletions(-) diff --git a/start_server.sh b/start_server.sh index 4b07e01..2f5ecdc 100755 --- a/start_server.sh +++ b/start_server.sh @@ -22,10 +22,49 @@ function echo_var(){ echo "$1: ${!1}" } -[ -z "$BACKEND" ] && echo "BACKEND must be set!" && exit 1 -[ -z "$MODEL_LOG" ] && echo "MODEL_LOG must be set!" && exit 1 -[ -z "$HF_TOKEN" ] && echo "HF_TOKEN must be set!" && exit 1 -[ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && echo "For comfyui backends, COMFY_MODEL must be set!" && exit 1 +function report_error_and_exit(){ + local error_msg="$1" + echo "ERROR: $error_msg" + + # Report error to autoscaler + MTOKEN="${MASTER_TOKEN:-}" + VERSION="${PYWORKER_VERSION:-0}" + + IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}" + for addr in "${REPORT_ADDRS[@]}"; do + curl -sS -X POST -H 'Content-Type: application/json' \ + -d "$(cat <> "$MODEL_LOG.old" - : > "$MODEL_LOG" + if ! cat "$MODEL_LOG" >> "$MODEL_LOG.old"; then + report_error_and_exit "Failed to rotate model log" + fi + if ! : > "$MODEL_LOG"; then + report_error_and_exit "Failed to truncate model log" + fi fi # Populate /etc/environment with quoted values if ! grep -q "VAST" /etc/environment; then - env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do + if ! env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do name=${line%%=*} value=${line#*=} printf '%s="%s"\n' "$name" "$value" - done > /etc/environment + done > /etc/environment; then + echo "WARNING: Failed to populate /etc/environment, continuing anyway" + fi fi if [ ! -d "$ENV_PATH" ] then echo "setting up venv" if ! which uv; then - curl -LsSf https://astral.sh/uv/install.sh | sh - source ~/.local/bin/env + if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then + report_error_and_exit "Failed to install uv package manager" + fi + if [[ -f ~/.local/bin/env ]]; then + if ! source ~/.local/bin/env; then + report_error_and_exit "Failed to source uv environment" + fi + else + echo "WARNING: ~/.local/bin/env not found after uv installation" + fi fi # Fork testing - [[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR" + if [[ ! -d $SERVER_DIR ]]; then + if ! git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"; then + report_error_and_exit "Failed to clone pyworker repository" + fi + fi if [[ -n ${PYWORKER_REF:-} ]]; then - (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF") + if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); then + report_error_and_exit "Failed to checkout pyworker reference: $PYWORKER_REF" + fi fi - uv venv --python-preference only-managed "$ENV_PATH" -p 3.10 - source "$ENV_PATH/bin/activate" + if ! uv venv --python-preference only-managed "$ENV_PATH" -p 3.10; then + report_error_and_exit "Failed to create virtual environment" + fi + + if ! source "$ENV_PATH/bin/activate"; then + report_error_and_exit "Failed to activate virtual environment" + fi - uv pip install -r "${SERVER_DIR}/requirements.txt" + if ! uv pip install -r "${SERVER_DIR}/requirements.txt"; then + report_error_and_exit "Failed to install Python requirements" + fi - touch ~/.no_auto_tmux + if ! touch ~/.no_auto_tmux; then + report_error_and_exit "Failed to create ~/.no_auto_tmux" + fi else - [[ -f ~/.local/bin/env ]] && source ~/.local/bin/env - source "$WORKSPACE_DIR/worker-env/bin/activate" + if [[ -f ~/.local/bin/env ]]; then + if ! source ~/.local/bin/env; then + report_error_and_exit "Failed to source uv environment" + fi + fi + if ! source "$WORKSPACE_DIR/worker-env/bin/activate"; then + report_error_and_exit "Failed to activate existing virtual environment" + fi echo "environment activated" echo "venv: $VIRTUAL_ENV" fi -[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1 +[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && report_error_and_exit "$BACKEND not supported!" if [ "$USE_SSL" = true ]; then - cat << EOF > /etc/openssl-san.cnf + if ! cat << EOF > /etc/openssl-san.cnf [req] default_bits = 2048 distinguished_name = req_distinguished_name @@ -109,18 +183,25 @@ if [ "$USE_SSL" = true ]; then [alt_names] IP.1 = 0.0.0.0 EOF + then + report_error_and_exit "Failed to write OpenSSL config" + fi - openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \ + if ! openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \ -nodes \ -sha256 \ -keyout /etc/instance.key \ -out /etc/instance.csr \ - -config /etc/openssl-san.cnf + -config /etc/openssl-san.cnf; then + report_error_and_exit "Failed to generate SSL certificate request" + fi - curl --header 'Content-Type: application/octet-stream' \ - --data-binary @//etc/instance.csr \ + if ! curl --header 'Content-Type: application/octet-stream' \ + --data-binary @/etc/instance.csr \ -X \ - POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; + POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; then + report_error_and_exit "Failed to sign SSL certificate" + fi fi @@ -128,7 +209,9 @@ fi export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED -cd "$SERVER_DIR" +if ! cd "$SERVER_DIR"; then + report_error_and_exit "Failed to cd into SERVER_DIR: $SERVER_DIR" +fi echo "launching PyWorker server" @@ -138,37 +221,7 @@ PY_STATUS=${PIPESTATUS[0]} set -e if [ "${PY_STATUS}" -ne 0 ]; then - echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..." - ERROR_MSG="PyWorker exited: code ${PY_STATUS}" - MTOKEN="${MASTER_TOKEN:-}" - VERSION="${PYWORKER_VERSION:-0}" - - IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}" - for addr in "${REPORT_ADDRS[@]}"; do - curl -sS -X POST -H 'Content-Type: application/json' \ - -d "$(cat <