Compare commits

...

16 Commits

Author SHA1 Message Date
Colter Downing 68d8ce4bfd refactor: use endpoint_id instead of endpoint name for routing
- Update route_payload to use endpoint_id instead of endpoint name
- Update AuthData to expect endpoint_id (int) instead of endpoint (str)
- Update ClientState to track endpoint_id
- Update comfyui client functions to use endpoint_id
- Fetch endpoint info (id + api_key) instead of just api_key

This aligns with the autoscaler changes in AUTO-848 that switched
to ID-based endpoint lookups for improved security and consistency.
2025-12-06 14:46:41 -08:00
Colter-Downing 138fc3ac47 Merge pull request #71 from vast-ai/AUTO-comfyui-updates
Auto comfyui updates
2025-12-04 10:55:12 -08:00
Colter Downing 222ac2a0dd default endpoint name 2025-12-04 10:54:55 -08:00
Colter Downing 40aed9b5f8 adding s3 as an option 2025-12-04 10:52:57 -08:00
Colter Downing d4d36bf86e done with comfy updates 2025-12-03 20:45:55 -08:00
Colter Downing e839cfc6e8 include view in API wrapper 2025-12-03 20:22:45 -08:00
Colter Downing f04138e13b update to be able to get images 2025-12-03 20:16:25 -08:00
Colter-Downing de3aa87c8f Merge pull request #70 from vast-ai/AUTO-tgi-client-edits
update tgi client
2025-12-03 18:40:01 -08:00
Colter Downing 6b5b1341a7 update tgi client 2025-12-03 18:38:42 -08:00
Colter-Downing 8be92c03de Merge pull request #69 from vast-ai/AUTO-874--fix-openai-worker-client
defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first
2025-12-03 16:59:56 -08:00
Colter Downing adedb8ba90 defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first if present 2025-12-03 16:57:28 -08:00
LucasArmandVast 2f543c01ad Merge pull request #68 from vast-ai/fix-vllm-concurrency
Increase model wait time for vLLM
2025-12-03 16:13:51 -05:00
Lucas Armand 0bcd2219ea Increase model wait time for vLLM 2025-12-03 12:38:52 -08:00
LucasArmandVast 0339b471c5 Merge pull request #66 from vast-ai/synthesis
PyWorker Error Handling
2025-11-25 16:02:26 -08:00
LucasArmandVast 7a792fd176 Merge pull request #64 from vast-ai/add-llama-log
add llama log
2025-11-21 10:24:27 -08:00
Lucas Armand e0449cb3c7 add llama log 2025-11-21 10:22:16 -08:00
12 changed files with 804 additions and 135 deletions
+1 -1
View File
@@ -66,7 +66,7 @@ class AuthData:
"""data used to authenticate requester""" """data used to authenticate requester"""
cost: str cost: str
endpoint: str endpoint_id: int
reqnum: int reqnum: int
request_idx: int request_idx: int
signature: str signature: str
+7 -3
View File
@@ -75,6 +75,7 @@ def print_truncate_res(res: str):
@dataclass @dataclass
class ClientState: class ClientState:
endpoint_group_name: str endpoint_group_name: str
endpoint_id: int
api_key: str api_key: str
server_url: str server_url: str
worker_endpoint: str worker_endpoint: str
@@ -95,7 +96,7 @@ class ClientState:
self.status = ClientStatus.Error self.status = ClientStatus.Error
return return
route_payload = { route_payload = {
"endpoint": self.endpoint_group_name, "endpoint_id": self.endpoint_id,
"api_key": self.api_key, "api_key": self.api_key,
"cost": self.payload.count_workload(), "cost": self.payload.count_workload(),
} }
@@ -244,16 +245,19 @@ def run_test(
print_thread = threading.Thread(target=print_state, args=(clients, num_requests)) print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
print_thread.daemon = True # makes threads get killed on program exit print_thread.daemon = True # makes threads get killed on program exit
print_thread.start() print_thread.start()
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_info = Endpoint.get_endpoint_info(
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
) )
if not endpoint_api_key: if not endpoint_info:
log.debug(f"Endpoint {endpoint_group_name} not found for API key") log.debug(f"Endpoint {endpoint_group_name} not found for API key")
return return
endpoint_id = endpoint_info["id"]
endpoint_api_key = endpoint_info["api_key"]
try: try:
for _ in range(num_requests): for _ in range(num_requests):
client = ClientState( client = ClientState(
endpoint_group_name=endpoint_group_name, endpoint_group_name=endpoint_group_name,
endpoint_id=endpoint_id,
api_key=endpoint_api_key, api_key=endpoint_api_key,
server_url=server_url, server_url=server_url,
worker_endpoint=worker_endpoint, worker_endpoint=worker_endpoint,
+92 -10
View File
@@ -1,8 +1,16 @@
# ComfyUI PyWorker # ComfyUI PyWorker
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node. The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
## Instance Setup
1. Pick a template
- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless))
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Requirements ## Requirements
@@ -10,6 +18,88 @@ This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) a
A docker image is provided but you may use any if the above requirements are met. A docker image is provided but you may use any if the above requirements are met.
## Client
The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage.
### Setup
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
2. Set your API key:
```bash
export VAST_API_KEY=<your_api_key>
```
### Usage
```bash
# Default prompt
python -m workers.comfyui-json.client
# Custom prompt
python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow"
# With options
python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30
# Using a custom workflow file
python -m workers.comfyui-json.client --workflow my_workflow.json
# With S3 upload
python -m workers.comfyui-json.client --s3
```
### CLI Flags
| Flag | Default | Description |
|------|---------|-------------|
| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name |
| `--prompt` | (default) | Text prompt for image generation |
| `--workflow` | (none) | Path to custom workflow JSON file |
| `--width` | 512 | Image width in pixels |
| `--height` | 512 | Image height in pixels |
| `--steps` | 20 | Number of denoising steps |
| `--seed` | (random) | Random seed for reproducibility |
| `--s3` | (disabled) | Upload generated images to S3 |
### Output
Images are saved to `./generated_images/comfy_{seed}.png`.
### S3 Upload (Optional)
You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag.
**1. Set environment variables:**
```bash
export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com"
export S3_BUCKET_NAME="my-bucket"
export S3_ACCESS_KEY_ID="your-access-key-id"
export S3_SECRET_ACCESS_KEY="your-secret-access-key"
```
**2. Run with S3 upload enabled:**
```bash
python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3
```
Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`.
**Note:** Requires `boto3` (`pip install boto3`).
## Benchmarking ## Benchmarking
### Custom Benchmark Workflows ### Custom Benchmark Workflows
@@ -212,11 +302,3 @@ WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
} }
} }
``` ```
## Client Libraries
See the test client examples for implementation details on how to integrate with the ComfyUI worker.
---
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
+301 -24
View File
@@ -1,35 +1,312 @@
from .data_types import count_workload import os
import sys
import json
import uuid import uuid
import random import random
import asyncio import asyncio
import random import logging
import argparse
import aiohttp
from vastai import Serverless from vastai import Serverless
async def main(): # ---------------------- Config ----------------------
async with Serverless() as client: DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed"
endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name ENDPOINT_NAME = "my-comfyui-endpoint"
DEFAULT_WIDTH = 512
DEFAULT_HEIGHT = 512
DEFAULT_STEPS = 20
COST = 100 # Fixed cost for ComfyUI requests
payload = { # Optional S3 Configuration (from environment variables)
"input": { S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
"request_id": str(uuid.uuid4()), S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
"modifier": "Text2Image", S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
"modifications": { S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
"prompt": "a beautiful landscape with mountains and lakes",
"width": 1024, logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
"height": 1024, log = logging.getLogger(__name__)
"steps": 20,
"seed": random.randint(0, 2**32 - 1)
}, def get_s3_client():
"workflow_json": {} # Empty since using modifier approach """Create and return an S3 client configured for the S3-compatible endpoint"""
} try:
import boto3
from botocore.config import Config
except ImportError:
log.error("boto3 is required for S3 uploads. Install with: pip install boto3")
return None
if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
log.error("S3 environment variables not fully configured. Required:")
log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY")
return None
return boto3.client(
"s3",
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config=Config(signature_version="s3v4"),
)
# ---------------------- API Functions ----------------------
async def call_generate(
client: Serverless,
*,
endpoint_name: str,
prompt: str,
width: int,
height: int,
steps: int,
seed: int,
) -> dict:
"""Generate image using Text2Image modifier"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"request_id": str(uuid.uuid4()),
"modifier": "Text2Image",
"modifications": {
"prompt": prompt,
"width": width,
"height": height,
"steps": steps,
"seed": seed,
},
} }
}
response = await endpoint.request("/generate/sync", payload, cost=count_workload()) return await endpoint.request("/generate/sync", payload, cost=COST)
async def call_generate_workflow(
client: Serverless,
*,
endpoint_name: str,
workflow_json: dict,
) -> dict:
"""Generate using custom workflow JSON"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"request_id": str(uuid.uuid4()),
"workflow_json": workflow_json,
}
}
return await endpoint.request("/generate/sync", payload, cost=COST)
# ---------------------- Demo Class ----------------------
class APIDemo:
def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False):
self.client = client
self.endpoint_name = endpoint_name
self.upload_s3 = upload_s3
self.s3_client = get_s3_client() if upload_s3 else None
if upload_s3 and not self.s3_client:
log.warning("S3 upload requested but client creation failed. Images will only be saved locally.")
def extract_filename(self, response: dict) -> str | None:
"""Extract the generated image filename from ComfyUI response"""
if "comfyui_response" in response:
for data in response["comfyui_response"].values():
if isinstance(data, dict) and "outputs" in data:
for node_output in data["outputs"].values():
if "images" in node_output and node_output["images"]:
return node_output["images"][0].get("filename")
return None
async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch and save image locally from the worker, optionally upload to S3"""
os.makedirs("generated_images", exist_ok=True)
return await self._fetch_image(worker_url, filename, local_name)
def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None:
"""Upload a local file to S3 and return the S3 URL"""
if not self.s3_client:
return None
try:
self.s3_client.upload_file(
local_path,
S3_BUCKET_NAME,
s3_key,
ExtraArgs={"ContentType": "image/png"}
)
s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}"
print(f" ☁️ Uploaded to S3: {s3_key}")
return s3_url
except Exception as e:
log.error(f"Failed to upload to S3: {e}")
return None
async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch image from worker's /view endpoint and save locally"""
if not worker_url:
return None
try:
url = f"{worker_url}/view"
params = {"filename": filename, "type": "output"}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, ssl=False) as resp:
if resp.status == 200:
path = f"generated_images/{local_name}"
image_data = await resp.read()
with open(path, "wb") as f:
f.write(image_data)
print(f" 💾 Saved: {path}")
# Upload to S3 if enabled
if self.upload_s3 and self.s3_client:
s3_key = f"comfyui/{local_name}"
self._upload_to_s3(path, s3_key)
return path
return None
except Exception:
return None
async def demo_prompt(
self,
prompt: str,
width: int,
height: int,
steps: int,
seed: int | None,
):
"""Demo: Generate image from text prompt"""
print("=" * 60)
print("COMFYUI TEXT-TO-IMAGE DEMO")
print("=" * 60)
if seed is None:
seed = random.randint(0, 2**32 - 1)
print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}")
print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}")
print("\n🎨 Generating image...")
response = await call_generate(
self.client,
endpoint_name=self.endpoint_name,
prompt=prompt,
width=width,
height=height,
steps=steps,
seed=seed,
)
print("\n✅ Generation complete!")
# Get worker URL for fetching images
worker_url = response.get("url", "")
print(f"Worker URL: {worker_url}")
# Fetch and save image
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, f"comfy_{seed}.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
async def demo_workflow(self, workflow_file: str):
"""Demo: Generate using custom workflow file"""
print("=" * 60)
print("COMFYUI CUSTOM WORKFLOW DEMO")
print("=" * 60)
if not os.path.exists(workflow_file):
log.error(f"Workflow file not found: {workflow_file}")
return
with open(workflow_file, "r") as f:
workflow_json = json.load(f)
print(f"Workflow: {workflow_file}")
print("\n🎨 Generating...")
response = await call_generate_workflow(
self.client,
endpoint_name=self.endpoint_name,
workflow_json=workflow_json,
)
print("\n✅ Generation complete!")
worker_url = response.get("url", "")
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, "workflow.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT",
help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')")
p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead")
p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})")
p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})")
p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})")
p.add_argument("--seed", type=int, default=None, help="Seed (default: random)")
p.add_argument("--s3", action="store_true",
help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)")
return p
async def main_async():
args = build_arg_parser().parse_args()
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
if args.s3:
print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint, upload_s3=args.s3)
if args.workflow:
await demo.demo_workflow(workflow_file=args.workflow)
else:
await demo.demo_prompt(
prompt=args.prompt,
width=args.width,
height=args.height,
steps=args.steps,
seed=args.seed,
)
except AttributeError as e:
if "API key" in str(e):
log.error("API key missing. Set VAST_API_KEY environment variable.")
else:
log.error(f"Error: {e}")
sys.exit(1)
except Exception as e:
log.error(f"Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Get the file from the path on the local machine using SCP or SFTP
# or configure S3 to upload to cloud storage.
print(response["response"]["output"][0]["local_path"])
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main_async())
+33
View File
@@ -4,6 +4,7 @@ import dataclasses
import base64 import base64
from typing import Optional, Union, Type from typing import Optional, Union, Type
import aiohttp
from aiohttp import web, ClientResponse from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction from lib.backend import Backend, LogAction
@@ -13,6 +14,7 @@ from .data_types import ComfyWorkflowData
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288") MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:18188") # Raw ComfyUI server
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded # This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: " MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
@@ -108,8 +110,39 @@ async def handle_ping(_):
return web.Response(body="pong") return web.Response(body="pong")
async def handle_view(request: web.Request) -> web.Response:
"""Proxy /view requests to raw ComfyUI server to fetch generated images"""
# Forward query params to raw ComfyUI (not the API wrapper)
query_string = request.query_string
url = f"{COMFYUI_URL}/view?{query_string}"
log.debug(f"Proxying /view request to: {url}")
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status == 200:
content = await resp.read()
return web.Response(
body=content,
status=200,
content_type=resp.content_type or "image/png"
)
else:
text = await resp.text()
return web.Response(
text=text,
status=resp.status,
content_type="text/plain"
)
except Exception as e:
log.error(f"Error proxying /view: {e}")
return web.Response(text=str(e), status=500)
routes = [ routes = [
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())), web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
web.get("/view", handle_view),
web.get("/ping", handle_ping), web.get("/ping", handle_ping),
] ]
+14 -11
View File
@@ -13,11 +13,11 @@ from vastai import Serverless
ENDPOINT_NAME = "my-comfyui-endpoint" ENDPOINT_NAME = "my-comfyui-endpoint"
COST = 100 # Use a constant cost for image generation COST = 100 # Use a constant cost for image generation
def call_default_workflow(client: Serverless) -> None: def call_default_workflow(endpoint_id: int, api_key: str, server_url: str) -> None:
WORKER_ENDPOINT = "/prompt" WORKER_ENDPOINT = "/prompt"
COST = 100 COST = 100
route_payload = { route_payload = {
"endpoint": endpoint_group_name, "endpoint_id": endpoint_id,
"api_key": api_key, "api_key": api_key,
"cost": COST, "cost": COST,
} }
@@ -32,7 +32,7 @@ def call_default_workflow(client: Serverless) -> None:
auth_data = dict( auth_data = dict(
signature=message["signature"], signature=message["signature"],
cost=message["cost"], cost=message["cost"],
endpoint=message["endpoint"], endpoint_id=message["endpoint_id"],
reqnum=message["reqnum"], reqnum=message["reqnum"],
url=message["url"], url=message["url"],
) )
@@ -52,12 +52,12 @@ def call_default_workflow(client: Serverless) -> None:
def call_custom_workflow_for_sd3( def call_custom_workflow_for_sd3(
endpoint_group_name: str, api_key: str, server_url: str endpoint_id: int, api_key: str, server_url: str
) -> None: ) -> None:
WORKER_ENDPOINT = "/custom-workflow" WORKER_ENDPOINT = "/custom-workflow"
COST = 100 COST = 100
route_payload = { route_payload = {
"endpoint": endpoint_group_name, "endpoint_id": endpoint_id,
"api_key": api_key, "api_key": api_key,
"cost": COST, "cost": COST,
} }
@@ -72,7 +72,7 @@ def call_custom_workflow_for_sd3(
auth_data = dict( auth_data = dict(
signature=message["signature"], signature=message["signature"],
cost=message["cost"], cost=message["cost"],
endpoint=message["endpoint"], endpoint_id=message["endpoint_id"],
reqnum=message["reqnum"], reqnum=message["reqnum"],
url=message["url"], url=message["url"],
request_idx=message["request_idx"], request_idx=message["request_idx"],
@@ -146,25 +146,28 @@ def call_custom_workflow_for_sd3(
if __name__ == "__main__": if __name__ == "__main__":
from lib.test_utils import test_args from lib.test_utils import test_args
log = logging.getLogger(__name__)
args = test_args.parse_args() args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_info = Endpoint.get_endpoint_info(
endpoint_name=args.endpoint_group_name, endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key, account_api_key=args.api_key,
instance=args.instance, instance=args.instance,
) )
if endpoint_api_key: if endpoint_info:
endpoint_id = endpoint_info["id"]
endpoint_api_key = endpoint_info["api_key"]
try: try:
call_default_workflow( call_default_workflow(
endpoint_id=endpoint_id,
api_key=endpoint_api_key, api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url, server_url=args.server_url,
) )
call_custom_workflow_for_sd3( call_custom_workflow_for_sd3(
endpoint_id=endpoint_id,
api_key=endpoint_api_key, api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url, server_url=args.server_url,
) )
except Exception as e: except Exception as e:
log.error(f"Error during API call: {e}") log.error(f"Error during API call: {e}")
else: else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ") log.error(f"Failed to get endpoint info for {args.endpoint_group_name}")
+33 -26
View File
@@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker. This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended) - [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless)) - [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected. All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. 2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo) ## Client Setup (Demo)
@@ -34,38 +33,20 @@ uv pip install -r requirements.txt
Several examples have been provided in the client to help you get started with your own implementation. Several examples have been provided in the client to help you get started with your own implementation.
### Completions First, set your API key as an environment variable:
Call to `/v1/completions` with json response
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME> export VAST_API_KEY=<your_api_key>
``` ```
### Chat Completion (json) The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming) ### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response Call to `/v1/chat/completions` with streaming response
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME> python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
``` ```
### Interactive Chat (streaming) ### Interactive Chat (streaming)
@@ -75,6 +56,32 @@ Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit. Type `clear` to clear the chat history or `quit` to exit.
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME> python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
``` ```
+32 -16
View File
@@ -18,7 +18,7 @@ logging.basicConfig(
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
# ---------------------- Prompts ---------------------- # ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "the capital of USA is" COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = ( TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? " "Can you list the files in the current working directory and tell me what you see? "
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
# ---- OpenAI-compatible calls (non-streaming) ---- # ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]: async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"] return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
return resp["response"] return resp["response"]
# ---- Streaming variants ---- # ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs): async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs): async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -174,9 +174,10 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
class APIDemo: class APIDemo:
"""Demo and testing functionality for the API client""" """Demo and testing functionality for the API client"""
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None): def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
self.client = client self.client = client
self.model = model self.model = model
self.endpoint_name = endpoint_name
self.tool_manager = tool_manager or ToolManager() self.tool_manager = tool_manager or ToolManager()
# ----- Streaming handler ----- # ----- Streaming handler -----
@@ -185,10 +186,15 @@ class APIDemo:
reasoning_content = "" reasoning_content = ""
printed_reasoning = False printed_reasoning = False
printed_answer = False printed_answer = False
finish_reason = None
async for chunk in stream: async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0] choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {}) delta = choice.get("delta", {})
# Track finish reason
if choice.get("finish_reason"):
finish_reason = choice.get("finish_reason")
# reasoning tokens # reasoning tokens
rc = delta.get("reasoning_content") rc = delta.get("reasoning_content")
@@ -219,6 +225,8 @@ class APIDemo:
print(f"Reasoning tokens: {len(reasoning_content.split())}") print(f"Reasoning tokens: {len(reasoning_content.split())}")
if printed_answer: if printed_answer:
print(f"Response tokens: {len(full_response.split())}") print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
return full_response return full_response
@@ -231,6 +239,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
prompt=COMPLETIONS_PROMPT, prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE, temperature=DEFAULT_TEMPERATURE,
) )
@@ -249,6 +258,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE temperature=DEFAULT_TEMPERATURE
) )
@@ -261,6 +271,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE temperature=DEFAULT_TEMPERATURE
) )
@@ -287,6 +298,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool, tools=minimal_tool,
tool_choice="none", tool_choice="none",
max_tokens=10 max_tokens=10
@@ -312,6 +324,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
tools=self.tool_manager.get_ls_tool_definition(), tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto", tool_choice="auto",
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
@@ -389,6 +402,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE, temperature=DEFAULT_TEMPERATURE,
) )
@@ -427,7 +441,6 @@ class APIDemo:
print("=" * 60) print("=" * 60)
print("INTERACTIVE STREAMING CHAT") print("INTERACTIVE STREAMING CHAT")
print("=" * 60) print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history") print("Type 'quit' to exit, 'clear' to clear history")
print() print()
@@ -453,7 +466,8 @@ class APIDemo:
stream = await stream_chat_completions( stream = await stream_chat_completions(
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=0.7 temperature=0.7
) )
@@ -473,8 +487,8 @@ class APIDemo:
# ---------------------- CLI ---------------------- # ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser: def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)") p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", required=True, help="Model to use for requests (required)") p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)") p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False) modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint") modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
@@ -502,12 +516,14 @@ async def main_async():
print("Please specify exactly one test mode") print("Please specify exactly one test mode")
sys.exit(1) sys.exit(1)
print(f"Using model: {args.model}")
print("=" * 60) print("=" * 60)
print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}")
try: try:
async with Serverless() as client: async with Serverless() as client:
demo = APIDemo(client, args.model, ToolManager()) demo = APIDemo(client, args.model, args.endpoint, ToolManager())
if args.completion: if args.completion:
await demo.demo_completions() await demo.demo_completions()
+2
View File
@@ -11,6 +11,7 @@ MODEL_SERVER_START_LOG_MSG = [
"llama runner started", # Ollama "llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI '"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI '"message":"Connected","target":"text_generation_router::server"', # TGI
"main: model loaded" # llama.cpp
] ]
MODEL_SERVER_ERROR_LOG_MSGS = [ MODEL_SERVER_ERROR_LOG_MSGS = [
@@ -34,6 +35,7 @@ backend = Backend(
model_server_url=os.environ["MODEL_SERVER_URL"], model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"], model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True, allow_parallel_requests=True,
max_wait_time=600.0,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256), benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[ log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG], *[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
+1 -1
View File
@@ -60,7 +60,7 @@ def do_one(endpoint_name: str,
worker_session): worker_session):
try: try:
workload = payload.count_workload() workload = payload.count_workload()
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload} route_payload = {"endpoint_id": endpoint_id, "api_key": endpoint_api_key, "cost": workload}
headers = {"Authorization": f"Bearer {endpoint_api_key}"} headers = {"Authorization": f"Bearer {endpoint_api_key}"}
start = time.time() start = time.time()
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4) r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
+93 -9
View File
@@ -1,19 +1,103 @@
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints: # HuggingFace TGI PyWorker
1. `generate`: Generates the LLM's response to a given prompt in a single request. This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
2. `generate_stream`: Streams the LLM's response token by token.
Both endpoints use the following API payload format: ## Instance Setup
1. Pick a template
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
The template can be configured via the template interface. You may want to change the model or startup arguments.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
## Using the Test Client
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
First, set your API key as an environment variable:
```bash
export VAST_API_KEY=<your_api_key>
```
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
### Generate (Streaming)
Call to `/generate_stream` with streaming response:
```bash
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
```
### Generate (Non-Streaming)
Call to `/generate` with json response:
```bash
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
```
### Interactive Session (Streaming)
Interactive session with streaming responses. Type `quit` to exit.
```bash
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
```
## API Endpoints
TGI provides two primary endpoints:
### Generate (Non-Streaming)
`/generate` - Returns the complete response in a single request.
```json ```json
{ {
"inputs": "PROMPT", "inputs": "Your prompt here",
"parameters": { "parameters": {
"max_new_tokens": 250 "max_new_tokens": 1024,
"temperature": 0.7,
"return_full_text": false
} }
} }
``` ```
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an ### Generate Stream (Streaming)
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
approximately 2 seconds to complete. `/generate_stream` - Streams the response token by token.
```json
{
"inputs": "Your prompt here",
"parameters": {
"max_new_tokens": 1024,
"temperature": 0.7,
"do_sample": true,
"return_full_text": false
}
}
```
## Performance Notes
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
+195 -34
View File
@@ -1,61 +1,222 @@
import logging
import json
import os
import sys
import argparse
from vastai import Serverless from vastai import Serverless
import asyncio import asyncio
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name # ---------------------- Logging ----------------------
MAX_TOKENS = 1024 logging.basicConfig(
PROMPT = "Think step by step: Tell me about the Python programming language." level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def call_generate(client: Serverless) -> None: # ---------------------- Defaults ----------------------
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- API Calls ----------------------
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
"""Non-streaming generation via /generate endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"inputs": PROMPT, "inputs": prompt,
"parameters": { "parameters": {
"max_new_tokens": MAX_TOKENS, "max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": 0.7, "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"return_full_text": False "return_full_text": False,
} }
} }
log.debug("POST /generate %s", json.dumps(payload)[:500])
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS) resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
return resp["response"]
print(resp["response"]["generated_text"])
async def call_generate_stream(client: Serverless) -> None: async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) """Streaming generation via /generate_stream endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"inputs": PROMPT, "inputs": prompt,
"parameters": { "parameters": {
"max_new_tokens": MAX_TOKENS, "max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": 0.7, "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"do_sample": True, "do_sample": True,
"return_full_text": False, "return_full_text": False,
} }
} }
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
resp = await endpoint.request( resp = await endpoint.request(
"/generate_stream", "/generate_stream",
payload, payload,
cost=MAX_TOKENS, cost=payload["parameters"]["max_new_tokens"],
stream=True, stream=True,
) )
stream = resp["response"] return resp["response"] # async generator
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("Answer:\n", end="", flush=True)
print(tok, end="", flush=True)
async def main(): # ---------------------- Demo Runner ----------------------
async with Serverless() as client: class APIDemo:
await call_generate(client) """Demo and testing functionality for the TGI API client"""
await call_generate_stream(client)
def __init__(self, client: Serverless, endpoint_name: str):
self.client = client
self.endpoint_name = endpoint_name
async def handle_streaming_response(self, stream) -> str:
"""Process streaming response and print tokens"""
full_response = ""
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(tok, end="", flush=True)
full_response += tok
print() # newline
if printed_answer:
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
return full_response
async def demo_generate(self) -> None:
"""Demo non-streaming generation"""
print("=" * 60)
print("GENERATE DEMO (NON-STREAMING)")
print("=" * 60)
response = await call_generate(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print(f"\n💬 Response: {response.get('generated_text', '')}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def demo_generate_stream(self) -> None:
"""Demo streaming generation"""
print("=" * 60)
print("GENERATE DEMO (STREAMING)")
print("=" * 60)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
try:
await self.handle_streaming_response(stream)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
async def interactive_chat(self) -> None:
"""Interactive session with streaming generation"""
print("=" * 60)
print("INTERACTIVE STREAMING SESSION")
print("=" * 60)
print(f"Using endpoint: {self.endpoint_name}")
print("Type 'quit' to exit")
print()
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() == "quit":
print("👋 Goodbye!")
break
elif not user_input:
continue
print("Assistant: ", end="", flush=True)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=user_input,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
full_response = ""
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
print(tok, end="", flush=True)
full_response += tok
print() # newline
except KeyboardInterrupt:
print("\n👋 Session interrupted. Goodbye!")
break
except Exception as e:
log.error("\nError: %s", e)
continue
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
return p
async def main_async():
args = build_arg_parser().parse_args()
selected = sum([args.generate, args.generate_stream, args.interactive])
if selected == 0:
print("Please specify exactly one test mode:")
print(" --generate : Test generate endpoint (non-streaming)")
print(" --generate-stream : Test generate endpoint with streaming")
print(" --interactive : Start interactive streaming session")
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
sys.exit(1)
elif selected > 1:
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint)
if args.generate:
await demo.demo_generate()
elif args.generate_stream:
await demo.demo_generate_stream()
elif args.interactive:
await demo.interactive_chat()
except Exception as e:
log.error("Error during test: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main_async())