Compare commits

..

1 Commits

Author SHA1 Message Date
Edgar Lin 7d43bc8d68 remove redis pubsub from pyworker 2025-10-29 11:46:31 -07:00
17 changed files with 785 additions and 1272 deletions
+1 -2
View File
@@ -2,5 +2,4 @@
.envrc .envrc
__pycache__ __pycache__
bin/ bin/
lib64 lib64
.venv
+3 -4
View File
@@ -39,12 +39,11 @@ reporting these metrics to the autoscaler.
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few: If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d) * **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00)
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d) * **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938)
Currently available workers: Currently available workers:
* `openai`: A simple example worker for a basic vLLM server. * `hello_world`: A simple example worker for a basic LLM server.
* `comfyui`: A worker for the ComfyUI image generation backend. * `comfyui`: A worker for the ComfyUI image generation backend.
* `tgi`: A worker for the Text Generation Inference backend. * `tgi`: A worker for the Text Generation Inference backend.
+24 -27
View File
@@ -30,7 +30,7 @@ from lib.data_types import (
BenchmarkResult BenchmarkResult
) )
VERSION = "0.2.1" VERSION = "0.1.0"
MSG_HISTORY_LEN = 100 MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
@@ -66,17 +66,10 @@ class Backend:
unsecured: bool = dataclasses.field( unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
) )
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
mtoken: str = dataclasses.field(
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
self.metrics._set_version(self.version) self.metrics._set_version(self.version)
self.metrics._set_mtoken(self.mtoken)
self._total_pubkey_fetch_errors = 0 self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False self.__start_healthcheck: bool = False
@@ -111,19 +104,23 @@ class Backend:
#######################################Private####################################### #######################################Private#######################################
def _fetch_pubkey(self): def _fetch_pubkey(self):
report_addr = self.report_addr.rstrip("/") command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"] result = subprocess.check_output(command, universal_newlines=True)
try: log.debug("public key:")
result = subprocess.check_output(command, universal_newlines=True) log.debug(result)
log.debug("public key:") key = None
log.debug(result) for _ in range(5):
key = RSA.import_key(result) try:
if key is not None: key = RSA.import_key(result)
return key break
except (ValueError , subprocess.CalledProcessError) as e: except ValueError as e:
log.debug(f"Error downloading key: {e}") log.debug(f"Error downloading key: {e}")
self.backend_errored("Failed to get autoscaler pubkey") time.sleep(15)
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
async def __handle_request( async def __handle_request(
self, self,
@@ -318,10 +315,10 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f: with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark") log.debug("already ran benchmark")
# trigger model load # trigger model load
# payload = self.benchmark_handler.make_benchmark_payload() payload = self.benchmark_handler.make_benchmark_payload()
# _ = await self.__call_api( _ = await self.__call_api(
# handler=self.benchmark_handler, payload=payload handler=self.benchmark_handler, payload=payload
# ) )
return float(f.readline()) return float(f.readline())
except FileNotFoundError: except FileNotFoundError:
pass pass
@@ -396,7 +393,7 @@ class Backend:
) )
# some backends need a few seconds after logging successful startup before # some backends need a few seconds after logging successful startup before
# they can begin accepting requests # they can begin accepting requests
# await sleep(5) await sleep(5)
try: try:
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
self.__start_healthcheck = True self.__start_healthcheck = True
@@ -417,7 +414,7 @@ class Backend:
async def tail_log(): async def tail_log():
log.debug(f"tailing file: {self.model_log_file}") log.debug(f"tailing file: {self.model_log_file}")
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f: async with await open_file(self.model_log_file) as f:
while True: while True:
line = await f.readline() line = await f.readline()
if line: if line:
-1
View File
@@ -286,7 +286,6 @@ class AutoScalerData:
"""Data that is reported to autoscaler""" """Data that is reported to autoscaler"""
id: int id: int
mtoken: str
version: str version: str
loadtime: float loadtime: float
cur_load: float cur_load: float
+2 -16
View File
@@ -28,7 +28,6 @@ def get_url() -> str:
@dataclass @dataclass
class Metrics: class Metrics:
version: str = "0" version: str = "0"
mtoken: str = ""
last_metric_update: float = 0.0 last_metric_update: float = 0.0
last_request_served: float = 0.0 last_request_served: float = 0.0
update_pending: bool = False update_pending: bool = False
@@ -143,16 +142,12 @@ class Metrics:
def _set_version(self, version: str) -> None: def _set_version(self, version: str) -> None:
self.version = version self.version = version
def _set_mtoken(self, mtoken: str) -> None:
self.mtoken = mtoken
#######################################Private####################################### #######################################Private#######################################
async def __send_delete_requests_and_reset(self): async def __send_delete_requests_and_reset(self):
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool: async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
data = { data = {
"worker_id": self.id, "worker_id": self.id,
"mtoken": self.mtoken,
"request_idxs": idxs, "request_idxs": idxs,
"success": success_flag, "success": success_flag,
} }
@@ -214,7 +209,6 @@ class Metrics:
def compute_autoscaler_data() -> AutoScalerData: def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData( return AutoScalerData(
id=self.id, id=self.id,
mtoken=self.mtoken,
version=self.version, version=self.version,
loadtime=(loadtime_snapshot or 0.0), loadtime=(loadtime_snapshot or 0.0),
new_load=self.model_metrics.workload_processing, new_load=self.model_metrics.workload_processing,
@@ -234,25 +228,17 @@ class Metrics:
async def send_data(report_addr: str) -> bool: async def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data() data = compute_autoscaler_data()
log_data = asdict(data) full_path = report_addr.rstrip("/") + "/worker_status/"
def obfuscate(secret: str) -> str:
if secret is None:
return ""
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
log.debug( log.debug(
"\n".join( "\n".join(
[ [
"#" * 60, "#" * 60,
f"sending data to autoscaler", f"sending data to autoscaler",
f"{json.dumps(log_data, indent=2)}", f"{json.dumps((asdict(data)), indent=2)}",
"#" * 60, "#" * 60,
] ]
) )
) )
full_path = report_addr.rstrip("/") + "/worker_status/"
for attempt in range(1, 4): for attempt in range(1, 4):
try: try:
session = await self.http() session = await self.http()
+25 -45
View File
@@ -3,58 +3,38 @@ import logging
from typing import List from typing import List
import ssl import ssl
from asyncio import run, gather from asyncio import run, gather
import asyncio
from lib.backend import Backend from lib.backend import Backend
from lib.metrics import Metrics
from aiohttp import web from aiohttp import web
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs): def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
try: log.debug("getting certificate...")
log.debug("getting certificate...") use_ssl = os.environ.get("USE_SSL", "false") == "true"
use_ssl = os.environ.get("USE_SSL", "false") == "true" if use_ssl is True:
if use_ssl is True: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_context.load_cert_chain(
ssl_context.load_cert_chain( certfile="/etc/instance.crt",
certfile="/etc/instance.crt", keyfile="/etc/instance.key",
keyfile="/etc/instance.key", )
) else:
else: ssl_context = None
ssl_context = None
async def main(): async def main():
log.debug("starting server...") log.debug("starting server...")
app = web.Application() app = web.Application()
app.add_routes(routes) app.add_routes(routes)
runner = web.AppRunner(app) runner = web.AppRunner(app)
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(
runner, runner,
ssl_context=ssl_context, ssl_context=ssl_context,
port=int(os.environ["WORKER_PORT"]), port=int(os.environ["WORKER_PORT"]),
**kwargs **kwargs
) )
await gather(site.start(), backend._start_tracking()) await gather(site.start(), backend._start_tracking())
run(main()) run(main())
except Exception as e:
err_msg = f"PyWorker failed to launch: {e}"
log.error(err_msg)
async def beacon():
metrics = Metrics()
metrics._set_version(getattr(backend, "version", "0"))
metrics._set_mtoken(getattr(backend, "mtoken", ""))
try:
while True:
metrics._model_errored(err_msg)
await metrics._Metrics__send_metrics_and_reset()
await asyncio.sleep(10)
finally:
await metrics.aclose()
run(beacon())
+1 -4
View File
@@ -1,6 +1,4 @@
aiohttp==3.10.1 aiohttp[speedups]==3.10.1
aiodns~=3.6.0
pycares~=4.11.0
anyio~=4.4 anyio~=4.4
lib~=4.0 lib~=4.0
nltk~=3.9 nltk~=3.9
@@ -10,4 +8,3 @@ Requests~=2.32
transformers~=4.52 transformers~=4.52
utils==1.0.* utils==1.0.*
hf_transfer>=0.1.9 hf_transfer>=0.1.9
vastai-sdk>=0.2.0
+5 -47
View File
@@ -41,14 +41,6 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG echo_var PYWORKER_LOG
echo_var MODEL_LOG echo_var MODEL_LOG
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
if [ -e "$MODEL_LOG" ]; then
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
cat "$MODEL_LOG" >> "$MODEL_LOG.old"
: > "$MODEL_LOG"
fi
# Populate /etc/environment with quoted values # Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then if ! grep -q "VAST" /etc/environment; then
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
@@ -132,43 +124,9 @@ cd "$SERVER_DIR"
echo "launching PyWorker server" echo "launching PyWorker server"
set +e # if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG" # from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
PY_STATUS=${PIPESTATUS[0]} [ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
set -e
if [ "${PY_STATUS}" -ne 0 ]; then (python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..." echo "launching PyWorker server done"
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
MTOKEN="${MASTER_TOKEN:-}"
VERSION="${PYWORKER_VERSION:-0}"
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
for addr in "${REPORT_ADDRS[@]}"; do
curl -sS -X POST -H 'Content-Type: application/json' \
-d "$(cat <<JSON
{
"id": ${CONTAINER_ID:-0},
"mtoken": "${MTOKEN}",
"version": "${VERSION}",
"loadtime": 0,
"new_load": 0,
"cur_load": 0,
"rej_load": 0,
"max_perf": 0,
"cur_perf": 0,
"error_msg": "${ERROR_MSG}",
"num_requests_working": 0,
"num_requests_recieved": 0,
"additional_disk_usage": 0,
"working_request_idxs": [],
"cur_capacity": 0,
"max_capacity": 0,
"url": "${URL}"
}
JSON
)" "${addr%/}/worker_status/" || true
done
fi
echo "launching PyWorker server done"
+10 -92
View File
@@ -1,16 +1,8 @@
# ComfyUI PyWorker # ComfyUI PyWorker
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's. This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture.
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node. The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
## Instance Setup
1. Pick a template
- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless))
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Requirements ## Requirements
@@ -18,88 +10,6 @@ This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) a
A docker image is provided but you may use any if the above requirements are met. A docker image is provided but you may use any if the above requirements are met.
## Client
The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage.
### Setup
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
2. Set your API key:
```bash
export VAST_API_KEY=<your_api_key>
```
### Usage
```bash
# Default prompt
python -m workers.comfyui-json.client
# Custom prompt
python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow"
# With options
python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30
# Using a custom workflow file
python -m workers.comfyui-json.client --workflow my_workflow.json
# With S3 upload
python -m workers.comfyui-json.client --s3
```
### CLI Flags
| Flag | Default | Description |
|------|---------|-------------|
| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name |
| `--prompt` | (default) | Text prompt for image generation |
| `--workflow` | (none) | Path to custom workflow JSON file |
| `--width` | 512 | Image width in pixels |
| `--height` | 512 | Image height in pixels |
| `--steps` | 20 | Number of denoising steps |
| `--seed` | (random) | Random seed for reproducibility |
| `--s3` | (disabled) | Upload generated images to S3 |
### Output
Images are saved to `./generated_images/comfy_{seed}.png`.
### S3 Upload (Optional)
You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag.
**1. Set environment variables:**
```bash
export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com"
export S3_BUCKET_NAME="my-bucket"
export S3_ACCESS_KEY_ID="your-access-key-id"
export S3_SECRET_ACCESS_KEY="your-secret-access-key"
```
**2. Run with S3 upload enabled:**
```bash
python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3
```
Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`.
**Note:** Requires `boto3` (`pip install boto3`).
## Benchmarking ## Benchmarking
### Custom Benchmark Workflows ### Custom Benchmark Workflows
@@ -302,3 +212,11 @@ WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
} }
} }
``` ```
## Client Libraries
See the test client examples for implementation details on how to integrate with the ComfyUI worker.
---
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
+136 -293
View File
@@ -1,312 +1,155 @@
import os import logging
import sys
import json
import uuid import uuid
import random import random
import asyncio from urllib.parse import urljoin
import logging import json
import argparse
import aiohttp
from vastai import Serverless import requests
# ---------------------- Config ---------------------- from lib.test_utils import print_truncate_res
DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed" from utils.endpoint_util import Endpoint
ENDPOINT_NAME = "my-comfyui-endpoint" from utils.ssl import get_cert_file_path
DEFAULT_WIDTH = 512 from .data_types import count_workload
DEFAULT_HEIGHT = 512
DEFAULT_STEPS = 20
COST = 100 # Fixed cost for ComfyUI requests
# Optional S3 Configuration (from environment variables) logging.basicConfig(
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL") level=logging.DEBUG,
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") format="%(asctime)s[%(levelname)-5s] %(message)s",
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID") datefmt="%Y-%m-%d %H:%M:%S",
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY") )
log = logging.getLogger(__file__)
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
log = logging.getLogger(__name__)
def get_s3_client(): def call_text2image_workflow(
"""Create and return an S3 client configured for the S3-compatible endpoint""" endpoint_group_name: str, api_key: str, server_url: str
try: ) -> None:
import boto3 """Simple Text2Image using the new modifier-based approach"""
from botocore.config import Config
except ImportError: def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
log.error("boto3 is required for S3 uploads. Install with: pip install boto3") """Helper function for making requests with consistent error handling"""
return None try:
response = requests.post(
url,
json=payload,
timeout=timeout,
verify=verify
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred during {context}: {http_err}")
log.error(f"Status Code: {response.status_code}")
log.error("Response content:", response.text)
return None
except requests.exceptions.Timeout:
log.error(f"Timeout occurred during {context}: {url}")
return None
except requests.exceptions.ConnectionError:
log.error(f"Connection error occurred during {context}: {url}")
return None
except json.JSONDecodeError as json_err:
log.error(f"Failed to decode JSON response during {context}: {json_err}")
if 'response' in locals():
print("Response content:", response.text)
return None
except Exception as err:
log.error(f"An unexpected error occurred during {context}: {err}")
if 'response' in locals():
log.error("Response content (if available):", response.text)
return None
WORKER_ENDPOINT = "/generate/sync"
if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]): # This worker has concurrency = 1. All workloads have cost value 1.0
log.error("S3 environment variables not fully configured. Required:") COST = count_workload()
log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY")
return None # Route to get worker URL
route_payload = {
return boto3.client( "endpoint": endpoint_group_name,
"s3", "api_key": api_key,
endpoint_url=S3_ENDPOINT_URL, "cost": COST,
aws_access_key_id=S3_ACCESS_KEY_ID, }
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config=Config(signature_version="s3v4"), # First request - get routing information
route_response = make_request(
url=urljoin(server_url, "/route/"),
payload=route_payload,
timeout=4,
context="route request"
) )
if route_response is None:
# ---------------------- API Functions ---------------------- return None
async def call_generate(
client: Serverless, if "url" not in route_response or not route_response["url"]:
*, log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
endpoint_name: str, return None
prompt: str,
width: int, if "status" in route_response:
height: int, print(f"Autoscaler status: {route_response['status']}")
steps: int, return None
seed: int,
) -> dict: # Extract data from route response
"""Generate image using Text2Image modifier""" url = route_response["url"]
endpoint = await client.get_endpoint(name=endpoint_name) auth_data = dict(
payload = { signature=route_response["signature"],
cost=route_response["cost"],
endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"],
url=route_response["url"],
)
# Build the payload for the worker request
worker_payload = {
"input": { "input": {
"request_id": str(uuid.uuid4()), "request_id": str(uuid.uuid4()),
"modifier": "Text2Image", "modifier": "Text2Image",
"modifications": { "modifications": {
"prompt": prompt, "prompt": "a beautiful landscape with mountains and lakes",
"width": width, "width": 1024,
"height": height, "height": 1024,
"steps": steps, "steps": 20,
"seed": seed, "seed": random.randint(0, 2**32 - 1)
}, },
"workflow_json": {} # Empty since using modifier approach
} }
} }
return await endpoint.request("/generate/sync", payload, cost=COST)
req_data = dict(payload=worker_payload, auth_data=auth_data)
worker_url = urljoin(url, WORKER_ENDPOINT)
async def call_generate_workflow( print(f"url: {worker_url}")
client: Serverless,
*, # Second request - call the worker endpoint
endpoint_name: str, worker_response = make_request(
workflow_json: dict, url=worker_url,
) -> dict: payload=req_data,
"""Generate using custom workflow JSON""" verify=get_cert_file_path(),
endpoint = await client.get_endpoint(name=endpoint_name) context="worker request"
payload = { )
"input": {
"request_id": str(uuid.uuid4()), return worker_response
"workflow_json": workflow_json,
}
}
return await endpoint.request("/generate/sync", payload, cost=COST)
# ---------------------- Demo Class ----------------------
class APIDemo:
def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False):
self.client = client
self.endpoint_name = endpoint_name
self.upload_s3 = upload_s3
self.s3_client = get_s3_client() if upload_s3 else None
if upload_s3 and not self.s3_client:
log.warning("S3 upload requested but client creation failed. Images will only be saved locally.")
def extract_filename(self, response: dict) -> str | None:
"""Extract the generated image filename from ComfyUI response"""
if "comfyui_response" in response:
for data in response["comfyui_response"].values():
if isinstance(data, dict) and "outputs" in data:
for node_output in data["outputs"].values():
if "images" in node_output and node_output["images"]:
return node_output["images"][0].get("filename")
return None
async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch and save image locally from the worker, optionally upload to S3"""
os.makedirs("generated_images", exist_ok=True)
return await self._fetch_image(worker_url, filename, local_name)
def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None:
"""Upload a local file to S3 and return the S3 URL"""
if not self.s3_client:
return None
try:
self.s3_client.upload_file(
local_path,
S3_BUCKET_NAME,
s3_key,
ExtraArgs={"ContentType": "image/png"}
)
s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}"
print(f" ☁️ Uploaded to S3: {s3_key}")
return s3_url
except Exception as e:
log.error(f"Failed to upload to S3: {e}")
return None
async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch image from worker's /view endpoint and save locally"""
if not worker_url:
return None
try:
url = f"{worker_url}/view"
params = {"filename": filename, "type": "output"}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, ssl=False) as resp:
if resp.status == 200:
path = f"generated_images/{local_name}"
image_data = await resp.read()
with open(path, "wb") as f:
f.write(image_data)
print(f" 💾 Saved: {path}")
# Upload to S3 if enabled
if self.upload_s3 and self.s3_client:
s3_key = f"comfyui/{local_name}"
self._upload_to_s3(path, s3_key)
return path
return None
except Exception:
return None
async def demo_prompt(
self,
prompt: str,
width: int,
height: int,
steps: int,
seed: int | None,
):
"""Demo: Generate image from text prompt"""
print("=" * 60)
print("COMFYUI TEXT-TO-IMAGE DEMO")
print("=" * 60)
if seed is None:
seed = random.randint(0, 2**32 - 1)
print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}")
print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}")
print("\n🎨 Generating image...")
response = await call_generate(
self.client,
endpoint_name=self.endpoint_name,
prompt=prompt,
width=width,
height=height,
steps=steps,
seed=seed,
)
print("\n✅ Generation complete!")
# Get worker URL for fetching images
worker_url = response.get("url", "")
print(f"Worker URL: {worker_url}")
# Fetch and save image
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, f"comfy_{seed}.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
async def demo_workflow(self, workflow_file: str):
"""Demo: Generate using custom workflow file"""
print("=" * 60)
print("COMFYUI CUSTOM WORKFLOW DEMO")
print("=" * 60)
if not os.path.exists(workflow_file):
log.error(f"Workflow file not found: {workflow_file}")
return
with open(workflow_file, "r") as f:
workflow_json = json.load(f)
print(f"Workflow: {workflow_file}")
print("\n🎨 Generating...")
response = await call_generate_workflow(
self.client,
endpoint_name=self.endpoint_name,
workflow_json=workflow_json,
)
print("\n✅ Generation complete!")
worker_url = response.get("url", "")
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, "workflow.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT",
help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')")
p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead")
p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})")
p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})")
p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})")
p.add_argument("--seed", type=int, default=None, help="Seed (default: random)")
p.add_argument("--s3", action="store_true",
help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)")
return p
async def main_async():
args = build_arg_parser().parse_args()
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
if args.s3:
print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint, upload_s3=args.s3)
if args.workflow:
await demo.demo_workflow(workflow_file=args.workflow)
else:
await demo.demo_prompt(
prompt=args.prompt,
width=args.width,
height=args.height,
steps=args.steps,
seed=args.seed,
)
except AttributeError as e:
if "API key" in str(e):
log.error("API key missing. Set VAST_API_KEY environment variable.")
else:
log.error(f"Error: {e}")
sys.exit(1)
except Exception as e:
log.error(f"Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main_async()) from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
result = call_text2image_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
if result is None:
log.error("Text2Image workflow failed")
else:
print(result)
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
-33
View File
@@ -4,7 +4,6 @@ import dataclasses
import base64 import base64
from typing import Optional, Union, Type from typing import Optional, Union, Type
import aiohttp
from aiohttp import web, ClientResponse from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction from lib.backend import Backend, LogAction
@@ -14,7 +13,6 @@ from .data_types import ComfyWorkflowData
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288") MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:18188") # Raw ComfyUI server
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded # This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: " MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
@@ -110,39 +108,8 @@ async def handle_ping(_):
return web.Response(body="pong") return web.Response(body="pong")
async def handle_view(request: web.Request) -> web.Response:
"""Proxy /view requests to raw ComfyUI server to fetch generated images"""
# Forward query params to raw ComfyUI (not the API wrapper)
query_string = request.query_string
url = f"{COMFYUI_URL}/view?{query_string}"
log.debug(f"Proxying /view request to: {url}")
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status == 200:
content = await resp.read()
return web.Response(
body=content,
status=200,
content_type=resp.content_type or "image/png"
)
else:
text = await resp.text()
return web.Response(
text=text,
status=resp.status,
content_type="text/plain"
)
except Exception as e:
log.error(f"Error proxying /view: {e}")
return web.Response(text=str(e), status=500)
routes = [ routes = [
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())), web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
web.get("/view", handle_view),
web.get("/ping", handle_ping), web.get("/ping", handle_ping),
] ]
+12 -6
View File
@@ -7,13 +7,20 @@ from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path from utils.ssl import get_cert_file_path
from vastai import Serverless """
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
"""
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
ENDPOINT_NAME = "my-comfyui-endpoint" def call_default_workflow(
COST = 100 # Use a constant cost for image generation endpoint_group_name: str, api_key: str, server_url: str
) -> None:
def call_default_workflow(client: Serverless) -> None:
WORKER_ENDPOINT = "/prompt" WORKER_ENDPOINT = "/prompt"
COST = 100 COST = 100
route_payload = { route_payload = {
@@ -75,7 +82,6 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"], endpoint=message["endpoint"],
reqnum=message["reqnum"], reqnum=message["reqnum"],
url=message["url"], url=message["url"],
request_idx=message["request_idx"],
) )
workflow = { workflow = {
"3": { "3": {
+22 -29
View File
@@ -8,13 +8,14 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker. This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended) - [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless)) - [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected. All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. 2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo) ## Client Setup (Demo)
@@ -33,30 +34,12 @@ uv pip install -r requirements.txt
Several examples have been provided in the client to help you get started with your own implementation. Several examples have been provided in the client to help you get started with your own implementation.
First, set your API key as an environment variable: ### Completions
Call to `/v1/completions` with json response
```bash ```bash
export VAST_API_KEY=<your_api_key> python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
```
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Interactive Chat (streaming)
Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
``` ```
### Chat Completion (json) ### Chat Completion (json)
@@ -64,7 +47,15 @@ python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model
Call to `/v1/chat/completions` with json response Call to `/v1/chat/completions` with json response
```bash ```bash
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME> python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
``` ```
### Tool Use (json) ### Tool Use (json)
@@ -74,14 +65,16 @@ Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model. This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash ```bash
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME> python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
``` ```
### Completions ### Interactive Chat (streaming)
Call to `/v1/completions` with json response Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash ```bash
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME> python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
``` ```
+430 -376
View File
@@ -1,15 +1,14 @@
import logging import logging
import json
import os
import sys import sys
import json
import subprocess import subprocess
import argparse from urllib.parse import urljoin
from typing import Any, Dict, List, Optional from typing import Dict, Any, Optional, Iterator, Union, List
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig
from vastai import Serverless
import asyncio
# ---------------------- Logging ----------------------
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s", format="%(asctime)s[%(levelname)-5s] %(message)s",
@@ -17,20 +16,135 @@ logging.basicConfig(
) )
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
# ---------------------- Prompts ---------------------- COMPLETIONS_PROMPT = "the capital of USA is"
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = ( TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
"Can you list the files in the current working directory and tell me what you see? "
"What do you think this directory might be for?"
) class APIClient:
"""Lightweight client focused solely on API communication"""
# Remove the generic WORKER_ENDPOINT since we're now going direct
DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4
def __init__(
self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
):
self.endpoint_group_name = endpoint_group_name
self.api_key = api_key
self.server_url = server_url
self.endpoint_api_key = endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service"""
if not self.endpoint_api_key:
raise ValueError("No valid endpoint API key available")
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.endpoint_api_key,
"cost": cost,
}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
timeout=self.DEFAULT_TIMEOUT,
)
response.raise_for_status()
return response.json()
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Create auth data from routing response"""
return {
"signature": message["signature"],
"cost": message["cost"],
"endpoint": message["endpoint"],
"reqnum": message["reqnum"],
"url": message["url"],
}
def _make_request(
self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint"""
# Get worker URL and auth data
cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost)
worker_url = message["url"]
auth_data = self._create_auth_data(message)
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}")
log.debug(f"Payload: {req_data}")
# Make the request using the specified method
if method.upper() == "POST":
response = requests.post(
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET":
response = requests.get(
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
if stream:
return self._handle_streaming_response(response)
else:
return response.json()
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
"""Handle streaming response and yield tokens"""
try:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
yield data # Yield the full chunk
except json.JSONDecodeError:
continue
except Exception as e:
log.error(f"Error handling streaming response: {e}")
raise
def call_completions(
self, config: CompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/completions", stream=config.stream
)
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
)
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- Tooling ----------------------
class ToolManager: class ToolManager:
"""Handles tool definitions and execution""" """Handles tool definitions and execution"""
@@ -50,7 +164,7 @@ class ToolManager:
@staticmethod @staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]: def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""OpenAI-compatible tool schema""" """Get the ls tool definition"""
return [ return [
{ {
"type": "function", "type": "function",
@@ -64,228 +178,98 @@ class ToolManager:
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str: def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result""" """Execute a tool call and return the result"""
function_name = (tool_call.get("function") or {}).get("name") function_name = tool_call["function"]["name"]
if function_name == "list_files": if function_name == "list_files":
return self.list_files() return self.list_files()
raise ValueError(f"Unknown tool function: {function_name}") else:
raise ValueError(f"Unknown tool function: {function_name}")
# ----- Helpers to handle streamed tool_calls assembly -----
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
"""
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
We merge into a per-index state dict until the assistant message finishes.
"""
idx = tc_delta.get("index")
if idx is None:
return
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
if tc_delta.get("id"):
entry["id"] = tc_delta["id"]
fn_delta = tc_delta.get("function") or {}
if "name" in fn_delta and fn_delta["name"]:
entry["function"]["name"] = fn_delta["name"]
if "arguments" in fn_delta and fn_delta["arguments"]:
entry["function"]["arguments"] += fn_delta["arguments"]
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
return [state[i] for i in sorted(state.keys())]
# ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
}
}
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
# ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
}
}
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
# ---------------------- Demo Runner ----------------------
class APIDemo: class APIDemo:
"""Demo and testing functionality for the API client""" """Demo and testing functionality for the API client"""
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None): def __init__(
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
self.client = client self.client = client
self.model = model self.model = model
self.endpoint_name = endpoint_name
self.tool_manager = tool_manager or ToolManager() self.tool_manager = tool_manager or ToolManager()
# ----- Streaming handler ----- def handle_streaming_response(
async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str: self, response_stream, show_reasoning: bool = True
) -> str:
"""
Handle streaming chat response and display all output.
"""
full_response = "" full_response = ""
reasoning_content = "" reasoning_content = ""
printed_reasoning = False reasoning_started = False
printed_answer = False content_started = False
finish_reason = None
async for chunk in stream: for chunk in response_stream:
choice = (chunk.get("choices") or [{}])[0] # Normalize the chunk
delta = choice.get("delta", {}) if isinstance(chunk, str):
chunk = chunk.strip()
# Track finish reason if chunk.startswith("data: "):
if choice.get("finish_reason"): chunk = chunk[6:].strip()
finish_reason = choice.get("finish_reason") if chunk in ["[DONE]", ""]:
continue
try:
parsed_chunk = json.loads(chunk)
except json.JSONDecodeError:
continue
elif isinstance(chunk, dict):
parsed_chunk = chunk
else:
continue
# reasoning tokens # Parse delta from the chunk
rc = delta.get("reasoning_content") choices = parsed_chunk.get("choices", [])
if rc and show_reasoning: if not choices:
if not printed_reasoning: continue
delta = choices[0].get("delta", {})
reasoning_token = delta.get("reasoning_content", "")
content_token = delta.get("content", "")
# Print reasoning token if applicable
if show_reasoning and reasoning_token:
if not reasoning_started:
print("\n🧠 Reasoning: ", end="", flush=True) print("\n🧠 Reasoning: ", end="", flush=True)
printed_reasoning = True reasoning_started = True
print(rc, end="", flush=True) print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
reasoning_content += rc reasoning_content += reasoning_token
# content tokens # Print content token
content_part = delta.get("content") if content_token:
if content_part: if not content_started:
if not printed_answer: if show_reasoning and reasoning_started:
if show_reasoning and printed_reasoning: print(f"\n💬 Response: ", end="", flush=True)
print("\n💬 Response: ", end="", flush=True)
else: else:
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
printed_answer = True content_started = True
print(content_part, end="", flush=True) print(content_token, end="", flush=True)
full_response += content_part full_response += content_token
print() # Ensure newline after response
print() # newline
if show_reasoning: if show_reasoning:
if printed_reasoning or printed_answer: if reasoning_started or content_started:
print("\nStreaming completed.") print("\nStreaming completed.")
if printed_reasoning: if reasoning_started:
print(f"Reasoning tokens: {len(reasoning_content.split())}") print(f"Reasoning tokens: {len(reasoning_content.split())}")
if printed_answer: if content_started:
print(f"Response tokens: {len(full_response.split())}") print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
return full_response return full_response
async def demo_completions(self) -> None:
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
response = await call_completions( def test_tool_support(self) -> bool:
client=self.client, """Test if the endpoint supports function calling"""
model=self.model, log.debug("Testing endpoint tool calling support...")
prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print("\nResponse:")
print(json.dumps(response, indent=2))
async def demo_chat(self, use_streaming: bool = True) -> None: # Try a simple request with minimal tools to test support
print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
print("=" * 60)
messages = [{"role": "user", "content": CHAT_PROMPT}]
if use_streaming:
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
try:
await self.handle_streaming_response(stream, show_reasoning=True)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
else:
response = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def test_tool_support(self) -> bool:
"""Probe that tool schema is accepted (no actual call)"""
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [ minimal_tool = [
{ {
@@ -293,158 +277,179 @@ class APIDemo:
"function": {"name": "test_function", "description": "Test function"}, "function": {"name": "test_function", "description": "Test function"},
} }
] ]
config = ChatCompletionConfig(
model=self.model,
messages=messages,
max_tokens=10,
tools=minimal_tool,
tool_choice="none", # Don't actually call the tool
)
try: try:
_ = await call_chat_completions( response = self.client.call_chat_completions(config)
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool,
tool_choice="none",
max_tokens=10
)
return True return True
except Exception as e: except Exception as e:
log.error("Endpoint does not support tool calling: %s", e) log.error(f"Error: Endpoint does not support tool calling: {e}")
return False return False
async def demo_ls_tool(self) -> None: def demo_completions(self) -> None:
"""Ask to list files using function calling, then provide final analysis""" """Demo: test basic completions endpoint"""
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
config = CompletionConfig(
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
)
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
)
response = self.client.call_completions(config)
if isinstance(response, dict):
print("\nResponse:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_chat(self, use_streaming: bool = True) -> None:
"""
Demo: test chat completions endpoint with optional streaming
"""
print("=" * 60)
print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60)
config = ChatCompletionConfig(
model=self.model,
messages=[{"role": "user", "content": CHAT_PROMPT}],
stream=use_streaming,
)
log.info(f"Testing chat completions with model '{self.model}'...")
response = self.client.call_chat_completions(config)
if use_streaming:
try:
self.handle_streaming_response(response, show_reasoning=True)
except Exception as e:
log.error(f"\nError during streaming: {e}")
import traceback
traceback.print_exc()
return
else:
if isinstance(response, dict):
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
print("=" * 60) print("=" * 60)
print("TOOL USE DEMO: List Directory Contents") print("TOOL USE DEMO: List Directory Contents")
print("=" * 60) print("=" * 60)
if not await self.test_tool_support(): # Test if tools are supported first
if not self.test_tool_support():
return return
messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}] # Request with tool available
messages = [{"role": "user", "content": TOOLS_PROMPT}]
# First pass: let the model decide tools, stream tool_calls and partial content config = ChatCompletionConfig(
stream = await stream_chat_completions(
client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
tools=self.tool_manager.get_ls_tool_definition(), tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto", tool_choice="auto",
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
) )
assistant_content_buf: List[str] = [] log.info(f"Making initial request with tool using model '{self.model}'...")
tool_calls_state: Dict[int, Dict[str, Any]] = {} response = self.client.call_chat_completions(config)
printed_reasoning = False
printed_answer = False
async for chunk in stream: if not isinstance(response, dict):
choice = (chunk.get("choices") or [{}])[0] raise ValueError("Expected dict response for tool use")
delta = choice.get("delta", {})
rc = delta.get("reasoning_content") choice = response.get("choices", [{}])[0]
if rc: message = choice.get("message", {})
if not printed_reasoning:
printed_reasoning = True
print("🧠 Reasoning: ", end="", flush=True)
print(rc, end="", flush=True)
content_part = delta.get("content") print(f"Assistant response: {message.get('content', 'No content')}")
if content_part:
assistant_content_buf.append(content_part)
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(content_part, end="", flush=True)
if "tool_calls" in delta and delta["tool_calls"]: # Check for tool calls
for tc_delta in delta["tool_calls"]: tool_calls = message.get("tool_calls")
_merge_tool_call_delta(tool_calls_state, tc_delta) if not tool_calls:
raise ValueError(
"No tool calls made - model may not support function calling"
)
# If no tool calls, were done. print(f"Tool calls detected: {len(tool_calls)}")
if not tool_calls_state:
print("\n(No tool calls were made.)")
return
# Build assistant message with tool_calls # Execute the tool call
assistant_message = { for tool_call in tool_calls:
"role": "assistant", function_name = tool_call["function"]["name"]
"content": "".join(assistant_content_buf) if assistant_content_buf else None, print(f"Executing tool: {function_name}")
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
}
messages.append(assistant_message)
# Execute tools and feed results back tool_result = self.tool_manager.execute_tool_call(tool_call)
for tc in assistant_message["tool_calls"]: print(f"Tool result:\n{tool_result}")
tool_name = (tc.get("function") or {}).get("name")
call_id = tc.get("id")
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
try: # Add tool result and continue conversation
args = json.loads(raw_args) if raw_args.strip() else {} messages.append(message) # Add assistant's message with tool call
except Exception as e: messages.append(
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args}) {
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result}) "role": "tool",
continue "tool_call_id": tool_call["id"],
"content": tool_result,
}
)
try: # Get final response
if tool_name == "list_files": final_config = ChatCompletionConfig(
tool_result = self.tool_manager.list_files() model=self.model,
else: messages=messages,
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"}) tools=self.tool_manager.get_ls_tool_definition(),
except Exception as e: )
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
print("\n[Tool executed]", tool_name) print("Getting final response...")
print(tool_result[:500] + ("..." if len(tool_result) > 500 else "")) final_response = self.client.call_chat_completions(final_config)
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
# Second pass: get final streamed answer after tool results if isinstance(final_response, dict):
stream2 = await stream_chat_completions( final_choice = final_response.get("choices", [{}])[0]
client=self.client, final_message = final_choice.get("message", {})
model=self.model, final_content = final_message.get("content", "")
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
final_buf = [] print("\n" + "=" * 60)
printed_reasoning2 = False print("FINAL LLM ANALYSIS:")
printed_answer2 = False print("=" * 60)
print(final_content)
print("=" * 60)
async for chunk in stream2: def interactive_chat(self) -> None:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
rc2 = delta.get("reasoning_content")
if rc2:
if not printed_reasoning2:
printed_reasoning2 = True
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
print(rc2, end="", flush=True)
c2 = delta.get("content")
if c2:
final_buf.append(c2)
if not printed_answer2:
printed_answer2 = True
print("\n💬 Response (final): ", end="", flush=True)
print(c2, end="", flush=True)
print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:")
print("=" * 60)
print("".join(final_buf))
print("=" * 60)
async def interactive_chat(self) -> None:
"""Interactive chat session with streaming""" """Interactive chat session with streaming"""
print("=" * 60) print("=" * 60)
print("INTERACTIVE STREAMING CHAT") print("INTERACTIVE STREAMING CHAT")
print("=" * 60) print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history") print("Type 'quit' to exit, 'clear' to clear history")
print() print()
messages: List[Dict[str, Any]] = [] messages = []
while True: while True:
try: try:
@@ -462,16 +467,16 @@ class APIDemo:
messages.append({"role": "user", "content": user_input}) messages.append({"role": "user", "content": user_input})
print("Assistant: ", end="", flush=True) config = ChatCompletionConfig(
stream = await stream_chat_completions( model=self.model, messages=messages, stream=True, temperature=0.7
client=self.client, )
model=self.model,
messages=messages, print("Assistant: ", end="", flush=True)
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, response = self.client.call_chat_completions(config)
temperature=0.7 assistant_content = self.handle_streaming_response(
response, show_reasoning=True
) )
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
# Add assistant response to conversation history # Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content}) messages.append({"role": "assistant", "content": assistant_content})
@@ -480,66 +485,115 @@ class APIDemo:
print("\n👋 Chat interrupted. Goodbye!") print("\n👋 Chat interrupted. Goodbye!")
break break
except Exception as e: except Exception as e:
log.error("\nError: %s", e) log.error(f"\nError: {e}")
continue continue
# ---------------------- CLI ---------------------- def main():
def build_arg_parser() -> argparse.ArgumentParser: """Main function with CLI switches for different tests"""
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)") from lib.test_utils import test_args
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False) # Add mandatory model argument
modes.add_argument("--completion", action="store_true", help="Test completions endpoint") test_args.add_argument(
modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)") "--model", required=True, help="Model to use for requests (required)"
modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming") )
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
return p
# Add test mode arguments
test_args.add_argument(
"--completion", action="store_true", help="Test completions endpoint"
)
test_args.add_argument(
"--chat",
action="store_true",
help="Test chat completions endpoint (non-streaming)",
)
test_args.add_argument(
"--chat-stream",
action="store_true",
help="Test chat completions endpoint with streaming",
)
test_args.add_argument(
"--tools",
action="store_true",
help="Test function calling with ls tool (non-streaming)",
)
test_args.add_argument(
"--interactive",
action="store_true",
help="Start interactive streaming chat session",
)
async def main_async(): args = test_args.parse_args()
args = build_arg_parser().parse_args()
selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive]) # Check that only one test mode is selected
if selected == 0: test_modes = [
args.completion,
args.chat,
args.chat_stream,
args.tools,
args.interactive,
]
selected_count = sum(test_modes)
if selected_count == 0:
print("Please specify exactly one test mode:") print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint") print(" --completion : Test completions endpoint")
print(" --chat : Test chat completions endpoint (non-streaming)") print(" --chat : Test chat completions endpoint (non-streaming)")
print(" --chat-stream : Test chat completions endpoint with streaming") print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool") print(" --tools : Test function calling with ls tool (non-streaming)")
print(" --interactive : Start interactive streaming chat session") print(" --interactive : Start interactive streaming chat session")
print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint") print(
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
sys.exit(1) sys.exit(1)
elif selected > 1: elif selected_count > 1:
print("Please specify exactly one test mode") print("Please specify exactly one test mode")
sys.exit(1) sys.exit(1)
print("=" * 60)
print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}")
try: try:
async with Serverless() as client: endpoint_api_key = Endpoint.get_endpoint_api_key(
demo = APIDemo(client, args.model, args.endpoint, ToolManager()) endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if args.completion: if not endpoint_api_key:
await demo.demo_completions() log.error(
elif args.chat: f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
await demo.demo_chat(use_streaming=False) )
elif args.chat_stream: sys.exit(1)
await demo.demo_chat(use_streaming=True)
elif args.tools: # Create the core API client
await demo.demo_ls_tool() client = APIClient(
elif args.interactive: endpoint_group_name=args.endpoint_group_name,
await demo.interactive_chat() api_key=args.api_key,
server_url=Endpoint.get_autoscaler_server_url(args.instance),
endpoint_api_key=endpoint_api_key,
)
# Create tool manager and demo (passing the model parameter)
tool_manager = ToolManager()
demo = APIDemo(client, args.model, tool_manager)
print(f"Using model: {args.model}")
print("=" * 60)
# Run the selected test
if args.completion:
demo.demo_completions()
elif args.chat:
demo.demo_chat(use_streaming=False)
elif args.chat_stream:
demo.demo_chat(use_streaming=True)
elif args.tools:
demo.demo_ls_tool()
elif args.interactive:
demo.interactive_chat()
except Exception as e: except Exception as e:
log.error("Error during test: %s", e, exc_info=True) log.error(f"Error during test: {e}", exc_info=True)
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main_async()) main()
-2
View File
@@ -11,7 +11,6 @@ MODEL_SERVER_START_LOG_MSG = [
"llama runner started", # Ollama "llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI '"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI '"message":"Connected","target":"text_generation_router::server"', # TGI
"main: model loaded" # llama.cpp
] ]
MODEL_SERVER_ERROR_LOG_MSGS = [ MODEL_SERVER_ERROR_LOG_MSGS = [
@@ -35,7 +34,6 @@ backend = Backend(
model_server_url=os.environ["MODEL_SERVER_URL"], model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"], model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True, allow_parallel_requests=True,
max_wait_time=600.0,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256), benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[ log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG], *[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
+9 -93
View File
@@ -1,103 +1,19 @@
# HuggingFace TGI PyWorker This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's. 1. `generate`: Generates the LLM's response to a given prompt in a single request.
2. `generate_stream`: Streams the LLM's response token by token.
## Instance Setup Both endpoints use the following API payload format:
1. Pick a template
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
The template can be configured via the template interface. You may want to change the model or startup arguments.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
## Using the Test Client
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
First, set your API key as an environment variable:
```bash
export VAST_API_KEY=<your_api_key>
```
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
### Generate (Streaming)
Call to `/generate_stream` with streaming response:
```bash
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
```
### Generate (Non-Streaming)
Call to `/generate` with json response:
```bash
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
```
### Interactive Session (Streaming)
Interactive session with streaming responses. Type `quit` to exit.
```bash
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
```
## API Endpoints
TGI provides two primary endpoints:
### Generate (Non-Streaming)
`/generate` - Returns the complete response in a single request.
```json ```json
{ {
"inputs": "Your prompt here", "inputs": "PROMPT",
"parameters": { "parameters": {
"max_new_tokens": 1024, "max_new_tokens": 250
"temperature": 0.7,
"return_full_text": false
} }
} }
``` ```
### Generate Stream (Streaming) Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
`/generate_stream` - Streams the response token by token. approximately 2 seconds to complete.
```json
{
"inputs": "Your prompt here",
"parameters": {
"max_new_tokens": 1024,
"temperature": 0.7,
"do_sample": true,
"return_full_text": false
}
}
```
## Performance Notes
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
+105 -202
View File
@@ -1,13 +1,11 @@
import logging import logging
import json
import os
import sys import sys
import argparse import json
from urllib.parse import urljoin
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from vastai import Serverless
import asyncio
# ---------------------- Logging ----------------------
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s", format="%(asctime)s[%(levelname)-5s] %(message)s",
@@ -15,208 +13,113 @@ logging.basicConfig(
) )
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
# ---------------------- Defaults ----------------------
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
MAX_TOKENS = 1024 WORKER_ENDPOINT = "/generate"
DEFAULT_TEMPERATURE = 0.7 COST = 100
route_payload = {
"endpoint": endpoint_group_name,
# ---------------------- API Calls ---------------------- "api_key": api_key,
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict: "cost": COST,
"""Non-streaming generation via /generate endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"return_full_text": False,
}
} }
log.debug("POST /generate %s", json.dumps(payload)[:500]) response = requests.post(
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"]) urljoin(server_url, "/route/"),
return resp["response"] json=route_payload,
timeout=4,
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
"""Streaming generation via /generate_stream endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"do_sample": True,
"return_full_text": False,
}
}
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
resp = await endpoint.request(
"/generate_stream",
payload,
cost=payload["parameters"]["max_new_tokens"],
stream=True,
) )
return resp["response"] # async generator response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=url,
)
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
res = response.json()
print(res)
# ---------------------- Demo Runner ---------------------- def call_generate_stream(
class APIDemo: endpoint_group_name: str, api_key: str, server_url: str
"""Demo and testing functionality for the TGI API client""" ) -> None:
WORKER_ENDPOINT = "/generate_stream"
def __init__(self, client: Serverless, endpoint_name: str): COST = 100
self.client = client route_payload = {
self.endpoint_name = endpoint_name "endpoint": endpoint_group_name,
"api_key": api_key,
async def handle_streaming_response(self, stream) -> str: "cost": COST,
"""Process streaming response and print tokens""" }
full_response = "" response = requests.post(
printed_answer = False urljoin(server_url, "/route/"),
json=route_payload,
async for event in stream: timeout=4,
tok = (event.get("token") or {}).get("text") )
if tok: response.raise_for_status() # Raise an exception for bad status codes
if not printed_answer: message = response.json()
printed_answer = True url = message["url"]
print("\n💬 Response: ", end="", flush=True) print(f"url: {url}")
print(tok, end="", flush=True) auth_data = dict(
full_response += tok signature=message["signature"],
cost=message["cost"],
print() # newline endpoint=message["endpoint"],
if printed_answer: reqnum=message["reqnum"],
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}") url=message["url"],
)
return full_response payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
async def demo_generate(self) -> None: url = urljoin(url, WORKER_ENDPOINT)
"""Demo non-streaming generation""" response = requests.post(url, json=req_data, stream=True)
print("=" * 60) response.raise_for_status() # Raise an exception for bad status codes
print("GENERATE DEMO (NON-STREAMING)") for line in response.iter_lines():
print("=" * 60) payload = line.decode().lstrip("data:").rstrip()
if payload:
response = await call_generate(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print(f"\n💬 Response: {response.get('generated_text', '')}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def demo_generate_stream(self) -> None:
"""Demo streaming generation"""
print("=" * 60)
print("GENERATE DEMO (STREAMING)")
print("=" * 60)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
try:
await self.handle_streaming_response(stream)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
async def interactive_chat(self) -> None:
"""Interactive session with streaming generation"""
print("=" * 60)
print("INTERACTIVE STREAMING SESSION")
print("=" * 60)
print(f"Using endpoint: {self.endpoint_name}")
print("Type 'quit' to exit")
print()
while True:
try: try:
user_input = input("You: ").strip() data = json.loads(payload)
print(data["token"]["text"], end="")
if user_input.lower() == "quit": sys.stdout.flush()
print("👋 Goodbye!") except (json.JSONDecodeError, KeyError) as e:
break log.warning(f"Failed to parse streaming response: {e}")
elif not user_input:
continue
print("Assistant: ", end="", flush=True)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=user_input,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
full_response = ""
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
print(tok, end="", flush=True)
full_response += tok
print() # newline
except KeyboardInterrupt:
print("\n👋 Session interrupted. Goodbye!")
break
except Exception as e:
log.error("\nError: %s", e)
continue continue
print()
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
return p
async def main_async():
args = build_arg_parser().parse_args()
selected = sum([args.generate, args.generate_stream, args.interactive])
if selected == 0:
print("Please specify exactly one test mode:")
print(" --generate : Test generate endpoint (non-streaming)")
print(" --generate-stream : Test generate endpoint with streaming")
print(" --interactive : Start interactive streaming session")
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
sys.exit(1)
elif selected > 1:
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint)
if args.generate:
await demo.demo_generate()
elif args.generate_stream:
await demo.demo_generate_stream()
elif args.interactive:
await demo.interactive_chat()
except Exception as e:
log.error("Error during test: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main_async()) from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
try:
call_generate(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_generate_stream(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")