Compare commits

..

19 Commits

Author SHA1 Message Date
Lucas Armand b02ade1df5 changed to session 2025-12-15 14:23:58 -08:00
Lucas Armand 0b6f381dd7 Add misc 2025-12-15 11:49:36 -08:00
Lucas Armand 74f8b6a1ef Added wheres-my-pyworker 2025-12-15 10:33:44 -08:00
Lucas Armand fa2bf082c2 only require HF_Token on backend 2025-12-12 14:47:29 -08:00
Lucas Armand 6a57ff8e0a try reverting env var 2025-12-12 12:16:33 -08:00
Lucas Armand 375633cb18 Fix 2025-12-12 12:12:57 -08:00
Lucas Armand ccd29ed8b6 remove input wrapping for vllm 2025-12-12 11:48:54 -08:00
Lucas Armand 2b30c69933 updated cost 2025-12-12 10:43:05 -08:00
Lucas Armand 4d99c12820 Added clients, updated READMEs 2025-12-12 10:41:21 -08:00
Lucas Armand 6060f8ce0c updated start_server.sh 2025-12-12 10:04:33 -08:00
Lucas Armand 067fa936fb remove legacy pyworker 2025-12-11 16:55:48 -08:00
Lucas Armand 405a8f1c0d returned to worker-sdk 2025-12-10 16:37:09 -08:00
Lucas Armand 12f4f23d39 remove parse request 2025-12-10 15:16:23 -08:00
Lucas Armand e2a771bb5a update ace and wan workers 2025-12-10 15:09:27 -08:00
Lucas Armand 0cd64adfc4 remove input 2025-12-10 14:47:47 -08:00
Lucas Armand 6f795b8fb8 remove input from workers 2025-12-10 14:46:10 -08:00
Lucas Armand 4bcc508473 reduce vllm benchmark runs to 2 2025-11-25 16:54:17 -08:00
Lucas Armand 74d7330800 add wan and ace workers 2025-11-25 16:08:40 -08:00
Lucas Armand 2ce0450809 Add worker.pys 2025-11-25 16:08:38 -08:00
12 changed files with 139 additions and 967 deletions
+11 -1
View File
@@ -1 +1,11 @@
vastai-sdk>=0.3.0 aiohttp[speedups]==3.10.1
anyio~=4.4
lib~=4.0
nltk~=3.9
psutil~=6.0
pycryptodome~=3.20
Requests~=2.32
transformers~=4.52
utils==1.0.*
hf_transfer>=0.1.9
git+https://github.com/vast-ai/vast-sdk.git@session
+13 -157
View File
@@ -2,17 +2,10 @@
set -e -o pipefail set -e -o pipefail
# Check for force update flag
FORCE_UPDATE=false
if [ -f "/.force_update" ]; then
echo "Force update flag detected at /.force_update"
FORCE_UPDATE=true
fi
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}" WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker" SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
ENV_PATH="${ENV_PATH:-$WORKSPACE_DIR/worker-env}" ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log" DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log" PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
@@ -53,42 +46,6 @@ JSON
exit 1 exit 1
} }
function install_vastai_sdk() {
local uv_flags=()
if [ "${USE_SYSTEM_PYTHON:-}" = "true" ]; then
uv_flags+=(--system --break-system-packages)
fi
if [ "$FORCE_UPDATE" = true ]; then
uv_flags+=(--force-reinstall)
echo "Force reinstalling vastai-sdk"
fi
# If SDK_BRANCH is set, install vastai-sdk from the vast-sdk repo at that branch/tag/commit.
if [ -n "${SDK_BRANCH:-}" ]; then
if [ -n "${SDK_VERSION:-}" ]; then
echo "WARNING: Both SDK_BRANCH and SDK_VERSION are set; using SDK_BRANCH=${SDK_BRANCH}"
fi
echo "Installing vastai-sdk from https://github.com/vast-ai/vast-sdk/ @ ${SDK_BRANCH}"
if ! uv pip install "${uv_flags[@]}" "vastai-sdk @ git+https://github.com/vast-ai/vast-sdk.git@${SDK_BRANCH}"; then
report_error_and_exit "Failed to install vastai-sdk from vast-ai/vast-sdk@${SDK_BRANCH}"
fi
return 0
fi
if [ -n "${SDK_VERSION:-}" ]; then
echo "Installing vastai-sdk version ${SDK_VERSION}"
if ! uv pip install "${uv_flags[@]}" "vastai-sdk==${SDK_VERSION}"; then
report_error_and_exit "Failed to install vastai-sdk==${SDK_VERSION}"
fi
return 0
fi
echo "Installing default vastai-sdk"
if ! uv pip install "${uv_flags[@]}" vastai-sdk; then
report_error_and_exit "Failed to install vastai-sdk"
fi
}
[ -n "$BACKEND" ] && [ -z "$HF_TOKEN" ] && report_error_and_exit "HF_TOKEN must be set when BACKEND is set!" [ -n "$BACKEND" ] && [ -z "$HF_TOKEN" ] && report_error_and_exit "HF_TOKEN must be set when BACKEND is set!"
[ -z "$CONTAINER_ID" ] && report_error_and_exit "CONTAINER_ID must be set!" [ -z "$CONTAINER_ID" ] && report_error_and_exit "CONTAINER_ID must be set!"
[ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && report_error_and_exit "For comfyui backends, COMFY_MODEL must be set!" [ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && report_error_and_exit "For comfyui backends, COMFY_MODEL must be set!"
@@ -106,8 +63,7 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG echo_var PYWORKER_LOG
echo_var MODEL_LOG echo_var MODEL_LOG
ROTATE_MODEL_LOG="${ROTATE_MODEL_LOG:-false}" if [ -e "$MODEL_LOG" ]; then
if [ "$ROTATE_MODEL_LOG" = "true" ] && [ -e "$MODEL_LOG" ]; then
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old" echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
if ! cat "$MODEL_LOG" >> "$MODEL_LOG.old"; then if ! cat "$MODEL_LOG" >> "$MODEL_LOG.old"; then
report_error_and_exit "Failed to rotate model log" report_error_and_exit "Failed to rotate model log"
@@ -128,21 +84,8 @@ if ! grep -q "VAST" /etc/environment; then
fi fi
fi fi
if [ "${USE_SYSTEM_PYTHON:-}" = "true" ]; then if [ ! -d "$ENV_PATH" ]
echo "Using system Python: $(which python3)" then
if ! which uv > /dev/null 2>&1; then
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
report_error_and_exit "Failed to install uv package manager"
fi
if [[ -f ~/.local/bin/env ]]; then
if ! source ~/.local/bin/env; then
report_error_and_exit "Failed to source uv environment"
fi
fi
fi
install_vastai_sdk
touch ~/.no_auto_tmux
elif [ ! -d "$ENV_PATH" ]; then
echo "setting up venv" echo "setting up venv"
if ! which uv; then if ! which uv; then
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
@@ -161,27 +104,10 @@ elif [ ! -d "$ENV_PATH" ]; then
if ! git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"; then if ! git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"; then
report_error_and_exit "Failed to clone pyworker repository" report_error_and_exit "Failed to clone pyworker repository"
fi fi
elif [ "$FORCE_UPDATE" = true ]; then
echo "Force updating pyworker repository"
if ! (cd "$SERVER_DIR" && git fetch --all); then
report_error_and_exit "Failed to fetch pyworker repository updates"
fi
fi fi
if [[ -n ${PYWORKER_REF:-} ]]; then if [[ -n ${PYWORKER_REF:-} ]]; then
if [ "$FORCE_UPDATE" = true ]; then if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); then
echo "Force updating to pyworker reference: $PYWORKER_REF" report_error_and_exit "Failed to checkout pyworker reference: $PYWORKER_REF"
if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF" && git pull); then
report_error_and_exit "Failed to force update pyworker reference: $PYWORKER_REF"
fi
else
if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); then
report_error_and_exit "Failed to checkout pyworker reference: $PYWORKER_REF"
fi
fi
elif [ "$FORCE_UPDATE" = true ]; then
echo "Force updating pyworker to latest"
if ! (cd "$SERVER_DIR" && git pull); then
report_error_and_exit "Failed to pull latest pyworker changes"
fi fi
fi fi
@@ -197,8 +123,6 @@ elif [ ! -d "$ENV_PATH" ]; then
report_error_and_exit "Failed to install Python requirements" report_error_and_exit "Failed to install Python requirements"
fi fi
install_vastai_sdk
if ! touch ~/.no_auto_tmux; then if ! touch ~/.no_auto_tmux; then
report_error_and_exit "Failed to create ~/.no_auto_tmux" report_error_and_exit "Failed to create ~/.no_auto_tmux"
fi fi
@@ -208,44 +132,11 @@ else
report_error_and_exit "Failed to source uv environment" report_error_and_exit "Failed to source uv environment"
fi fi
fi fi
if ! source "$ENV_PATH/bin/activate"; then if ! source "$WORKSPACE_DIR/worker-env/bin/activate"; then
report_error_and_exit "Failed to activate existing virtual environment" report_error_and_exit "Failed to activate existing virtual environment"
fi fi
echo "environment activated" echo "environment activated"
echo "venv: $VIRTUAL_ENV" echo "venv: $VIRTUAL_ENV"
# Handle force update for existing environment
if [ "$FORCE_UPDATE" = true ]; then
echo "Performing force update on existing environment"
if [[ -d $SERVER_DIR ]]; then
echo "Force updating pyworker repository"
if ! (cd "$SERVER_DIR" && git fetch --all); then
report_error_and_exit "Failed to fetch pyworker repository updates"
fi
if [[ -n ${PYWORKER_REF:-} ]]; then
echo "Force updating to pyworker reference: $PYWORKER_REF"
if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF" && git pull); then
report_error_and_exit "Failed to force update pyworker reference: $PYWORKER_REF"
fi
else
echo "Force updating pyworker to latest"
if ! (cd "$SERVER_DIR" && git pull); then
report_error_and_exit "Failed to pull latest pyworker changes"
fi
fi
fi
install_vastai_sdk
fi
fi
# Remove force update flag after successful update
if [ "$FORCE_UPDATE" = true ]; then
echo "Removing force update flag"
rm -f "/.force_update"
echo "Force update completed successfully"
fi fi
if [ "$USE_SSL" = true ]; then if [ "$USE_SSL" = true ]; then
@@ -283,51 +174,16 @@ EOF
report_error_and_exit "Failed to generate SSL certificate request" report_error_and_exit "Failed to generate SSL certificate request"
fi fi
max_retries=5 if ! curl --header 'Content-Type: application/octet-stream' \
retry_delay=2 --data-binary @/etc/instance.csr \
for attempt in $(seq 1 "$max_retries"); do -X \
http_code=$(curl -sS -o /etc/instance.crt -w '%{http_code}' \ POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; then
--header 'Content-Type: application/octet-stream' \ report_error_and_exit "Failed to sign SSL certificate"
--data-binary @/etc/instance.csr \ fi
-X POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID")
if [ "$http_code" -ge 200 ] && [ "$http_code" -lt 300 ]; then
break
fi
echo "SSL cert signing attempt $attempt/$max_retries failed (HTTP $http_code)"
if [ "$attempt" -eq "$max_retries" ]; then
report_error_and_exit "Failed to sign SSL certificate after $max_retries attempts (HTTP $http_code)"
fi
sleep "$retry_delay"
retry_delay=$((retry_delay * 2))
done
fi fi
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
# ─── SDK Deployment Mode ───────────────────────────────────────────────
if [ "$IS_DEPLOYMENT" = "true" ]; then
echo "=== SDK Deployment Mode ==="
echo "DEPLOYMENT_ID: $DEPLOYMENT_ID"
DEPLOY_DIR="/workspace/deployment"
mkdir -p "$DEPLOY_DIR"
VAST_API_BASE="${VAST_API_BASE:-https://console.vast.ai}"
# Download deployment code, retrying until the blob is available on S3.
# The s3_key exists in the DB as soon as the deployment is created, but the
# actual upload may still be in flight from the client side.
# Install SDK (uses the install_vastai_sdk function which supports SDK_BRANCH/SDK_VERSION)
install_vastai_sdk
# Run deployment in serve mode
export VAST_DEPLOYMENT_MODE=serve
echo "Starting deployment: python3 $DEPLOY_DIR/deployment.py"
serve-vast-deployment
exit $?
fi
# ─── End SDK Deployment Mode ───────────────────────────────────────────
if ! cd "$SERVER_DIR"; then if ! cd "$SERVER_DIR"; then
report_error_and_exit "Failed to cd into SERVER_DIR: $SERVER_DIR" report_error_and_exit "Failed to cd into SERVER_DIR: $SERVER_DIR"
fi fi
+1 -1
View File
@@ -2,7 +2,7 @@
This is the PyWorker implementation for running **ACE Step v1 3.5B** text-to-music workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI audio-generation workflows through a proxy-based architecture and returning generated audio assets. This is the PyWorker implementation for running **ACE Step v1 3.5B** text-to-music workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI audio-generation workflows through a proxy-based architecture and returning generated audio assets.
Each request has a static cost of `1000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node. Each request has a static cost of `100`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
## Requirements ## Requirements
+2 -141
View File
@@ -1,16 +1,8 @@
# ComfyUI PyWorker # ComfyUI PyWorker
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's. This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture.
The cost for each request has a static value of `100`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node. The cost for each request has a static value of `100`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
## Instance Setup
1. Pick a template
- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless))
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Requirements ## Requirements
@@ -18,137 +10,6 @@ This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) a
A docker image is provided but you may use any if the above requirements are met. A docker image is provided but you may use any if the above requirements are met.
## Client
The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage.
### Setup
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
2. Set your API key:
```bash
export VAST_API_KEY=<your_api_key>
```
### Usage
```bash
# Default prompt
python -m workers.comfyui-json.client
# Custom prompt
python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow"
# With options
python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30
# Using a custom workflow file
python -m workers.comfyui-json.client --workflow my_workflow.json
# With S3 upload
python -m workers.comfyui-json.client --s3
```
### CLI Flags
| Flag | Default | Description |
|------|---------|-------------|
| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name |
| `--prompt` | (default) | Text prompt for image generation |
| `--workflow` | (none) | Path to custom workflow JSON file |
| `--width` | 512 | Image width in pixels |
| `--height` | 512 | Image height in pixels |
| `--steps` | 20 | Number of denoising steps |
| `--seed` | (random) | Random seed for reproducibility |
| `--s3` | (disabled) | Upload generated images to S3 |
### Output
Images are saved to `./generated_images/comfy_{seed}.png`.
### S3 Upload (Optional)
You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag.
**1. Set environment variables:**
```bash
export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com"
export S3_BUCKET_NAME="my-bucket"
export S3_ACCESS_KEY_ID="your-access-key-id"
export S3_SECRET_ACCESS_KEY="your-secret-access-key"
```
**2. Run with S3 upload enabled:**
```bash
python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3
```
Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`.
**Note:** Requires `boto3` (`pip install boto3`).
## Benchmarking
### Custom Benchmark Workflows
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
**Ways to provide the benchmark file:**
- Fork this repository and add your `benchmark.json` file
- Write the file during worker provisioning (onstart script or setup phase)
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
### Default Benchmark (Fallback)
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
| Environment Variable | Default Value | Description |
| -------------------- | ------------- | ----------- |
| BENCHMARK_TEST_WIDTH | 512 | Image width (pixels) |
| BENCHMARK_TEST_HEIGHT | 512 | Image height (pixels) |
| BENCHMARK_TEST_STEPS | 20 | Number of denoising steps |
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
#### Calibrating Fallback Benchmark Duration
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
**Example:** If your typical workflow should complete in 90 seconds on acceptable hardware:
```bash
# 1. Measure it/sec on your reference machine
# RTX 4090 typically achieves ~43 it/sec with SD1.5
# 2. Calculate required steps
# 90 seconds × 43 it/sec = 3870 steps
# 3. Configure benchmark
export BENCHMARK_TEST_STEPS=3870
# 4. Machines completing significantly slower than 90s indicate hardware issues
```
**Performance expectations:**
- Benchmark duration should remain consistent across identical GPU models
- Significant variation (>20%) may indicate thermal, power, or configuration issues
## Endpoint ## Endpoint
The worker provides a single endpoint: The worker provides a single endpoint:
+22 -300
View File
@@ -1,312 +1,34 @@
import os
import sys
import json
import uuid import uuid
import random import random
import asyncio import asyncio
import logging import random
import argparse
import aiohttp
from vastai import Serverless from vastai import Serverless
# ---------------------- Config ---------------------- async def main():
DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed" async with Serverless() as client:
ENDPOINT_NAME = "my-comfyui-endpoint" endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name
DEFAULT_WIDTH = 512
DEFAULT_HEIGHT = 512
DEFAULT_STEPS = 20
COST = 100 # Fixed cost for ComfyUI requests
# Optional S3 Configuration (from environment variables) payload = {
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL") "input": {
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") "request_id": str(uuid.uuid4()),
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID") "modifier": "Text2Image",
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY") "modifications": {
"prompt": "a beautiful landscape with mountains and lakes",
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s") "width": 1024,
log = logging.getLogger(__name__) "height": 1024,
"steps": 20,
"seed": random.randint(0, 2**32 - 1)
def get_s3_client(): },
"""Create and return an S3 client configured for the S3-compatible endpoint""" "workflow_json": {} # Empty since using modifier approach
try: }
import boto3
from botocore.config import Config
except ImportError:
log.error("boto3 is required for S3 uploads. Install with: pip install boto3")
return None
if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
log.error("S3 environment variables not fully configured. Required:")
log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY")
return None
return boto3.client(
"s3",
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config=Config(signature_version="s3v4"),
)
# ---------------------- API Functions ----------------------
async def call_generate(
client: Serverless,
*,
endpoint_name: str,
prompt: str,
width: int,
height: int,
steps: int,
seed: int,
) -> dict:
"""Generate image using Text2Image modifier"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"request_id": str(uuid.uuid4()),
"modifier": "Text2Image",
"modifications": {
"prompt": prompt,
"width": width,
"height": height,
"steps": steps,
"seed": seed,
},
} }
}
return await endpoint.request("/generate/sync", payload, cost=COST)
response = await endpoint.request("/generate/sync", payload, cost=100)
async def call_generate_workflow( # Get the file from the path on the local machine using SCP or SFTP
client: Serverless, # or configure S3 to upload to cloud storage.
*, print(response["response"]["output"][0]["local_path"])
endpoint_name: str,
workflow_json: dict,
) -> dict:
"""Generate using custom workflow JSON"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"request_id": str(uuid.uuid4()),
"workflow_json": workflow_json,
}
}
return await endpoint.request("/generate/sync", payload, cost=COST)
# ---------------------- Demo Class ----------------------
class APIDemo:
def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False):
self.client = client
self.endpoint_name = endpoint_name
self.upload_s3 = upload_s3
self.s3_client = get_s3_client() if upload_s3 else None
if upload_s3 and not self.s3_client:
log.warning("S3 upload requested but client creation failed. Images will only be saved locally.")
def extract_filename(self, response: dict) -> str | None:
"""Extract the generated image filename from ComfyUI response"""
if "comfyui_response" in response:
for data in response["comfyui_response"].values():
if isinstance(data, dict) and "outputs" in data:
for node_output in data["outputs"].values():
if "images" in node_output and node_output["images"]:
return node_output["images"][0].get("filename")
return None
async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch and save image locally from the worker, optionally upload to S3"""
os.makedirs("generated_images", exist_ok=True)
return await self._fetch_image(worker_url, filename, local_name)
def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None:
"""Upload a local file to S3 and return the S3 URL"""
if not self.s3_client:
return None
try:
self.s3_client.upload_file(
local_path,
S3_BUCKET_NAME,
s3_key,
ExtraArgs={"ContentType": "image/png"}
)
s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}"
print(f" ☁️ Uploaded to S3: {s3_key}")
return s3_url
except Exception as e:
log.error(f"Failed to upload to S3: {e}")
return None
async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch image from worker's /view endpoint and save locally"""
if not worker_url:
return None
try:
url = f"{worker_url}/view"
params = {"filename": filename, "type": "output"}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, ssl=False) as resp:
if resp.status == 200:
path = f"generated_images/{local_name}"
image_data = await resp.read()
with open(path, "wb") as f:
f.write(image_data)
print(f" 💾 Saved: {path}")
# Upload to S3 if enabled
if self.upload_s3 and self.s3_client:
s3_key = f"comfyui/{local_name}"
self._upload_to_s3(path, s3_key)
return path
return None
except Exception:
return None
async def demo_prompt(
self,
prompt: str,
width: int,
height: int,
steps: int,
seed: int | None,
):
"""Demo: Generate image from text prompt"""
print("=" * 60)
print("COMFYUI TEXT-TO-IMAGE DEMO")
print("=" * 60)
if seed is None:
seed = random.randint(0, 2**32 - 1)
print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}")
print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}")
print("\n🎨 Generating image...")
response = await call_generate(
self.client,
endpoint_name=self.endpoint_name,
prompt=prompt,
width=width,
height=height,
steps=steps,
seed=seed,
)
print("\n✅ Generation complete!")
# Get worker URL for fetching images
worker_url = response.get("url", "")
print(f"Worker URL: {worker_url}")
# Fetch and save image
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, f"comfy_{seed}.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
async def demo_workflow(self, workflow_file: str):
"""Demo: Generate using custom workflow file"""
print("=" * 60)
print("COMFYUI CUSTOM WORKFLOW DEMO")
print("=" * 60)
if not os.path.exists(workflow_file):
log.error(f"Workflow file not found: {workflow_file}")
return
with open(workflow_file, "r") as f:
workflow_json = json.load(f)
print(f"Workflow: {workflow_file}")
print("\n🎨 Generating...")
response = await call_generate_workflow(
self.client,
endpoint_name=self.endpoint_name,
workflow_json=workflow_json,
)
print("\n✅ Generation complete!")
worker_url = response.get("url", "")
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, "workflow.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT",
help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')")
p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead")
p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})")
p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})")
p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})")
p.add_argument("--seed", type=int, default=None, help="Seed (default: random)")
p.add_argument("--s3", action="store_true",
help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)")
return p
async def main_async():
args = build_arg_parser().parse_args()
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
if args.s3:
print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint, upload_s3=args.s3)
if args.workflow:
await demo.demo_workflow(workflow_file=args.workflow)
else:
await demo.demo_prompt(
prompt=args.prompt,
width=args.width,
height=args.height,
steps=args.steps,
seed=args.seed,
)
except AttributeError as e:
if "API key" in str(e):
log.error("API key missing. Set VAST_API_KEY environment variable.")
else:
log.error(f"Error: {e}")
sys.exit(1)
except Exception as e:
log.error(f"Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main_async()) asyncio.run(main())
+22 -29
View File
@@ -8,13 +8,14 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker. This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended) - [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless)) - [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected. All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. 2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo) ## Client Setup (Demo)
@@ -33,30 +34,12 @@ uv pip install -r requirements.txt
Several examples have been provided in the client to help you get started with your own implementation. Several examples have been provided in the client to help you get started with your own implementation.
First, set your API key as an environment variable: ### Completions
Call to `/v1/completions` with json response
```bash ```bash
export VAST_API_KEY=<your_api_key> python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
```
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Interactive Chat (streaming)
Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
``` ```
### Chat Completion (json) ### Chat Completion (json)
@@ -64,7 +47,15 @@ python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model
Call to `/v1/chat/completions` with json response Call to `/v1/chat/completions` with json response
```bash ```bash
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME> python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
``` ```
### Tool Use (json) ### Tool Use (json)
@@ -74,14 +65,16 @@ Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model. This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash ```bash
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME> python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
``` ```
### Completions ### Interactive Chat (streaming)
Call to `/v1/completions` with json response Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash ```bash
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME> python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
``` ```
+15 -31
View File
@@ -18,7 +18,7 @@ logging.basicConfig(
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
# ---------------------- Prompts ---------------------- # ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by" COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = ( TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? " "Can you list the files in the current working directory and tell me what you see? "
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
# ---- OpenAI-compatible calls (non-streaming) ---- # ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]: async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=endpoint_name) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"model": model, "model": model,
@@ -111,9 +111,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, endpo
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"]) resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"])
return resp["response"] return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]: async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=endpoint_name) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"model": model, "model": model,
@@ -128,9 +128,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
return resp["response"] return resp["response"]
# ---- Streaming variants ---- # ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs): async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"model": model, "model": model,
@@ -144,9 +144,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, end
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"], stream=True) resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"], stream=True)
return resp["response"] # async generator return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs): async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"model": model, "model": model,
@@ -166,10 +166,9 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
class APIDemo: class APIDemo:
"""Demo and testing functionality for the API client""" """Demo and testing functionality for the API client"""
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None): def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
self.client = client self.client = client
self.model = model self.model = model
self.endpoint_name = endpoint_name
self.tool_manager = tool_manager or ToolManager() self.tool_manager = tool_manager or ToolManager()
# ----- Streaming handler ----- # ----- Streaming handler -----
@@ -178,16 +177,11 @@ class APIDemo:
reasoning_content = "" reasoning_content = ""
printed_reasoning = False printed_reasoning = False
printed_answer = False printed_answer = False
finish_reason = None
async for chunk in stream: async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0] choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {}) delta = choice.get("delta", {})
# Track finish reason
if choice.get("finish_reason"):
finish_reason = choice.get("finish_reason")
# reasoning tokens # reasoning tokens
rc = delta.get("reasoning_content") rc = delta.get("reasoning_content")
if rc and show_reasoning: if rc and show_reasoning:
@@ -217,8 +211,6 @@ class APIDemo:
print(f"Reasoning tokens: {len(reasoning_content.split())}") print(f"Reasoning tokens: {len(reasoning_content.split())}")
if printed_answer: if printed_answer:
print(f"Response tokens: {len(full_response.split())}") print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
return full_response return full_response
@@ -231,7 +223,6 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
prompt=COMPLETIONS_PROMPT, prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE, temperature=DEFAULT_TEMPERATURE,
) )
@@ -250,7 +241,6 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE temperature=DEFAULT_TEMPERATURE
) )
@@ -263,7 +253,6 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE temperature=DEFAULT_TEMPERATURE
) )
@@ -290,7 +279,6 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool, tools=minimal_tool,
tool_choice="none", tool_choice="none",
max_tokens=10 max_tokens=10
@@ -316,7 +304,6 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
tools=self.tool_manager.get_ls_tool_definition(), tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto", tool_choice="auto",
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
@@ -394,7 +381,6 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE, temperature=DEFAULT_TEMPERATURE,
) )
@@ -433,6 +419,7 @@ class APIDemo:
print("=" * 60) print("=" * 60)
print("INTERACTIVE STREAMING CHAT") print("INTERACTIVE STREAMING CHAT")
print("=" * 60) print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history") print("Type 'quit' to exit, 'clear' to clear history")
print() print()
@@ -459,7 +446,6 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=0.7 temperature=0.7
) )
@@ -479,8 +465,8 @@ class APIDemo:
# ---------------------- CLI ---------------------- # ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser: def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)") p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})") p.add_argument("--model", required=True, help="Model to use for requests (required)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})") p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
modes = p.add_mutually_exclusive_group(required=False) modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint") modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
@@ -508,14 +494,12 @@ async def main_async():
print("Please specify exactly one test mode") print("Please specify exactly one test mode")
sys.exit(1) sys.exit(1)
print("=" * 60)
print(f"Using model: {args.model}") print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}") print("=" * 60)
try: try:
async with Serverless() as client: async with Serverless() as client:
demo = APIDemo(client, args.model, args.endpoint, ToolManager()) demo = APIDemo(client, args.model, ToolManager())
if args.completion: if args.completion:
await demo.demo_completions() await demo.demo_completions()
+4 -12
View File
@@ -28,12 +28,6 @@ MODEL_INFO_LOG_MSGS = [
nltk.download("words") nltk.download("words")
WORD_LIST = nltk.corpus.words.words() WORD_LIST = nltk.corpus.words.words()
def request_parser(request):
data = request
if request.get("input") is not None:
data = request.get("input")
return data
def completions_benchmark_generator() -> dict: def completions_benchmark_generator() -> dict:
prompt = " ".join(random.choices(WORD_LIST, k=int(250))) prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
@@ -60,20 +54,18 @@ worker_config = WorkerConfig(
route="/v1/completions", route="/v1/completions",
workload_calculator= lambda data: data.get("max_tokens", 0), workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True, allow_parallel_requests=True,
request_parser=request_parser, max_queue_time=60.0,
max_queue_time=600.0,
benchmark_config=BenchmarkConfig( benchmark_config=BenchmarkConfig(
generator=completions_benchmark_generator, generator=completions_benchmark_generator,
concurrency=10, concurrency=100,
runs=3 runs=2
) )
), ),
HandlerConfig( HandlerConfig(
route="/v1/chat/completions", route="/v1/chat/completions",
workload_calculator= lambda data: data.get("max_tokens", 0), workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True, allow_parallel_requests=True,
request_parser=request_parser, max_queue_time=60.0,
max_queue_time=600.0,
) )
], ],
log_action_config=LogActionConfig( log_action_config=LogActionConfig(
+9 -93
View File
@@ -1,103 +1,19 @@
# HuggingFace TGI PyWorker This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's. 1. `generate`: Generates the LLM's response to a given prompt in a single request.
2. `generate_stream`: Streams the LLM's response token by token.
## Instance Setup Both endpoints use the following API payload format:
1. Pick a template
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
The template can be configured via the template interface. You may want to change the model or startup arguments.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
## Using the Test Client
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
First, set your API key as an environment variable:
```bash
export VAST_API_KEY=<your_api_key>
```
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
### Generate (Streaming)
Call to `/generate_stream` with streaming response:
```bash
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
```
### Generate (Non-Streaming)
Call to `/generate` with json response:
```bash
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
```
### Interactive Session (Streaming)
Interactive session with streaming responses. Type `quit` to exit.
```bash
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
```
## API Endpoints
TGI provides two primary endpoints:
### Generate (Non-Streaming)
`/generate` - Returns the complete response in a single request.
```json ```json
{ {
"inputs": "Your prompt here", "inputs": "PROMPT",
"parameters": { "parameters": {
"max_new_tokens": 1024, "max_new_tokens": 250
"temperature": 0.7,
"return_full_text": false
} }
} }
``` ```
### Generate Stream (Streaming) Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
`/generate_stream` - Streams the response token by token. approximately 2 seconds to complete.
```json
{
"inputs": "Your prompt here",
"parameters": {
"max_new_tokens": 1024,
"temperature": 0.7,
"do_sample": true,
"return_full_text": false
}
}
```
## Performance Notes
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
+33 -194
View File
@@ -1,222 +1,61 @@
import logging
import json
import os
import sys
import argparse
from vastai import Serverless from vastai import Serverless
import asyncio import asyncio
# ---------------------- Logging ---------------------- ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
# ---------------------- Defaults ----------------------
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
MAX_TOKENS = 1024 MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7 PROMPT = "Think step by step: Tell me about the Python programming language."
async def call_generate(client: Serverless) -> None:
# ---------------------- API Calls ---------------------- endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
"""Non-streaming generation via /generate endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"inputs": prompt, "inputs": PROMPT,
"parameters": { "parameters": {
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS), "max_new_tokens": MAX_TOKENS,
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), "temperature": 0.7,
"return_full_text": False, "return_full_text": False
} }
} }
log.debug("POST /generate %s", json.dumps(payload)[:500])
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"]) resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
return resp["response"]
print(resp["response"]["generated_text"])
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs): async def call_generate_stream(client: Serverless) -> None:
"""Streaming generation via /generate_stream endpoint""" endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"inputs": prompt, "inputs": PROMPT,
"parameters": { "parameters": {
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS), "max_new_tokens": MAX_TOKENS,
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), "temperature": 0.7,
"do_sample": True, "do_sample": True,
"return_full_text": False, "return_full_text": False,
} }
} }
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
resp = await endpoint.request( resp = await endpoint.request(
"/generate_stream", "/generate_stream",
payload, payload,
cost=payload["parameters"]["max_new_tokens"], cost=MAX_TOKENS,
stream=True, stream=True,
) )
return resp["response"] # async generator stream = resp["response"]
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("Answer:\n", end="", flush=True)
print(tok, end="", flush=True)
# ---------------------- Demo Runner ---------------------- async def main():
class APIDemo: async with Serverless() as client:
"""Demo and testing functionality for the TGI API client""" await call_generate(client)
await call_generate_stream(client)
def __init__(self, client: Serverless, endpoint_name: str):
self.client = client
self.endpoint_name = endpoint_name
async def handle_streaming_response(self, stream) -> str:
"""Process streaming response and print tokens"""
full_response = ""
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(tok, end="", flush=True)
full_response += tok
print() # newline
if printed_answer:
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
return full_response
async def demo_generate(self) -> None:
"""Demo non-streaming generation"""
print("=" * 60)
print("GENERATE DEMO (NON-STREAMING)")
print("=" * 60)
response = await call_generate(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print(f"\n💬 Response: {response.get('generated_text', '')}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def demo_generate_stream(self) -> None:
"""Demo streaming generation"""
print("=" * 60)
print("GENERATE DEMO (STREAMING)")
print("=" * 60)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
try:
await self.handle_streaming_response(stream)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
async def interactive_chat(self) -> None:
"""Interactive session with streaming generation"""
print("=" * 60)
print("INTERACTIVE STREAMING SESSION")
print("=" * 60)
print(f"Using endpoint: {self.endpoint_name}")
print("Type 'quit' to exit")
print()
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() == "quit":
print("👋 Goodbye!")
break
elif not user_input:
continue
print("Assistant: ", end="", flush=True)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=user_input,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
full_response = ""
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
print(tok, end="", flush=True)
full_response += tok
print() # newline
except KeyboardInterrupt:
print("\n👋 Session interrupted. Goodbye!")
break
except Exception as e:
log.error("\nError: %s", e)
continue
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
return p
async def main_async():
args = build_arg_parser().parse_args()
selected = sum([args.generate, args.generate_stream, args.interactive])
if selected == 0:
print("Please specify exactly one test mode:")
print(" --generate : Test generate endpoint (non-streaming)")
print(" --generate-stream : Test generate endpoint with streaming")
print(" --interactive : Start interactive streaming session")
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
sys.exit(1)
elif selected > 1:
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint)
if args.generate:
await demo.demo_generate()
elif args.generate_stream:
await demo.demo_generate_stream()
elif args.interactive:
await demo.interactive_chat()
except Exception as e:
log.error("Error during test: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main_async()) asyncio.run(main())
+3 -4
View File
@@ -52,18 +52,17 @@ worker_config = WorkerConfig(
HandlerConfig( HandlerConfig(
route="/generate", route="/generate",
allow_parallel_requests=True, allow_parallel_requests=True,
max_queue_time=600.0, max_queue_time=60.0,
benchmark_config=BenchmarkConfig( benchmark_config=BenchmarkConfig(
generator=benchmark_generator, generator=benchmark_generator,
concurrency=10, concurrency=50
runs=3
), ),
workload_calculator= lambda x: x["parameters"]["max_new_tokens"] workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
), ),
HandlerConfig( HandlerConfig(
route="/generate_stream", route="/generate_stream",
allow_parallel_requests=True, allow_parallel_requests=True,
max_queue_time=600.0, max_queue_time=60.0,
workload_calculator= lambda x: x["parameters"]["max_new_tokens"] workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
) )
], ],
+1 -1
View File
@@ -2,7 +2,7 @@
This is the PyWorker implementation for running **Wan 2.2 T2V A14B** text-to-video workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI video-generation workflows through a proxy-based architecture and returning generated video assets. This is the PyWorker implementation for running **Wan 2.2 T2V A14B** text-to-video workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI video-generation workflows through a proxy-based architecture and returning generated video assets.
Each request has a static cost of `10000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node. Each request has a static cost of `100`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
## Requirements ## Requirements