Compare commits

..

32 Commits

Author SHA1 Message Date
Lucas Armand 186b388f2f Merge branch 'main' into add-hacky-deployments-script 2026-03-23 14:20:45 -07:00
Lucas Armand d1c521f973 retry S3 download 2026-03-23 14:18:52 -07:00
Lucas Armand e1a5cf2b43 Retry until it loads 2026-03-23 14:16:41 -07:00
LucasArmandVast 2e8f18276f Add beta deployments script (#80) 2026-03-23 14:14:06 -07:00
Lucas Armand 87f968f961 Add hacky deployments script 2026-03-23 12:39:15 -07:00
Scott Darden eba9c480eb Merge pull request #79 from vast-ai/update-requirements
Updated requirements to only require vastai-sdk
2026-01-14 12:07:33 -08:00
Lucas Armand aaca1c9645 Updated requirements to only require vastai-sdk 2026-01-14 10:47:07 -08:00
LucasArmandVast f319db6bd5 flag for model log rotate (#78) 2026-01-12 17:03:18 -08:00
LucasArmandVast 4d786b4d17 SDK Versioning Improvements (#77)
* Add SDK_BRANCH
2026-01-02 10:23:07 -08:00
LucasArmandVast bd3e0032a1 Add SDK version checking (#76) 2025-12-17 21:01:52 -08:00
Lucas Armand e02f4bc943 Lowered concurrency of vLLM and TGI benchmarks 2025-12-17 11:55:33 -08:00
Lucas Armand bcb04b9a32 add missing comma 2025-12-17 11:40:40 -08:00
Lucas Armand 9daf171487 Increase queue limits for vLLM and TGI 2025-12-17 11:38:55 -08:00
LucasArmandVast 29f836eb1a Backwards compatible vLLM payload (#75)
* Support old vLLM payloads
2025-12-15 19:58:02 -08:00
LucasArmandVast 4380d98c01 Use PyWorker SDK (#67)
* Change PyWorker to Worker SDK
* Moved /lib to vast-sdk (https://github.com/vast-ai/vast-sdk)
2025-12-15 19:33:03 -08:00
Abiola Akinnubi 2ce741a8b7 Merge pull request #74 from vast-ai/AUTO-912
Mark pyworkers as "Error" if startup script fails. to avoid silent fail that waits for autoscaler.
2025-12-11 17:05:13 -08:00
Abiola Akinnubi 4ecc07032f Mark pyworkers as "Error" if startup script fails. to avoid silent fail that waits for autoscaler. 2025-12-11 12:51:56 -08:00
edgaratvast df61e6e946 correct version pin for aiohttp (#73)
Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-12-10 19:34:52 -08:00
LucasArmandVast 70f8a8f534 Merge pull request #72 from vast-ai/hotfix-pin-pycares
Hotfix: pin pycares
2025-12-10 20:41:44 -05:00
Lucas Armand 7be8aa6397 pin pycares 2025-12-10 17:38:03 -08:00
Colter-Downing 138fc3ac47 Merge pull request #71 from vast-ai/AUTO-comfyui-updates
Auto comfyui updates
2025-12-04 10:55:12 -08:00
Colter Downing 222ac2a0dd default endpoint name 2025-12-04 10:54:55 -08:00
Colter Downing 40aed9b5f8 adding s3 as an option 2025-12-04 10:52:57 -08:00
Colter Downing d4d36bf86e done with comfy updates 2025-12-03 20:45:55 -08:00
Colter Downing e839cfc6e8 include view in API wrapper 2025-12-03 20:22:45 -08:00
Colter Downing f04138e13b update to be able to get images 2025-12-03 20:16:25 -08:00
Colter-Downing de3aa87c8f Merge pull request #70 from vast-ai/AUTO-tgi-client-edits
update tgi client
2025-12-03 18:40:01 -08:00
Colter Downing 6b5b1341a7 update tgi client 2025-12-03 18:38:42 -08:00
Colter-Downing 8be92c03de Merge pull request #69 from vast-ai/AUTO-874--fix-openai-worker-client
defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first
2025-12-03 16:59:56 -08:00
Colter Downing adedb8ba90 defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first if present 2025-12-03 16:57:28 -08:00
LucasArmandVast 2f543c01ad Merge pull request #68 from vast-ai/fix-vllm-concurrency
Increase model wait time for vLLM
2025-12-03 16:13:51 -05:00
Lucas Armand 0bcd2219ea Increase model wait time for vLLM 2025-12-03 12:38:52 -08:00
12 changed files with 921 additions and 133 deletions
+1 -11
View File
@@ -1,11 +1 @@
aiohttp[speedups]==3.10.1 vastai-sdk>=0.3.0
anyio~=4.4
lib~=4.0
nltk~=3.9
psutil~=6.0
pycryptodome~=3.20
Requests~=2.32
transformers~=4.52
utils==1.0.*
hf_transfer>=0.1.9
git+https://github.com/vast-ai/vast-sdk.git@session
+106 -2
View File
@@ -46,6 +46,33 @@ JSON
exit 1 exit 1
} }
function install_vastai_sdk() {
# If SDK_BRANCH is set, install vastai-sdk from the vast-sdk repo at that branch/tag/commit.
if [ -n "${SDK_BRANCH:-}" ]; then
if [ -n "${SDK_VERSION:-}" ]; then
echo "WARNING: Both SDK_BRANCH and SDK_VERSION are set; using SDK_BRANCH=${SDK_BRANCH}"
fi
echo "Installing vastai-sdk from https://github.com/vast-ai/vast-sdk/ @ ${SDK_BRANCH}"
if ! uv pip install "vastai-sdk @ git+https://github.com/vast-ai/vast-sdk.git@${SDK_BRANCH}"; then
report_error_and_exit "Failed to install vastai-sdk from vast-ai/vast-sdk@${SDK_BRANCH}"
fi
return 0
fi
if [ -n "${SDK_VERSION:-}" ]; then
echo "Installing vastai-sdk version ${SDK_VERSION}"
if ! uv pip install "vastai-sdk==${SDK_VERSION}"; then
report_error_and_exit "Failed to install vastai-sdk==${SDK_VERSION}"
fi
return 0
fi
echo "Installing default vastai-sdk"
if ! uv pip install vastai-sdk; then
report_error_and_exit "Failed to install vastai-sdk"
fi
}
[ -n "$BACKEND" ] && [ -z "$HF_TOKEN" ] && report_error_and_exit "HF_TOKEN must be set when BACKEND is set!" [ -n "$BACKEND" ] && [ -z "$HF_TOKEN" ] && report_error_and_exit "HF_TOKEN must be set when BACKEND is set!"
[ -z "$CONTAINER_ID" ] && report_error_and_exit "CONTAINER_ID must be set!" [ -z "$CONTAINER_ID" ] && report_error_and_exit "CONTAINER_ID must be set!"
[ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && report_error_and_exit "For comfyui backends, COMFY_MODEL must be set!" [ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && report_error_and_exit "For comfyui backends, COMFY_MODEL must be set!"
@@ -63,7 +90,8 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG echo_var PYWORKER_LOG
echo_var MODEL_LOG echo_var MODEL_LOG
if [ -e "$MODEL_LOG" ]; then ROTATE_MODEL_LOG="${ROTATE_MODEL_LOG:-false}"
if [ "$ROTATE_MODEL_LOG" = "true" ] && [ -e "$MODEL_LOG" ]; then
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old" echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
if ! cat "$MODEL_LOG" >> "$MODEL_LOG.old"; then if ! cat "$MODEL_LOG" >> "$MODEL_LOG.old"; then
report_error_and_exit "Failed to rotate model log" report_error_and_exit "Failed to rotate model log"
@@ -114,7 +142,7 @@ then
if ! uv venv --python-preference only-managed "$ENV_PATH" -p 3.10; then if ! uv venv --python-preference only-managed "$ENV_PATH" -p 3.10; then
report_error_and_exit "Failed to create virtual environment" report_error_and_exit "Failed to create virtual environment"
fi fi
if ! source "$ENV_PATH/bin/activate"; then if ! source "$ENV_PATH/bin/activate"; then
report_error_and_exit "Failed to activate virtual environment" report_error_and_exit "Failed to activate virtual environment"
fi fi
@@ -123,6 +151,8 @@ then
report_error_and_exit "Failed to install Python requirements" report_error_and_exit "Failed to install Python requirements"
fi fi
install_vastai_sdk
if ! touch ~/.no_auto_tmux; then if ! touch ~/.no_auto_tmux; then
report_error_and_exit "Failed to create ~/.no_auto_tmux" report_error_and_exit "Failed to create ~/.no_auto_tmux"
fi fi
@@ -184,6 +214,80 @@ fi
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
# ─── SDK Deployment Mode ───────────────────────────────────────────────
if [ "$IS_DEPLOYMENT" = "true" ]; then
echo "=== SDK Deployment Mode ==="
echo "DEPLOYMENT_ID: $DEPLOYMENT_ID"
DEPLOY_DIR="/workspace/deployment"
mkdir -p "$DEPLOY_DIR"
VAST_API_BASE="${VAST_API_BASE:-https://console.vast.ai}"
# Download deployment code, retrying until the blob is available on S3.
# The s3_key exists in the DB as soon as the deployment is created, but the
# actual upload may still be in flight from the client side.
echo "Downloading deployment code..."
RETRY=0
while true; do
DOWNLOAD_RESPONSE=$(curl -sS \
-H "Authorization: Bearer $CONTAINER_API_KEY" \
"${VAST_API_BASE}/api/v0/deployment/${DEPLOYMENT_ID}/download_url/")
DOWNLOAD_URL=$(python3 -c "
import sys, json
try:
d = json.load(sys.stdin)
print(d.get('download_url') or '')
except: print('')
" <<< "$DOWNLOAD_RESPONSE")
if [ -z "$DOWNLOAD_URL" ] || [ "$DOWNLOAD_URL" = "None" ]; then
RETRY=$((RETRY + 1))
echo "No download URL yet (attempt $RETRY), retrying in 10s... response: $DOWNLOAD_RESPONSE"
sleep 10
continue
fi
# Got a URL — try the actual S3 download
HTTP_CODE=$(curl -sS -L -o "$DEPLOY_DIR/deployment.tar.gz" -w "%{http_code}" "$DOWNLOAD_URL")
if [ "$HTTP_CODE" = "200" ]; then
break
fi
RETRY=$((RETRY + 1))
echo "S3 download returned HTTP $HTTP_CODE (attempt $RETRY), blob not yet uploaded. Retrying in 10s..."
rm -f "$DEPLOY_DIR/deployment.tar.gz"
sleep 10
done
cd "$DEPLOY_DIR" && tar xzf deployment.tar.gz
echo "Deployment code extracted."
# Source secrets if present
if [ -f "$DEPLOY_DIR/.secrets" ]; then
echo "Sourcing secrets..."
source "$DEPLOY_DIR/.secrets"
fi
# Run on_start.sh to completion if present
if [ -f "$DEPLOY_DIR/on_start.sh" ]; then
echo "Running on_start.sh..."
chmod +x "$DEPLOY_DIR/on_start.sh"
bash "$DEPLOY_DIR/on_start.sh"
echo "on_start.sh completed."
fi
# Install SDK (uses the install_vastai_sdk function which supports SDK_BRANCH/SDK_VERSION)
install_vastai_sdk
# Run deployment in serve mode
export VAST_DEPLOYMENT_MODE=serve
echo "Starting deployment: python3 $DEPLOY_DIR/deployment.py"
python3 "$DEPLOY_DIR/deployment.py"
exit $?
fi
# ─── End SDK Deployment Mode ───────────────────────────────────────────
if ! cd "$SERVER_DIR"; then if ! cd "$SERVER_DIR"; then
report_error_and_exit "Failed to cd into SERVER_DIR: $SERVER_DIR" report_error_and_exit "Failed to cd into SERVER_DIR: $SERVER_DIR"
fi fi
+1 -1
View File
@@ -2,7 +2,7 @@
This is the PyWorker implementation for running **ACE Step v1 3.5B** text-to-music workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI audio-generation workflows through a proxy-based architecture and returning generated audio assets. This is the PyWorker implementation for running **ACE Step v1 3.5B** text-to-music workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI audio-generation workflows through a proxy-based architecture and returning generated audio assets.
Each request has a static cost of `100`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node. Each request has a static cost of `1000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
## Requirements ## Requirements
+142 -3
View File
@@ -1,8 +1,16 @@
# ComfyUI PyWorker # ComfyUI PyWorker
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
The cost for each request has a static value of `100`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node. The cost for each request has a static value of `100`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
## Instance Setup
1. Pick a template
- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless))
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Requirements ## Requirements
@@ -10,6 +18,137 @@ This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) a
A docker image is provided but you may use any if the above requirements are met. A docker image is provided but you may use any if the above requirements are met.
## Client
The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage.
### Setup
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
2. Set your API key:
```bash
export VAST_API_KEY=<your_api_key>
```
### Usage
```bash
# Default prompt
python -m workers.comfyui-json.client
# Custom prompt
python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow"
# With options
python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30
# Using a custom workflow file
python -m workers.comfyui-json.client --workflow my_workflow.json
# With S3 upload
python -m workers.comfyui-json.client --s3
```
### CLI Flags
| Flag | Default | Description |
|------|---------|-------------|
| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name |
| `--prompt` | (default) | Text prompt for image generation |
| `--workflow` | (none) | Path to custom workflow JSON file |
| `--width` | 512 | Image width in pixels |
| `--height` | 512 | Image height in pixels |
| `--steps` | 20 | Number of denoising steps |
| `--seed` | (random) | Random seed for reproducibility |
| `--s3` | (disabled) | Upload generated images to S3 |
### Output
Images are saved to `./generated_images/comfy_{seed}.png`.
### S3 Upload (Optional)
You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag.
**1. Set environment variables:**
```bash
export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com"
export S3_BUCKET_NAME="my-bucket"
export S3_ACCESS_KEY_ID="your-access-key-id"
export S3_SECRET_ACCESS_KEY="your-secret-access-key"
```
**2. Run with S3 upload enabled:**
```bash
python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3
```
Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`.
**Note:** Requires `boto3` (`pip install boto3`).
## Benchmarking
### Custom Benchmark Workflows
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
**Ways to provide the benchmark file:**
- Fork this repository and add your `benchmark.json` file
- Write the file during worker provisioning (onstart script or setup phase)
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
### Default Benchmark (Fallback)
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
| Environment Variable | Default Value | Description |
| -------------------- | ------------- | ----------- |
| BENCHMARK_TEST_WIDTH | 512 | Image width (pixels) |
| BENCHMARK_TEST_HEIGHT | 512 | Image height (pixels) |
| BENCHMARK_TEST_STEPS | 20 | Number of denoising steps |
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
#### Calibrating Fallback Benchmark Duration
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
**Example:** If your typical workflow should complete in 90 seconds on acceptable hardware:
```bash
# 1. Measure it/sec on your reference machine
# RTX 4090 typically achieves ~43 it/sec with SD1.5
# 2. Calculate required steps
# 90 seconds × 43 it/sec = 3870 steps
# 3. Configure benchmark
export BENCHMARK_TEST_STEPS=3870
# 4. Machines completing significantly slower than 90s indicate hardware issues
```
**Performance expectations:**
- Benchmark duration should remain consistent across identical GPU models
- Significant variation (>20%) may indicate thermal, power, or configuration issues
## Endpoint ## Endpoint
The worker provides a single endpoint: The worker provides a single endpoint:
@@ -170,4 +309,4 @@ See the client example for implementation details on how to integrate with the C
--- ---
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler. See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
+301 -23
View File
@@ -1,34 +1,312 @@
import os
import sys
import json
import uuid import uuid
import random import random
import asyncio import asyncio
import random import logging
import argparse
import aiohttp
from vastai import Serverless from vastai import Serverless
async def main(): # ---------------------- Config ----------------------
async with Serverless() as client: DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed"
endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name ENDPOINT_NAME = "my-comfyui-endpoint"
DEFAULT_WIDTH = 512
DEFAULT_HEIGHT = 512
DEFAULT_STEPS = 20
COST = 100 # Fixed cost for ComfyUI requests
payload = { # Optional S3 Configuration (from environment variables)
"input": { S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
"request_id": str(uuid.uuid4()), S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
"modifier": "Text2Image", S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
"modifications": { S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
"prompt": "a beautiful landscape with mountains and lakes",
"width": 1024, logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
"height": 1024, log = logging.getLogger(__name__)
"steps": 20,
"seed": random.randint(0, 2**32 - 1)
}, def get_s3_client():
"workflow_json": {} # Empty since using modifier approach """Create and return an S3 client configured for the S3-compatible endpoint"""
} try:
import boto3
from botocore.config import Config
except ImportError:
log.error("boto3 is required for S3 uploads. Install with: pip install boto3")
return None
if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
log.error("S3 environment variables not fully configured. Required:")
log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY")
return None
return boto3.client(
"s3",
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config=Config(signature_version="s3v4"),
)
# ---------------------- API Functions ----------------------
async def call_generate(
client: Serverless,
*,
endpoint_name: str,
prompt: str,
width: int,
height: int,
steps: int,
seed: int,
) -> dict:
"""Generate image using Text2Image modifier"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"request_id": str(uuid.uuid4()),
"modifier": "Text2Image",
"modifications": {
"prompt": prompt,
"width": width,
"height": height,
"steps": steps,
"seed": seed,
},
} }
}
response = await endpoint.request("/generate/sync", payload, cost=100) return await endpoint.request("/generate/sync", payload, cost=COST)
async def call_generate_workflow(
client: Serverless,
*,
endpoint_name: str,
workflow_json: dict,
) -> dict:
"""Generate using custom workflow JSON"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"request_id": str(uuid.uuid4()),
"workflow_json": workflow_json,
}
}
return await endpoint.request("/generate/sync", payload, cost=COST)
# ---------------------- Demo Class ----------------------
class APIDemo:
def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False):
self.client = client
self.endpoint_name = endpoint_name
self.upload_s3 = upload_s3
self.s3_client = get_s3_client() if upload_s3 else None
if upload_s3 and not self.s3_client:
log.warning("S3 upload requested but client creation failed. Images will only be saved locally.")
def extract_filename(self, response: dict) -> str | None:
"""Extract the generated image filename from ComfyUI response"""
if "comfyui_response" in response:
for data in response["comfyui_response"].values():
if isinstance(data, dict) and "outputs" in data:
for node_output in data["outputs"].values():
if "images" in node_output and node_output["images"]:
return node_output["images"][0].get("filename")
return None
async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch and save image locally from the worker, optionally upload to S3"""
os.makedirs("generated_images", exist_ok=True)
return await self._fetch_image(worker_url, filename, local_name)
def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None:
"""Upload a local file to S3 and return the S3 URL"""
if not self.s3_client:
return None
try:
self.s3_client.upload_file(
local_path,
S3_BUCKET_NAME,
s3_key,
ExtraArgs={"ContentType": "image/png"}
)
s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}"
print(f" ☁️ Uploaded to S3: {s3_key}")
return s3_url
except Exception as e:
log.error(f"Failed to upload to S3: {e}")
return None
async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch image from worker's /view endpoint and save locally"""
if not worker_url:
return None
try:
url = f"{worker_url}/view"
params = {"filename": filename, "type": "output"}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, ssl=False) as resp:
if resp.status == 200:
path = f"generated_images/{local_name}"
image_data = await resp.read()
with open(path, "wb") as f:
f.write(image_data)
print(f" 💾 Saved: {path}")
# Upload to S3 if enabled
if self.upload_s3 and self.s3_client:
s3_key = f"comfyui/{local_name}"
self._upload_to_s3(path, s3_key)
return path
return None
except Exception:
return None
async def demo_prompt(
self,
prompt: str,
width: int,
height: int,
steps: int,
seed: int | None,
):
"""Demo: Generate image from text prompt"""
print("=" * 60)
print("COMFYUI TEXT-TO-IMAGE DEMO")
print("=" * 60)
if seed is None:
seed = random.randint(0, 2**32 - 1)
print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}")
print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}")
print("\n🎨 Generating image...")
response = await call_generate(
self.client,
endpoint_name=self.endpoint_name,
prompt=prompt,
width=width,
height=height,
steps=steps,
seed=seed,
)
print("\n✅ Generation complete!")
# Get worker URL for fetching images
worker_url = response.get("url", "")
print(f"Worker URL: {worker_url}")
# Fetch and save image
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, f"comfy_{seed}.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
async def demo_workflow(self, workflow_file: str):
"""Demo: Generate using custom workflow file"""
print("=" * 60)
print("COMFYUI CUSTOM WORKFLOW DEMO")
print("=" * 60)
if not os.path.exists(workflow_file):
log.error(f"Workflow file not found: {workflow_file}")
return
with open(workflow_file, "r") as f:
workflow_json = json.load(f)
print(f"Workflow: {workflow_file}")
print("\n🎨 Generating...")
response = await call_generate_workflow(
self.client,
endpoint_name=self.endpoint_name,
workflow_json=workflow_json,
)
print("\n✅ Generation complete!")
worker_url = response.get("url", "")
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, "workflow.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT",
help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')")
p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead")
p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})")
p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})")
p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})")
p.add_argument("--seed", type=int, default=None, help="Seed (default: random)")
p.add_argument("--s3", action="store_true",
help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)")
return p
async def main_async():
args = build_arg_parser().parse_args()
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
if args.s3:
print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint, upload_s3=args.s3)
if args.workflow:
await demo.demo_workflow(workflow_file=args.workflow)
else:
await demo.demo_prompt(
prompt=args.prompt,
width=args.width,
height=args.height,
steps=args.steps,
seed=args.seed,
)
except AttributeError as e:
if "API key" in str(e):
log.error("API key missing. Set VAST_API_KEY environment variable.")
else:
log.error(f"Error: {e}")
sys.exit(1)
except Exception as e:
log.error(f"Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Get the file from the path on the local machine using SCP or SFTP
# or configure S3 to upload to cloud storage.
print(response["response"]["output"][0]["local_path"])
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main_async())
+33 -26
View File
@@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker. This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended) - [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless)) - [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected. All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. 2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo) ## Client Setup (Demo)
@@ -34,38 +33,20 @@ uv pip install -r requirements.txt
Several examples have been provided in the client to help you get started with your own implementation. Several examples have been provided in the client to help you get started with your own implementation.
### Completions First, set your API key as an environment variable:
Call to `/v1/completions` with json response
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME> export VAST_API_KEY=<your_api_key>
``` ```
### Chat Completion (json) The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming) ### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response Call to `/v1/chat/completions` with streaming response
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME> python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
``` ```
### Interactive Chat (streaming) ### Interactive Chat (streaming)
@@ -75,6 +56,32 @@ Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit. Type `clear` to clear the chat history or `quit` to exit.
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME> python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
``` ```
+32 -16
View File
@@ -18,7 +18,7 @@ logging.basicConfig(
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
# ---------------------- Prompts ---------------------- # ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "the capital of USA is" COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = ( TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? " "Can you list the files in the current working directory and tell me what you see? "
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
# ---- OpenAI-compatible calls (non-streaming) ---- # ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]: async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"model": model, "model": model,
@@ -111,9 +111,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"]) resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"])
return resp["response"] return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"model": model, "model": model,
@@ -128,9 +128,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
return resp["response"] return resp["response"]
# ---- Streaming variants ---- # ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs): async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"model": model, "model": model,
@@ -144,9 +144,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"], stream=True) resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"], stream=True)
return resp["response"] # async generator return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs): async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"model": model, "model": model,
@@ -166,9 +166,10 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
class APIDemo: class APIDemo:
"""Demo and testing functionality for the API client""" """Demo and testing functionality for the API client"""
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None): def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
self.client = client self.client = client
self.model = model self.model = model
self.endpoint_name = endpoint_name
self.tool_manager = tool_manager or ToolManager() self.tool_manager = tool_manager or ToolManager()
# ----- Streaming handler ----- # ----- Streaming handler -----
@@ -177,10 +178,15 @@ class APIDemo:
reasoning_content = "" reasoning_content = ""
printed_reasoning = False printed_reasoning = False
printed_answer = False printed_answer = False
finish_reason = None
async for chunk in stream: async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0] choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {}) delta = choice.get("delta", {})
# Track finish reason
if choice.get("finish_reason"):
finish_reason = choice.get("finish_reason")
# reasoning tokens # reasoning tokens
rc = delta.get("reasoning_content") rc = delta.get("reasoning_content")
@@ -211,6 +217,8 @@ class APIDemo:
print(f"Reasoning tokens: {len(reasoning_content.split())}") print(f"Reasoning tokens: {len(reasoning_content.split())}")
if printed_answer: if printed_answer:
print(f"Response tokens: {len(full_response.split())}") print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
return full_response return full_response
@@ -223,6 +231,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
prompt=COMPLETIONS_PROMPT, prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE, temperature=DEFAULT_TEMPERATURE,
) )
@@ -241,6 +250,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE temperature=DEFAULT_TEMPERATURE
) )
@@ -253,6 +263,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE temperature=DEFAULT_TEMPERATURE
) )
@@ -279,6 +290,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool, tools=minimal_tool,
tool_choice="none", tool_choice="none",
max_tokens=10 max_tokens=10
@@ -304,6 +316,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
tools=self.tool_manager.get_ls_tool_definition(), tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto", tool_choice="auto",
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
@@ -381,6 +394,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE, temperature=DEFAULT_TEMPERATURE,
) )
@@ -419,7 +433,6 @@ class APIDemo:
print("=" * 60) print("=" * 60)
print("INTERACTIVE STREAMING CHAT") print("INTERACTIVE STREAMING CHAT")
print("=" * 60) print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history") print("Type 'quit' to exit, 'clear' to clear history")
print() print()
@@ -445,7 +458,8 @@ class APIDemo:
stream = await stream_chat_completions( stream = await stream_chat_completions(
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=0.7 temperature=0.7
) )
@@ -465,8 +479,8 @@ class APIDemo:
# ---------------------- CLI ---------------------- # ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser: def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)") p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", required=True, help="Model to use for requests (required)") p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)") p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False) modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint") modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
@@ -494,12 +508,14 @@ async def main_async():
print("Please specify exactly one test mode") print("Please specify exactly one test mode")
sys.exit(1) sys.exit(1)
print(f"Using model: {args.model}")
print("=" * 60) print("=" * 60)
print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}")
try: try:
async with Serverless() as client: async with Serverless() as client:
demo = APIDemo(client, args.model, ToolManager()) demo = APIDemo(client, args.model, args.endpoint, ToolManager())
if args.completion: if args.completion:
await demo.demo_completions() await demo.demo_completions()
+12 -4
View File
@@ -28,6 +28,12 @@ MODEL_INFO_LOG_MSGS = [
nltk.download("words") nltk.download("words")
WORD_LIST = nltk.corpus.words.words() WORD_LIST = nltk.corpus.words.words()
def request_parser(request):
data = request
if request.get("input") is not None:
data = request.get("input")
return data
def completions_benchmark_generator() -> dict: def completions_benchmark_generator() -> dict:
prompt = " ".join(random.choices(WORD_LIST, k=int(250))) prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
@@ -54,18 +60,20 @@ worker_config = WorkerConfig(
route="/v1/completions", route="/v1/completions",
workload_calculator= lambda data: data.get("max_tokens", 0), workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True, allow_parallel_requests=True,
max_queue_time=60.0, request_parser=request_parser,
max_queue_time=600.0,
benchmark_config=BenchmarkConfig( benchmark_config=BenchmarkConfig(
generator=completions_benchmark_generator, generator=completions_benchmark_generator,
concurrency=100, concurrency=10,
runs=2 runs=3
) )
), ),
HandlerConfig( HandlerConfig(
route="/v1/chat/completions", route="/v1/chat/completions",
workload_calculator= lambda data: data.get("max_tokens", 0), workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True, allow_parallel_requests=True,
max_queue_time=60.0, request_parser=request_parser,
max_queue_time=600.0,
) )
], ],
log_action_config=LogActionConfig( log_action_config=LogActionConfig(
+93 -9
View File
@@ -1,19 +1,103 @@
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints: # HuggingFace TGI PyWorker
1. `generate`: Generates the LLM's response to a given prompt in a single request. This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
2. `generate_stream`: Streams the LLM's response token by token.
Both endpoints use the following API payload format: ## Instance Setup
1. Pick a template
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
The template can be configured via the template interface. You may want to change the model or startup arguments.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
## Using the Test Client
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
First, set your API key as an environment variable:
```bash
export VAST_API_KEY=<your_api_key>
```
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
### Generate (Streaming)
Call to `/generate_stream` with streaming response:
```bash
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
```
### Generate (Non-Streaming)
Call to `/generate` with json response:
```bash
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
```
### Interactive Session (Streaming)
Interactive session with streaming responses. Type `quit` to exit.
```bash
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
```
## API Endpoints
TGI provides two primary endpoints:
### Generate (Non-Streaming)
`/generate` - Returns the complete response in a single request.
```json ```json
{ {
"inputs": "PROMPT", "inputs": "Your prompt here",
"parameters": { "parameters": {
"max_new_tokens": 250 "max_new_tokens": 1024,
"temperature": 0.7,
"return_full_text": false
} }
} }
``` ```
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an ### Generate Stream (Streaming)
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
approximately 2 seconds to complete. `/generate_stream` - Streams the response token by token.
```json
{
"inputs": "Your prompt here",
"parameters": {
"max_new_tokens": 1024,
"temperature": 0.7,
"do_sample": true,
"return_full_text": false
}
}
```
## Performance Notes
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
+195 -34
View File
@@ -1,61 +1,222 @@
import logging
import json
import os
import sys
import argparse
from vastai import Serverless from vastai import Serverless
import asyncio import asyncio
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name # ---------------------- Logging ----------------------
MAX_TOKENS = 1024 logging.basicConfig(
PROMPT = "Think step by step: Tell me about the Python programming language." level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def call_generate(client: Serverless) -> None: # ---------------------- Defaults ----------------------
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- API Calls ----------------------
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
"""Non-streaming generation via /generate endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"inputs": PROMPT, "inputs": prompt,
"parameters": { "parameters": {
"max_new_tokens": MAX_TOKENS, "max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": 0.7, "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"return_full_text": False "return_full_text": False,
} }
} }
log.debug("POST /generate %s", json.dumps(payload)[:500])
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS) resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
return resp["response"]
print(resp["response"]["generated_text"])
async def call_generate_stream(client: Serverless) -> None: async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) """Streaming generation via /generate_stream endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"inputs": PROMPT, "inputs": prompt,
"parameters": { "parameters": {
"max_new_tokens": MAX_TOKENS, "max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": 0.7, "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"do_sample": True, "do_sample": True,
"return_full_text": False, "return_full_text": False,
} }
} }
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
resp = await endpoint.request( resp = await endpoint.request(
"/generate_stream", "/generate_stream",
payload, payload,
cost=MAX_TOKENS, cost=payload["parameters"]["max_new_tokens"],
stream=True, stream=True,
) )
stream = resp["response"] return resp["response"] # async generator
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("Answer:\n", end="", flush=True)
print(tok, end="", flush=True)
async def main(): # ---------------------- Demo Runner ----------------------
async with Serverless() as client: class APIDemo:
await call_generate(client) """Demo and testing functionality for the TGI API client"""
await call_generate_stream(client)
def __init__(self, client: Serverless, endpoint_name: str):
self.client = client
self.endpoint_name = endpoint_name
async def handle_streaming_response(self, stream) -> str:
"""Process streaming response and print tokens"""
full_response = ""
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(tok, end="", flush=True)
full_response += tok
print() # newline
if printed_answer:
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
return full_response
async def demo_generate(self) -> None:
"""Demo non-streaming generation"""
print("=" * 60)
print("GENERATE DEMO (NON-STREAMING)")
print("=" * 60)
response = await call_generate(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print(f"\n💬 Response: {response.get('generated_text', '')}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def demo_generate_stream(self) -> None:
"""Demo streaming generation"""
print("=" * 60)
print("GENERATE DEMO (STREAMING)")
print("=" * 60)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
try:
await self.handle_streaming_response(stream)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
async def interactive_chat(self) -> None:
"""Interactive session with streaming generation"""
print("=" * 60)
print("INTERACTIVE STREAMING SESSION")
print("=" * 60)
print(f"Using endpoint: {self.endpoint_name}")
print("Type 'quit' to exit")
print()
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() == "quit":
print("👋 Goodbye!")
break
elif not user_input:
continue
print("Assistant: ", end="", flush=True)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=user_input,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
full_response = ""
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
print(tok, end="", flush=True)
full_response += tok
print() # newline
except KeyboardInterrupt:
print("\n👋 Session interrupted. Goodbye!")
break
except Exception as e:
log.error("\nError: %s", e)
continue
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
return p
async def main_async():
args = build_arg_parser().parse_args()
selected = sum([args.generate, args.generate_stream, args.interactive])
if selected == 0:
print("Please specify exactly one test mode:")
print(" --generate : Test generate endpoint (non-streaming)")
print(" --generate-stream : Test generate endpoint with streaming")
print(" --interactive : Start interactive streaming session")
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
sys.exit(1)
elif selected > 1:
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint)
if args.generate:
await demo.demo_generate()
elif args.generate_stream:
await demo.demo_generate_stream()
elif args.interactive:
await demo.interactive_chat()
except Exception as e:
log.error("Error during test: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main_async())
+4 -3
View File
@@ -52,17 +52,18 @@ worker_config = WorkerConfig(
HandlerConfig( HandlerConfig(
route="/generate", route="/generate",
allow_parallel_requests=True, allow_parallel_requests=True,
max_queue_time=60.0, max_queue_time=600.0,
benchmark_config=BenchmarkConfig( benchmark_config=BenchmarkConfig(
generator=benchmark_generator, generator=benchmark_generator,
concurrency=50 concurrency=10,
runs=3
), ),
workload_calculator= lambda x: x["parameters"]["max_new_tokens"] workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
), ),
HandlerConfig( HandlerConfig(
route="/generate_stream", route="/generate_stream",
allow_parallel_requests=True, allow_parallel_requests=True,
max_queue_time=60.0, max_queue_time=600.0,
workload_calculator= lambda x: x["parameters"]["max_new_tokens"] workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
) )
], ],
+1 -1
View File
@@ -2,7 +2,7 @@
This is the PyWorker implementation for running **Wan 2.2 T2V A14B** text-to-video workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI video-generation workflows through a proxy-based architecture and returning generated video assets. This is the PyWorker implementation for running **Wan 2.2 T2V A14B** text-to-video workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI video-generation workflows through a proxy-based architecture and returning generated video assets.
Each request has a static cost of `100`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node. Each request has a static cost of `10000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
## Requirements ## Requirements