Compare commits

...

58 Commits

Author SHA1 Message Date
Lucas Armand e143162438 bumpy pyworker version 2025-11-25 16:01:23 -08:00
Lucas Armand 7986e51e9e early errors 2025-11-24 15:24:06 -08:00
Lucas Armand 9c6ab78503 Move model log line 2025-11-24 15:22:23 -08:00
Lucas Armand 45e0c7d9ca Move model log rotate to top 2025-11-24 15:02:33 -08:00
Lucas Armand a4339bd3f1 hotfix: add f 2025-11-12 16:10:55 -08:00
Lucas Armand 2b26e5e20c hotfix: remove g 2025-11-12 16:01:57 -08:00
LucasArmandVast d3727d4fd7 Merge pull request #58 from vast-ai/update-client-scripts
Update client scripts
2025-11-12 10:22:42 -08:00
Lucas Armand a47c9d1ed0 remove test bugs 2025-11-11 18:13:46 -08:00
Lucas Armand 0b14562a63 dont exit on pyworker fail 2025-11-11 17:57:08 -08:00
Lucas Armand de9b50abb9 use set +e 2025-11-11 17:53:36 -08:00
Lucas Armand c510801723 fix 2025-11-11 17:49:34 -08:00
Lucas Armand a12523b1d2 Added bad code to tgi server to test 2025-11-11 17:41:12 -08:00
Lucas Armand eedf81c0a3 Updated readme and .gitignore 2025-11-11 17:18:40 -08:00
Lucas Armand 3adec1826d minor changes 2025-11-11 17:11:38 -08:00
Lucas Armand b55bfa9611 Updated clients, include vastai-sdk, handle non-UTF-8 2025-11-11 17:09:28 -08:00
LucasArmandVast 7db54f3bd7 Merge pull request #55 from vast-ai/use-mtoken
Use mtoken
2025-11-10 11:54:04 -08:00
LucasArmandVast d63a060202 Merge pull request #56 from vast-ai/obfuscate-mtoken
Obfuscate mtoken in logs
2025-11-10 11:53:17 -08:00
Lucas Armand c6521cb6d4 add ... 2025-11-07 10:10:35 -08:00
Lucas Armand b7fe4ebb91 Obfuscate mtoken in logs 2025-11-07 10:02:39 -08:00
Lucas Armand 8ae7b74605 bump version to 0.2.0 2025-11-05 13:32:21 -08:00
Lucas Armand 106067d716 bump version to 0.1.1 2025-11-04 17:15:59 -08:00
Lucas Armand f5134d4bf5 Fix spelling mistake 2025-11-04 16:59:39 -08:00
Lucas Armand 47e5460532 added mtoken 2025-11-04 15:55:14 -08:00
Colter-Downing ec2ac0a21a Merge pull request #52 from vast-ai/remove-sleeps-and-delays
Remove sleeps and delays
2025-10-30 11:53:39 -07:00
Abiola Akinnubi 2cde573c56 Merge pull request #48 from vast-ai/comfy-request-idx
Added request_idx to comfy auth_data
2025-10-30 11:27:35 -07:00
Abiola Akinnubi b2e4a5db0c Merge pull request #49 from vast-ai/unsecure_report_addr
Added caller for REPORT_ADDR to backend.py to use the report add
2025-10-30 10:39:46 -07:00
Abiola Akinnubi 7437028cb2 Added caller for REPORT_ADDR to backend.py 2025-10-29 18:02:17 -07:00
edgaratvast 02c8307af7 remove redis pubsub from pyworker (#53)
Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-29 17:07:56 -07:00
Colter Downing 7c0f316eeb leave the env vars alone! 2025-10-29 11:36:46 -07:00
Colter Downing b4025a744f remove env var writing 2025-10-29 09:58:09 -07:00
Colter Downing d190308329 removed 5 sec sleep and warmup request on load 2025-10-29 09:57:46 -07:00
LucasArmandVast 9f5a432513 Merge pull request #51 from vast-ai/delete-reqs-hotfix
Redis subscriber queue patch
2025-10-28 16:07:28 -07:00
Lucas Armand e09f1fa953 patch for redis queue 2025-10-28 16:03:50 -07:00
edgaratvast ba6f1c2e4b Fix signature (#50)
* change order of fields in auth_data to match autoscaler for signature verification

* also ignore __request_id

* Revert "change order of fields in auth_data to match autoscaler for signature verification" so that it's alphabetical again

This reverts commit b8223879c9.

* enforce alphabetical json dumping of message for signature verification

---------

Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-28 16:01:32 -07:00
Abiola Akinnubi 944f83fc03 Removed extra spaces from operator assignment 2025-10-28 21:03:52 +00:00
edgaratvast 298590fb88 Merge pull request #45 from vast-ai/new-pyworker
New PyWorker
2025-10-28 14:02:53 -07:00
Lucas Armand 814c3acd4c remove unused code 2025-10-28 13:43:57 -07:00
Lucas Armand 22bca74087 Prevent load time race 2025-10-27 18:25:21 -07:00
Lucas Armand 9c795e2a01 removed bad code 2025-10-27 17:03:13 -07:00
Lucas Armand 830b532781 Trying unified delete 2025-10-27 16:57:52 -07:00
LucasArmandVast d6a6e34c6b Merge branch 'main' into new-pyworker 2025-10-27 12:43:49 -07:00
Colter-Downing ac1e109c48 Merge pull request #47 from vast-ai/new-pyworker-vllm-prefix-cache
vLLM Prefix caching, benchmark bug fix, test load script
2025-10-27 12:30:34 -07:00
Abiola Akinnubi f56bbc0ebe Added request_idx to comfy auth_data 2025-10-27 03:17:06 +00:00
Rob Ballantyne 70d51bafe1 Merge pull request #36 from robballantyne/feat/comfyui-json-benchmark-workflow-from-file 2025-10-23 17:05:48 +01:00
Rob Ballantyne 63909736bb Merge pull request #4 from robballantyne/feat/comfyui-json-benchmark-workflow-from-file-no-silent-fail
Feat/comfyui json benchmark workflow from file no silent fail
2025-10-23 17:02:12 +01:00
Rob Ballantyne f4f7080df1 Re-add comment 2025-10-23 17:00:28 +01:00
Rob Ballantyne d51a338e8f log when benchmark file not used 2025-10-23 16:41:02 +01:00
Rob Ballantyne 92a04bd7af No silent fail if benchmark file is missing 2025-10-23 13:41:03 +01:00
LucasArmandVast c98d661513 Merge pull request #39 from vast-ai/remove-time-divide
PyWorker fixes for cur_load and acks bug
2025-10-13 10:06:22 -07:00
Lucas Armand f6fd1c6ac1 merge 2025-10-09 18:15:55 -07:00
Lucas Armand 055e346c8c Send metrics on request start 2025-10-09 10:13:50 -07:00
Lucas Armand 1cedb28acf Removed division by elapsed time, since autoscaler cur_load in units of workload 2025-10-08 16:54:18 -07:00
Rob Ballantyne ec25dda3ad Merge branch 'vast-ai:main' into feat/comfyui-json-benchmark-workflow-from-file 2025-10-08 14:49:32 +01:00
Colter-Downing 0397af719d Merge pull request #37 from robballantyne/bugfix/healthcheck-endpoint
Fix healthcheck endpoint URL

Tested and merged by Colter
2025-10-06 15:11:27 -07:00
Rob Ballantyne 3786cf978d Add awareness of errors thrown by the provisioning script 2025-10-05 23:14:59 +01:00
Rob Ballantyne a86d4bcf9c Import json 2025-10-05 23:05:33 +01:00
Rob Ballantyne e9b6a14a5e Import Path 2025-10-05 22:59:19 +01:00
Rob Ballantyne cadac033e1 Enables use of custom workflow for benchmarking
Retains existing method is misc/benchmark.json is nopt present
2025-10-05 22:53:22 +01:00
16 changed files with 789 additions and 788 deletions
+1
View File
@@ -3,3 +3,4 @@
__pycache__
bin/
lib64
.venv
+4 -3
View File
@@ -39,11 +39,12 @@ reporting these metrics to the autoscaler.
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447)
* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d)
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938)
Currently available workers:
* `hello_world`: A simple example worker for a basic LLM server.
* `openai`: A simple example worker for a basic vLLM server.
* `comfyui`: A worker for the ComfyUI image generation backend.
* `tgi`: A worker for the Text Generation Inference backend.
+24 -21
View File
@@ -30,7 +30,7 @@ from lib.data_types import (
BenchmarkResult
)
VERSION = "0.1.0"
VERSION = "0.2.1"
MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
@@ -66,10 +66,17 @@ class Backend:
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
mtoken: str = dataclasses.field(
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
)
def __post_init__(self):
self.metrics = Metrics()
self.metrics._set_version(self.version)
self.metrics._set_mtoken(self.mtoken)
self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False
@@ -104,23 +111,19 @@ class Backend:
#######################################Private#######################################
def _fetch_pubkey(self):
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
report_addr = self.report_addr.rstrip("/")
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
try:
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = None
for _ in range(5):
try:
key = RSA.import_key(result)
break
except ValueError as e:
log.debug(f"Error downloading key: {e}")
time.sleep(15)
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
if key is not None:
return key
except (ValueError , subprocess.CalledProcessError) as e:
log.debug(f"Error downloading key: {e}")
self.backend_errored("Failed to get autoscaler pubkey")
async def __handle_request(
self,
@@ -286,7 +289,7 @@ class Backend:
message = {
key: value
for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature"
if key != "signature" and key != "__request_id"
}
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug(
@@ -296,7 +299,7 @@ class Backend:
elif message in self.msg_history:
log.debug(f"message: {message} already in message history")
return False
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
@@ -315,10 +318,10 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark")
# trigger model load
payload = self.benchmark_handler.make_benchmark_payload()
_ = await self.__call_api(
handler=self.benchmark_handler, payload=payload
)
# payload = self.benchmark_handler.make_benchmark_payload()
# _ = await self.__call_api(
# handler=self.benchmark_handler, payload=payload
# )
return float(f.readline())
except FileNotFoundError:
pass
@@ -393,7 +396,7 @@ class Backend:
)
# some backends need a few seconds after logging successful startup before
# they can begin accepting requests
await sleep(5)
# await sleep(5)
try:
max_throughput = await run_benchmark()
self.__start_healthcheck = True
@@ -414,7 +417,7 @@ class Backend:
async def tail_log():
log.debug(f"tailing file: {self.model_log_file}")
async with await open_file(self.model_log_file) as f:
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
while True:
line = await f.readline()
if line:
+5 -3
View File
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
class AuthData:
"""data used to authenticate requester"""
signature: str
cost: str
endpoint: str
reqnum: int
url: str
request_idx: int
signature: str
url: str
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -190,10 +190,11 @@ class SystemMetrics:
self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage
def reset(self):
def reset(self, expected: float | None) -> None:
# autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading
if self.model_loading_time == expected:
self.model_loading_time = None
@@ -285,6 +286,7 @@ class AutoScalerData:
"""Data that is reported to autoscaler"""
id: int
mtoken: str
version: str
loadtime: float
cur_load: float
+62 -15
View File
@@ -28,6 +28,7 @@ def get_url() -> str:
@dataclass
class Metrics:
version: str = "0"
mtoken: str = ""
last_metric_update: float = 0.0
last_request_served: float = 0.0
update_pending: bool = False
@@ -142,17 +143,22 @@ class Metrics:
def _set_version(self, version: str) -> None:
self.version = version
def _set_mtoken(self, mtoken: str) -> None:
self.mtoken = mtoken
#######################################Private#######################################
async def __send_delete_requests_and_reset(self):
async def send_data(report_addr: str, success: bool) -> bool:
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
data = {
"worker_id": self.id,
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success],
"success": success
"mtoken": self.mtoken,
"request_idxs": idxs,
"success": success_flag,
}
log.debug(f"Deleting requests that {'succeeded' if success else 'failed'}: {data['request_idxs']}")
log.debug(
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
)
full_path = report_addr.rstrip("/") + "/delete_requests/"
for attempt in range(1, 4):
try:
@@ -162,26 +168,55 @@ class Metrics:
res.raise_for_status()
return True
except asyncio.TimeoutError:
log.debug(f"delete_requests timed out")
log.debug("delete_requests timed out")
except (ClientResponseError, Exception) as e:
log.debug(f"delete_requests failed with error: {e}")
await asyncio.sleep(2)
log.debug(f"retrying delete_request, attempt: {attempt}")
return False
# Take a snapshot of what we plan to send this tick.
# New arrivals after this snapshot will remain in the queue for the next tick.
snapshot = list(self.model_metrics.requests_deleting)
success_idxs = [r.request_idx for r in snapshot if r.success is True]
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
if not success_idxs and not failed_idxs:
return # nothing to do
for report_addr in self.report_addr:
success = await send_data(report_addr, success=True) and await send_data(report_addr, success=False)
if success is True:
self.model_metrics.requests_deleting.clear()
# TODO: Add a Redis subscriber queue for delete_requests
if report_addr == "https://cloud.vast.ai/api/v0":
# Patch: ignore the Redis API report_addr
continue
sent_success = True
sent_failed = True
if success_idxs:
sent_success = await post(report_addr, success_idxs, True)
if failed_idxs:
sent_failed = await post(report_addr, failed_idxs, False)
if sent_success and sent_failed:
# Remove only the items we actually sent from the live queue.
sent_set = set(success_idxs) | set(failed_idxs)
self.model_metrics.requests_deleting[:] = [
r for r in self.model_metrics.requests_deleting
if r.request_idx not in sent_set
]
break
async def __send_metrics_and_reset(self):
loadtime_snapshot = self.system_metrics.model_loading_time
def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData(
id=self.id,
mtoken=self.mtoken,
version=self.version,
loadtime=(self.system_metrics.model_loading_time or 0.0),
loadtime=(loadtime_snapshot or 0.0),
new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected,
@@ -199,17 +234,25 @@ class Metrics:
async def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data()
full_path = report_addr.rstrip("/") + "/worker_status/"
log_data = asdict(data)
def obfuscate(secret: str) -> str:
if secret is None:
return ""
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
log.debug(
"\n".join(
[
"#" * 60,
f"sending data to autoscaler",
f"{json.dumps((asdict(data)), indent=2)}",
f"{json.dumps(log_data, indent=2)}",
"#" * 60,
]
)
)
full_path = report_addr.rstrip("/") + "/worker_status/"
for attempt in range(1, 4):
try:
session = await self.http()
@@ -229,11 +272,15 @@ class Metrics:
self.system_metrics.update_disk_usage()
sent = False
for report_addr in self.report_addr:
success = await send_data(report_addr)
if success is True:
if await send_data(report_addr):
sent = True
break
if sent:
# clear the one-shot loadtime only if we actually sent *this* value
self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False
self.model_metrics.reset()
self.system_metrics.reset()
self.last_metric_update = time.time()
+21 -1
View File
@@ -3,15 +3,17 @@ import logging
from typing import List
import ssl
from asyncio import run, gather
import asyncio
from lib.backend import Backend
from lib.metrics import Metrics
from aiohttp import web
log = logging.getLogger(__file__)
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
try:
log.debug("getting certificate...")
use_ssl = os.environ.get("USE_SSL", "false") == "true"
if use_ssl is True:
@@ -38,3 +40,21 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
await gather(site.start(), backend._start_tracking())
run(main())
except Exception as e:
err_msg = f"PyWorker failed to launch: {e}"
log.error(err_msg)
async def beacon():
metrics = Metrics()
metrics._set_version(getattr(backend, "version", "0"))
metrics._set_mtoken(getattr(backend, "mtoken", ""))
try:
while True:
metrics._model_errored(err_msg)
await metrics._Metrics__send_metrics_and_reset()
await asyncio.sleep(10)
finally:
await metrics.aclose()
run(beacon())
+1
View File
@@ -8,3 +8,4 @@ Requests~=2.32
transformers~=4.52
utils==1.0.*
hf_transfer>=0.1.9
vastai-sdk>=0.2.0
+47 -5
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR"
@@ -41,6 +41,14 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG
echo_var MODEL_LOG
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
if [ -e "$MODEL_LOG" ]; then
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
cat "$MODEL_LOG" >> "$MODEL_LOG.old"
: > "$MODEL_LOG"
fi
# Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
@@ -124,9 +132,43 @@ cd "$SERVER_DIR"
echo "launching PyWorker server"
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
set +e
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
set -e
if [ "${PY_STATUS}" -ne 0 ]; then
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
MTOKEN="${MASTER_TOKEN:-}"
VERSION="${PYWORKER_VERSION:-0}"
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
for addr in "${REPORT_ADDRS[@]}"; do
curl -sS -X POST -H 'Content-Type: application/json' \
-d "$(cat <<JSON
{
"id": ${CONTAINER_ID:-0},
"mtoken": "${MTOKEN}",
"version": "${VERSION}",
"loadtime": 0,
"new_load": 0,
"cur_load": 0,
"rej_load": 0,
"max_perf": 0,
"cur_perf": 0,
"error_msg": "${ERROR_MSG}",
"num_requests_working": 0,
"num_requests_recieved": 0,
"additional_disk_usage": 0,
"working_request_idxs": [],
"cur_capacity": 0,
"max_capacity": 0,
"url": "${URL}"
}
JSON
)" "${addr%/}/worker_status/" || true
done
fi
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
echo "launching PyWorker server done"
+15 -3
View File
@@ -12,9 +12,21 @@ A docker image is provided but you may use any if the above requirements are met
## Benchmarking
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
### Custom Benchmark Workflows
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
**Ways to provide the benchmark file:**
- Fork this repository and add your `benchmark.json` file
- Write the file during worker provisioning (onstart script or setup phase)
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
### Default Benchmark (Fallback)
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
| Environment Variable | Default Value | Description |
| -------------------- | ------------- | ----------- |
@@ -24,7 +36,7 @@ The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image wo
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
### Calibrating Benchmark Duration
#### Calibrating Fallback Benchmark Duration
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
+13 -133
View File
@@ -1,107 +1,16 @@
import logging
from .data_types import count_workload
import uuid
import random
from urllib.parse import urljoin
import json
import asyncio
import random
import requests
from vastai import Serverless
from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types import count_workload
async def main():
async with Serverless() as client:
endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_text2image_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
"""Simple Text2Image using the new modifier-based approach"""
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
"""Helper function for making requests with consistent error handling"""
try:
response = requests.post(
url,
json=payload,
timeout=timeout,
verify=verify
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred during {context}: {http_err}")
log.error(f"Status Code: {response.status_code}")
log.error("Response content:", response.text)
return None
except requests.exceptions.Timeout:
log.error(f"Timeout occurred during {context}: {url}")
return None
except requests.exceptions.ConnectionError:
log.error(f"Connection error occurred during {context}: {url}")
return None
except json.JSONDecodeError as json_err:
log.error(f"Failed to decode JSON response during {context}: {json_err}")
if 'response' in locals():
print("Response content:", response.text)
return None
except Exception as err:
log.error(f"An unexpected error occurred during {context}: {err}")
if 'response' in locals():
log.error("Response content (if available):", response.text)
return None
WORKER_ENDPOINT = "/generate/sync"
# This worker has concurrency = 1. All workloads have cost value 1.0
COST = count_workload()
# Route to get worker URL
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
# First request - get routing information
route_response = make_request(
url=urljoin(server_url, "/route/"),
payload=route_payload,
timeout=4,
context="route request"
)
if route_response is None:
return None
if "url" not in route_response or not route_response["url"]:
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
return None
if "status" in route_response:
print(f"Autoscaler status: {route_response['status']}")
return None
# Extract data from route response
url = route_response["url"]
auth_data = dict(
signature=route_response["signature"],
cost=route_response["cost"],
endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"],
url=route_response["url"],
)
# Build the payload for the worker request
worker_payload = {
payload = {
"input": {
"request_id": str(uuid.uuid4()),
"modifier": "Text2Image",
@@ -116,40 +25,11 @@ def call_text2image_workflow(
}
}
req_data = dict(payload=worker_payload, auth_data=auth_data)
worker_url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {worker_url}")
# Second request - call the worker endpoint
worker_response = make_request(
url=worker_url,
payload=req_data,
verify=get_cert_file_path(),
context="worker request"
)
return worker_response
response = await endpoint.request("/generate/sync", payload, cost=count_workload())
# Get the file from the path on the local machine using SCP or SFTP
# or configure S3 to upload to cloud storage.
print(response["response"]["output"][0]["local_path"])
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
result = call_text2image_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
if result is None:
log.error("Text2Image workflow failed")
else:
print(result)
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
asyncio.run(main())
+28 -4
View File
@@ -5,12 +5,13 @@ import dataclasses
from typing import Dict, Any
from functools import cache
from math import ceil
from pathlib import Path
import json
import logging
from lib.data_types import ApiPayload, JsonDataException
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
log = logging.getLogger(__file__)
def count_workload() -> float:
# Always 100.0 where there is a single instance of ComfyUI handling requests
@@ -24,9 +25,32 @@ class ComfyWorkflowData(ApiPayload):
@classmethod
def for_test(cls):
"""
Use the variables available to simulate workflows of the required running time
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
Otherwise, use the variables available to simulate workflows of the required running time
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
"""
# Try to load benchmark.json
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
if benchmark_file.exists():
try:
with open(benchmark_file, "r") as f:
benchmark_workflow = json.load(f)
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"workflow_json": benchmark_workflow
}
)
except (json.JSONDecodeError, IOError):
# JSON is malformed or file can't be read, fall through to default
log.error(f"Failed to benchmark using {benchmark_file}")
# Fallback: read prompts and construct payload
log.info("Using fallback method for benchmarking")
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
test_prompt = random.choice(test_prompts).rstrip()
return cls(
input={
@@ -0,0 +1,107 @@
{
"3": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "text, watermark",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
+1
View File
@@ -19,6 +19,7 @@ MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: ", # This error is emitted when the model file is not there at all
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
]
+6 -12
View File
@@ -7,20 +7,13 @@ from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
"""
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
"""
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
from vastai import Serverless
def call_default_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
ENDPOINT_NAME = "my-comfyui-endpoint"
COST = 100 # Use a constant cost for image generation
def call_default_workflow(client: Serverless) -> None:
WORKER_ENDPOINT = "/prompt"
COST = 100
route_payload = {
@@ -82,6 +75,7 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
request_idx=message["request_idx"],
)
workflow = {
"3": {
+340 -410
View File
@@ -1,14 +1,15 @@
import logging
import sys
import json
import os
import sys
import subprocess
from urllib.parse import urljoin
from typing import Dict, Any, Optional, Iterator, Union, List
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig
import argparse
from typing import Any, Dict, List, Optional
from vastai import Serverless
import asyncio
# ---------------------- Logging ----------------------
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
@@ -16,135 +17,20 @@ logging.basicConfig(
)
log = logging.getLogger(__file__)
# ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
class APIClient:
"""Lightweight client focused solely on API communication"""
# Remove the generic WORKER_ENDPOINT since we're now going direct
DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4
def __init__(
self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
):
self.endpoint_group_name = endpoint_group_name
self.api_key = api_key
self.server_url = server_url
self.endpoint_api_key = endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service"""
if not self.endpoint_api_key:
raise ValueError("No valid endpoint API key available")
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.endpoint_api_key,
"cost": cost,
}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
timeout=self.DEFAULT_TIMEOUT,
)
response.raise_for_status()
return response.json()
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Create auth data from routing response"""
return {
"signature": message["signature"],
"cost": message["cost"],
"endpoint": message["endpoint"],
"reqnum": message["reqnum"],
"url": message["url"],
}
def _make_request(
self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint"""
# Get worker URL and auth data
cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost)
worker_url = message["url"]
auth_data = self._create_auth_data(message)
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}")
log.debug(f"Payload: {req_data}")
# Make the request using the specified method
if method.upper() == "POST":
response = requests.post(
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET":
response = requests.get(
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
if stream:
return self._handle_streaming_response(response)
else:
return response.json()
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
"""Handle streaming response and yield tokens"""
try:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
yield data # Yield the full chunk
except json.JSONDecodeError:
continue
except Exception as e:
log.error(f"Error handling streaming response: {e}")
raise
def call_completions(
self, config: CompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/completions", stream=config.stream
)
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? "
"What do you think this directory might be for?"
)
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- Tooling ----------------------
class ToolManager:
"""Handles tool definitions and execution"""
@@ -164,7 +50,7 @@ class ToolManager:
@staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""Get the ls tool definition"""
"""OpenAI-compatible tool schema"""
return [
{
"type": "function",
@@ -178,98 +64,217 @@ class ToolManager:
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result"""
function_name = tool_call["function"]["name"]
function_name = (tool_call.get("function") or {}).get("name")
if function_name == "list_files":
return self.list_files()
else:
raise ValueError(f"Unknown tool function: {function_name}")
# ----- Helpers to handle streamed tool_calls assembly -----
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
"""
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
We merge into a per-index state dict until the assistant message finishes.
"""
idx = tc_delta.get("index")
if idx is None:
return
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
if tc_delta.get("id"):
entry["id"] = tc_delta["id"]
fn_delta = tc_delta.get("function") or {}
if "name" in fn_delta and fn_delta["name"]:
entry["function"]["name"] = fn_delta["name"]
if "arguments" in fn_delta and fn_delta["arguments"]:
entry["function"]["arguments"] += fn_delta["arguments"]
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
return [state[i] for i in sorted(state.keys())]
# ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
}
}
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
# ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
}
}
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
# ---------------------- Demo Runner ----------------------
class APIDemo:
"""Demo and testing functionality for the API client"""
def __init__(
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
self.client = client
self.model = model
self.tool_manager = tool_manager or ToolManager()
def handle_streaming_response(
self, response_stream, show_reasoning: bool = True
) -> str:
"""
Handle streaming chat response and display all output.
"""
# ----- Streaming handler -----
async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
full_response = ""
reasoning_content = ""
reasoning_started = False
content_started = False
printed_reasoning = False
printed_answer = False
for chunk in response_stream:
# Normalize the chunk
if isinstance(chunk, str):
chunk = chunk.strip()
if chunk.startswith("data: "):
chunk = chunk[6:].strip()
if chunk in ["[DONE]", ""]:
continue
try:
parsed_chunk = json.loads(chunk)
except json.JSONDecodeError:
continue
elif isinstance(chunk, dict):
parsed_chunk = chunk
else:
continue
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
# Parse delta from the chunk
choices = parsed_chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
reasoning_token = delta.get("reasoning_content", "")
content_token = delta.get("content", "")
# Print reasoning token if applicable
if show_reasoning and reasoning_token:
if not reasoning_started:
# reasoning tokens
rc = delta.get("reasoning_content")
if rc and show_reasoning:
if not printed_reasoning:
print("\n🧠 Reasoning: ", end="", flush=True)
reasoning_started = True
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
reasoning_content += reasoning_token
printed_reasoning = True
print(rc, end="", flush=True)
reasoning_content += rc
# Print content token
if content_token:
if not content_started:
if show_reasoning and reasoning_started:
print(f"\n💬 Response: ", end="", flush=True)
# content tokens
content_part = delta.get("content")
if content_part:
if not printed_answer:
if show_reasoning and printed_reasoning:
print("\n💬 Response: ", end="", flush=True)
else:
print("Assistant: ", end="", flush=True)
content_started = True
print(content_token, end="", flush=True)
full_response += content_token
print() # Ensure newline after response
printed_answer = True
print(content_part, end="", flush=True)
full_response += content_part
print() # newline
if show_reasoning:
if reasoning_started or content_started:
if printed_reasoning or printed_answer:
print("\nStreaming completed.")
if reasoning_started:
if printed_reasoning:
print(f"Reasoning tokens: {len(reasoning_content.split())}")
if content_started:
if printed_answer:
print(f"Response tokens: {len(full_response.split())}")
return full_response
def test_tool_support(self) -> bool:
"""Test if the endpoint supports function calling"""
log.debug("Testing endpoint tool calling support...")
async def demo_completions(self) -> None:
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
# Try a simple request with minimal tools to test support
response = await call_completions(
client=self.client,
model=self.model,
prompt=COMPLETIONS_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print("\nResponse:")
print(json.dumps(response, indent=2))
async def demo_chat(self, use_streaming: bool = True) -> None:
print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
print("=" * 60)
messages = [{"role": "user", "content": CHAT_PROMPT}]
if use_streaming:
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
try:
await self.handle_streaming_response(stream, show_reasoning=True)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
else:
response = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def test_tool_support(self) -> bool:
"""Probe that tool schema is accepted (no actual call)"""
messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [
{
@@ -277,170 +282,147 @@ class APIDemo:
"function": {"name": "test_function", "description": "Test function"},
}
]
config = ChatCompletionConfig(
try:
_ = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
max_tokens=10,
tools=minimal_tool,
tool_choice="none", # Don't actually call the tool
tool_choice="none",
max_tokens=10
)
try:
response = self.client.call_chat_completions(config)
return True
except Exception as e:
log.error(f"Error: Endpoint does not support tool calling: {e}")
log.error("Endpoint does not support tool calling: %s", e)
return False
def demo_completions(self) -> None:
"""Demo: test basic completions endpoint"""
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
config = CompletionConfig(
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
)
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
)
response = self.client.call_completions(config)
if isinstance(response, dict):
print("\nResponse:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_chat(self, use_streaming: bool = True) -> None:
"""
Demo: test chat completions endpoint with optional streaming
"""
print("=" * 60)
print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60)
config = ChatCompletionConfig(
model=self.model,
messages=[{"role": "user", "content": CHAT_PROMPT}],
stream=use_streaming,
)
log.info(f"Testing chat completions with model '{self.model}'...")
response = self.client.call_chat_completions(config)
if use_streaming:
try:
self.handle_streaming_response(response, show_reasoning=True)
except Exception as e:
log.error(f"\nError during streaming: {e}")
import traceback
traceback.print_exc()
return
else:
if isinstance(response, dict):
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
async def demo_ls_tool(self) -> None:
"""Ask to list files using function calling, then provide final analysis"""
print("=" * 60)
print("TOOL USE DEMO: List Directory Contents")
print("=" * 60)
# Test if tools are supported first
if not self.test_tool_support():
if not await self.test_tool_support():
return
# Request with tool available
messages = [{"role": "user", "content": TOOLS_PROMPT}]
messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
config = ChatCompletionConfig(
# First pass: let the model decide tools, stream tool_calls and partial content
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto",
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
log.info(f"Making initial request with tool using model '{self.model}'...")
response = self.client.call_chat_completions(config)
assistant_content_buf: List[str] = []
tool_calls_state: Dict[int, Dict[str, Any]] = {}
printed_reasoning = False
printed_answer = False
if not isinstance(response, dict):
raise ValueError("Expected dict response for tool use")
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
rc = delta.get("reasoning_content")
if rc:
if not printed_reasoning:
printed_reasoning = True
print("🧠 Reasoning: ", end="", flush=True)
print(rc, end="", flush=True)
print(f"Assistant response: {message.get('content', 'No content')}")
content_part = delta.get("content")
if content_part:
assistant_content_buf.append(content_part)
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(content_part, end="", flush=True)
# Check for tool calls
tool_calls = message.get("tool_calls")
if not tool_calls:
raise ValueError(
"No tool calls made - model may not support function calling"
)
if "tool_calls" in delta and delta["tool_calls"]:
for tc_delta in delta["tool_calls"]:
_merge_tool_call_delta(tool_calls_state, tc_delta)
print(f"Tool calls detected: {len(tool_calls)}")
# If no tool calls, were done.
if not tool_calls_state:
print("\n(No tool calls were made.)")
return
# Execute the tool call
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
print(f"Executing tool: {function_name}")
tool_result = self.tool_manager.execute_tool_call(tool_call)
print(f"Tool result:\n{tool_result}")
# Add tool result and continue conversation
messages.append(message) # Add assistant's message with tool call
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": tool_result,
# Build assistant message with tool_calls
assistant_message = {
"role": "assistant",
"content": "".join(assistant_content_buf) if assistant_content_buf else None,
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
}
)
messages.append(assistant_message)
# Get final response
final_config = ChatCompletionConfig(
# Execute tools and feed results back
for tc in assistant_message["tool_calls"]:
tool_name = (tc.get("function") or {}).get("name")
call_id = tc.get("id")
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
try:
args = json.loads(raw_args) if raw_args.strip() else {}
except Exception as e:
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
continue
try:
if tool_name == "list_files":
tool_result = self.tool_manager.list_files()
else:
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
except Exception as e:
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
print("\n[Tool executed]", tool_name)
print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
# Second pass: get final streamed answer after tool results
stream2 = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
tools=self.tool_manager.get_ls_tool_definition(),
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print("Getting final response...")
final_response = self.client.call_chat_completions(final_config)
final_buf = []
printed_reasoning2 = False
printed_answer2 = False
if isinstance(final_response, dict):
final_choice = final_response.get("choices", [{}])[0]
final_message = final_choice.get("message", {})
final_content = final_message.get("content", "")
async for chunk in stream2:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
rc2 = delta.get("reasoning_content")
if rc2:
if not printed_reasoning2:
printed_reasoning2 = True
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
print(rc2, end="", flush=True)
c2 = delta.get("content")
if c2:
final_buf.append(c2)
if not printed_answer2:
printed_answer2 = True
print("\n💬 Response (final): ", end="", flush=True)
print(c2, end="", flush=True)
print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:")
print("=" * 60)
print(final_content)
print("".join(final_buf))
print("=" * 60)
def interactive_chat(self) -> None:
async def interactive_chat(self) -> None:
"""Interactive chat session with streaming"""
print("=" * 60)
print("INTERACTIVE STREAMING CHAT")
@@ -449,7 +431,7 @@ class APIDemo:
print("Type 'quit' to exit, 'clear' to clear history")
print()
messages = []
messages: List[Dict[str, Any]] = []
while True:
try:
@@ -467,16 +449,15 @@ class APIDemo:
messages.append({"role": "user", "content": user_input})
config = ChatCompletionConfig(
model=self.model, messages=messages, stream=True, temperature=0.7
)
print("Assistant: ", end="", flush=True)
response = self.client.call_chat_completions(config)
assistant_content = self.handle_streaming_response(
response, show_reasoning=True
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=0.7
)
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
# Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content})
@@ -485,115 +466,64 @@ class APIDemo:
print("\n👋 Chat interrupted. Goodbye!")
break
except Exception as e:
log.error(f"\nError: {e}")
log.error("\nError: %s", e)
continue
def main():
"""Main function with CLI switches for different tests"""
from lib.test_utils import test_args
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", required=True, help="Model to use for requests (required)")
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
# Add mandatory model argument
test_args.add_argument(
"--model", required=True, help="Model to use for requests (required)"
)
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
return p
# Add test mode arguments
test_args.add_argument(
"--completion", action="store_true", help="Test completions endpoint"
)
test_args.add_argument(
"--chat",
action="store_true",
help="Test chat completions endpoint (non-streaming)",
)
test_args.add_argument(
"--chat-stream",
action="store_true",
help="Test chat completions endpoint with streaming",
)
test_args.add_argument(
"--tools",
action="store_true",
help="Test function calling with ls tool (non-streaming)",
)
test_args.add_argument(
"--interactive",
action="store_true",
help="Start interactive streaming chat session",
)
args = test_args.parse_args()
async def main_async():
args = build_arg_parser().parse_args()
# Check that only one test mode is selected
test_modes = [
args.completion,
args.chat,
args.chat_stream,
args.tools,
args.interactive,
]
selected_count = sum(test_modes)
if selected_count == 0:
selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
if selected == 0:
print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint")
print(" --chat : Test chat completions endpoint (non-streaming)")
print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool (non-streaming)")
print(" --tools : Test function calling with ls tool")
print(" --interactive : Start interactive streaming chat session")
print(
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
sys.exit(1)
elif selected_count > 1:
elif selected > 1:
print("Please specify exactly one test mode")
sys.exit(1)
try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client
client = APIClient(
endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key,
server_url=Endpoint.get_autoscaler_server_url(args.instance),
endpoint_api_key=endpoint_api_key,
)
# Create tool manager and demo (passing the model parameter)
tool_manager = ToolManager()
demo = APIDemo(client, args.model, tool_manager)
print(f"Using model: {args.model}")
print("=" * 60)
# Run the selected test
try:
async with Serverless() as client:
demo = APIDemo(client, args.model, ToolManager())
if args.completion:
demo.demo_completions()
await demo.demo_completions()
elif args.chat:
demo.demo_chat(use_streaming=False)
await demo.demo_chat(use_streaming=False)
elif args.chat_stream:
demo.demo_chat(use_streaming=True)
await demo.demo_chat(use_streaming=True)
elif args.tools:
demo.demo_ls_tool()
await demo.demo_ls_tool()
elif args.interactive:
demo.interactive_chat()
await demo.interactive_chat()
except Exception as e:
log.error(f"Error during test: {e}", exc_info=True)
log.error("Error during test: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()
asyncio.run(main_async())
+53 -117
View File
@@ -1,125 +1,61 @@
import logging
import sys
import json
from urllib.parse import urljoin
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from vastai import Serverless
import asyncio
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
MAX_TOKENS = 1024
PROMPT = "Think step by step: Tell me about the Python programming language."
async def call_generate(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
WORKER_ENDPOINT = "/generate"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
payload = {
"inputs": PROMPT,
"parameters": {
"max_new_tokens": MAX_TOKENS,
"temperature": 0.7,
"return_full_text": False
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=url,
)
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
res = response.json()
print(res)
def call_generate_stream(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/generate_stream"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
print(f"url: {url}")
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
response = requests.post(url, json=req_data, stream=True)
response.raise_for_status() # Raise an exception for bad status codes
for line in response.iter_lines():
payload = line.decode().lstrip("data:").rstrip()
if payload:
try:
data = json.loads(payload)
print(data["token"]["text"], end="")
sys.stdout.flush()
except (json.JSONDecodeError, KeyError) as e:
log.warning(f"Failed to parse streaming response: {e}")
continue
print()
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
print(resp["response"]["generated_text"])
async def call_generate_stream(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"inputs": PROMPT,
"parameters": {
"max_new_tokens": MAX_TOKENS,
"temperature": 0.7,
"do_sample": True,
"return_full_text": False,
}
}
resp = await endpoint.request(
"/generate_stream",
payload,
cost=MAX_TOKENS,
stream=True,
)
stream = resp["response"]
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("Answer:\n", end="", flush=True)
print(tok, end="", flush=True)
async def main():
async with Serverless() as client:
await call_generate(client)
await call_generate_stream(client)
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
try:
call_generate(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_generate_stream(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
asyncio.run(main())