Compare commits

..

60 Commits

Author SHA1 Message Date
Colter Downing 62fbfb061d more logs 2025-11-24 18:40:45 -08:00
Colter Downing c772e1651b debug logs 2025-11-24 18:21:35 -08:00
Colter Downing ecc6a3ce0d catch all exceptions, add logs 2025-11-24 18:06:17 -08:00
Lucas Armand 7986e51e9e early errors 2025-11-24 15:24:06 -08:00
Lucas Armand 9c6ab78503 Move model log line 2025-11-24 15:22:23 -08:00
Lucas Armand 45e0c7d9ca Move model log rotate to top 2025-11-24 15:02:33 -08:00
Lucas Armand a4339bd3f1 hotfix: add f 2025-11-12 16:10:55 -08:00
Lucas Armand 2b26e5e20c hotfix: remove g 2025-11-12 16:01:57 -08:00
LucasArmandVast d3727d4fd7 Merge pull request #58 from vast-ai/update-client-scripts
Update client scripts
2025-11-12 10:22:42 -08:00
Lucas Armand a47c9d1ed0 remove test bugs 2025-11-11 18:13:46 -08:00
Lucas Armand 0b14562a63 dont exit on pyworker fail 2025-11-11 17:57:08 -08:00
Lucas Armand de9b50abb9 use set +e 2025-11-11 17:53:36 -08:00
Lucas Armand c510801723 fix 2025-11-11 17:49:34 -08:00
Lucas Armand a12523b1d2 Added bad code to tgi server to test 2025-11-11 17:41:12 -08:00
Lucas Armand eedf81c0a3 Updated readme and .gitignore 2025-11-11 17:18:40 -08:00
Lucas Armand 3adec1826d minor changes 2025-11-11 17:11:38 -08:00
Lucas Armand b55bfa9611 Updated clients, include vastai-sdk, handle non-UTF-8 2025-11-11 17:09:28 -08:00
LucasArmandVast 7db54f3bd7 Merge pull request #55 from vast-ai/use-mtoken
Use mtoken
2025-11-10 11:54:04 -08:00
LucasArmandVast d63a060202 Merge pull request #56 from vast-ai/obfuscate-mtoken
Obfuscate mtoken in logs
2025-11-10 11:53:17 -08:00
Lucas Armand c6521cb6d4 add ... 2025-11-07 10:10:35 -08:00
Lucas Armand b7fe4ebb91 Obfuscate mtoken in logs 2025-11-07 10:02:39 -08:00
Lucas Armand 8ae7b74605 bump version to 0.2.0 2025-11-05 13:32:21 -08:00
Lucas Armand 106067d716 bump version to 0.1.1 2025-11-04 17:15:59 -08:00
Lucas Armand f5134d4bf5 Fix spelling mistake 2025-11-04 16:59:39 -08:00
Lucas Armand 47e5460532 added mtoken 2025-11-04 15:55:14 -08:00
Colter-Downing ec2ac0a21a Merge pull request #52 from vast-ai/remove-sleeps-and-delays
Remove sleeps and delays
2025-10-30 11:53:39 -07:00
Abiola Akinnubi 2cde573c56 Merge pull request #48 from vast-ai/comfy-request-idx
Added request_idx to comfy auth_data
2025-10-30 11:27:35 -07:00
Abiola Akinnubi b2e4a5db0c Merge pull request #49 from vast-ai/unsecure_report_addr
Added caller for REPORT_ADDR to backend.py to use the report add
2025-10-30 10:39:46 -07:00
Abiola Akinnubi 7437028cb2 Added caller for REPORT_ADDR to backend.py 2025-10-29 18:02:17 -07:00
edgaratvast 02c8307af7 remove redis pubsub from pyworker (#53)
Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-29 17:07:56 -07:00
Colter Downing 7c0f316eeb leave the env vars alone! 2025-10-29 11:36:46 -07:00
Colter Downing b4025a744f remove env var writing 2025-10-29 09:58:09 -07:00
Colter Downing d190308329 removed 5 sec sleep and warmup request on load 2025-10-29 09:57:46 -07:00
LucasArmandVast 9f5a432513 Merge pull request #51 from vast-ai/delete-reqs-hotfix
Redis subscriber queue patch
2025-10-28 16:07:28 -07:00
Lucas Armand e09f1fa953 patch for redis queue 2025-10-28 16:03:50 -07:00
edgaratvast ba6f1c2e4b Fix signature (#50)
* change order of fields in auth_data to match autoscaler for signature verification

* also ignore __request_id

* Revert "change order of fields in auth_data to match autoscaler for signature verification" so that it's alphabetical again

This reverts commit b8223879c9.

* enforce alphabetical json dumping of message for signature verification

---------

Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-28 16:01:32 -07:00
Abiola Akinnubi 944f83fc03 Removed extra spaces from operator assignment 2025-10-28 21:03:52 +00:00
edgaratvast 298590fb88 Merge pull request #45 from vast-ai/new-pyworker
New PyWorker
2025-10-28 14:02:53 -07:00
Lucas Armand 814c3acd4c remove unused code 2025-10-28 13:43:57 -07:00
Lucas Armand 22bca74087 Prevent load time race 2025-10-27 18:25:21 -07:00
Lucas Armand 9c795e2a01 removed bad code 2025-10-27 17:03:13 -07:00
Lucas Armand 830b532781 Trying unified delete 2025-10-27 16:57:52 -07:00
LucasArmandVast d6a6e34c6b Merge branch 'main' into new-pyworker 2025-10-27 12:43:49 -07:00
Colter-Downing ac1e109c48 Merge pull request #47 from vast-ai/new-pyworker-vllm-prefix-cache
vLLM Prefix caching, benchmark bug fix, test load script
2025-10-27 12:30:34 -07:00
Abiola Akinnubi f56bbc0ebe Added request_idx to comfy auth_data 2025-10-27 03:17:06 +00:00
Rob Ballantyne 70d51bafe1 Merge pull request #36 from robballantyne/feat/comfyui-json-benchmark-workflow-from-file 2025-10-23 17:05:48 +01:00
Rob Ballantyne 63909736bb Merge pull request #4 from robballantyne/feat/comfyui-json-benchmark-workflow-from-file-no-silent-fail
Feat/comfyui json benchmark workflow from file no silent fail
2025-10-23 17:02:12 +01:00
Rob Ballantyne f4f7080df1 Re-add comment 2025-10-23 17:00:28 +01:00
Rob Ballantyne d51a338e8f log when benchmark file not used 2025-10-23 16:41:02 +01:00
Rob Ballantyne 92a04bd7af No silent fail if benchmark file is missing 2025-10-23 13:41:03 +01:00
LucasArmandVast c98d661513 Merge pull request #39 from vast-ai/remove-time-divide
PyWorker fixes for cur_load and acks bug
2025-10-13 10:06:22 -07:00
Lucas Armand f6fd1c6ac1 merge 2025-10-09 18:15:55 -07:00
Lucas Armand 055e346c8c Send metrics on request start 2025-10-09 10:13:50 -07:00
Lucas Armand 1cedb28acf Removed division by elapsed time, since autoscaler cur_load in units of workload 2025-10-08 16:54:18 -07:00
Rob Ballantyne ec25dda3ad Merge branch 'vast-ai:main' into feat/comfyui-json-benchmark-workflow-from-file 2025-10-08 14:49:32 +01:00
Colter-Downing 0397af719d Merge pull request #37 from robballantyne/bugfix/healthcheck-endpoint
Fix healthcheck endpoint URL

Tested and merged by Colter
2025-10-06 15:11:27 -07:00
Rob Ballantyne 3786cf978d Add awareness of errors thrown by the provisioning script 2025-10-05 23:14:59 +01:00
Rob Ballantyne a86d4bcf9c Import json 2025-10-05 23:05:33 +01:00
Rob Ballantyne e9b6a14a5e Import Path 2025-10-05 22:59:19 +01:00
Rob Ballantyne cadac033e1 Enables use of custom workflow for benchmarking
Retains existing method is misc/benchmark.json is nopt present
2025-10-05 22:53:22 +01:00
16 changed files with 877 additions and 795 deletions
+2 -1
View File
@@ -2,4 +2,5 @@
.envrc .envrc
__pycache__ __pycache__
bin/ bin/
lib64 lib64
.venv
+4 -3
View File
@@ -39,11 +39,12 @@ reporting these metrics to the autoscaler.
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few: If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00) * **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447) * **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938)
Currently available workers: Currently available workers:
* `hello_world`: A simple example worker for a basic LLM server. * `openai`: A simple example worker for a basic vLLM server.
* `comfyui`: A worker for the ComfyUI image generation backend. * `comfyui`: A worker for the ComfyUI image generation backend.
* `tgi`: A worker for the Text Generation Inference backend. * `tgi`: A worker for the Text Generation Inference backend.
+62 -33
View File
@@ -30,7 +30,7 @@ from lib.data_types import (
BenchmarkResult BenchmarkResult
) )
VERSION = "0.1.0" VERSION = "0.2.0"
MSG_HISTORY_LEN = 100 MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
@@ -66,10 +66,17 @@ class Backend:
unsecured: bool = dataclasses.field( unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
) )
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
mtoken: str = dataclasses.field(
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
self.metrics._set_version(self.version) self.metrics._set_version(self.version)
self.metrics._set_mtoken(self.mtoken)
self._total_pubkey_fetch_errors = 0 self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False self.__start_healthcheck: bool = False
@@ -104,23 +111,19 @@ class Backend:
#######################################Private####################################### #######################################Private#######################################
def _fetch_pubkey(self): def _fetch_pubkey(self):
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"] report_addr = self.report_addr.rstrip("/")
result = subprocess.check_output(command, universal_newlines=True) command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
log.debug("public key:") try:
log.debug(result) result = subprocess.check_output(command, universal_newlines=True)
key = None log.debug("public key:")
for _ in range(5): log.debug(result)
try: key = RSA.import_key(result)
key = RSA.import_key(result) if key is not None:
break return key
except ValueError as e: except (ValueError , subprocess.CalledProcessError) as e:
log.debug(f"Error downloading key: {e}") log.debug(f"Error downloading key: {e}")
time.sleep(15) self.backend_errored("Failed to get autoscaler pubkey")
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
async def __handle_request( async def __handle_request(
self, self,
@@ -232,10 +235,14 @@ class Backend:
log.debug("No healthcheck endpoint defined, skipping healthcheck") log.debug("No healthcheck endpoint defined, skipping healthcheck")
return return
first_healthcheck = True
while True: while True:
await sleep(10) await sleep(10)
if self.__start_healthcheck is False: if self.__start_healthcheck is False:
continue continue
if first_healthcheck:
log.info(f"[healthcheck] First healthcheck starting (model is now loaded)")
first_healthcheck = False
try: try:
log.debug(f"Performing healthcheck on {health_check_url}") log.debug(f"Performing healthcheck on {health_check_url}")
async with self.healthcheck_session.get(health_check_url) as response: async with self.healthcheck_session.get(health_check_url) as response:
@@ -253,9 +260,22 @@ class Backend:
self.backend_errored(str(e)) self.backend_errored(str(e))
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
await gather( log.info("Starting tracking tasks (read_logs, send_metrics_loop, healthcheck, send_delete_requests_loop)")
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop() task_names = ["read_logs", "send_metrics_loop", "healthcheck", "send_delete_requests_loop"]
results = await gather(
self.__read_logs(),
self.metrics._send_metrics_loop(),
self.__healthcheck(),
self.metrics._send_delete_requests_loop(),
return_exceptions=True
) )
# If we get here, one or more tasks exited (they should run forever)
log.error(f"CRITICAL: _start_tracking gather returned! This should never happen. Results: {results}")
for name, result in zip(task_names, results):
if isinstance(result, Exception):
log.error(f"Tracking task '{name}' crashed with exception: {result}", exc_info=result)
elif result is not None:
log.warning(f"Tracking task '{name}' exited unexpectedly with result: {result}")
def backend_errored(self, msg: str) -> None: def backend_errored(self, msg: str) -> None:
self.metrics._model_errored(msg) self.metrics._model_errored(msg)
@@ -286,7 +306,7 @@ class Backend:
message = { message = {
key: value key: value
for (key, value) in (dataclasses.asdict(auth_data).items()) for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature" if key != "signature" and key != "__request_id"
} }
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN): if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug( log.debug(
@@ -296,7 +316,7 @@ class Backend:
elif message in self.msg_history: elif message in self.msg_history:
log.debug(f"message: {message} already in message history") log.debug(f"message: {message} already in message history")
return False return False
elif verify_signature(json.dumps(message, indent=4), auth_data.signature): elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum) self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message) self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:] self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
@@ -315,10 +335,10 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f: with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark") log.debug("already ran benchmark")
# trigger model load # trigger model load
payload = self.benchmark_handler.make_benchmark_payload() # payload = self.benchmark_handler.make_benchmark_payload()
_ = await self.__call_api( # _ = await self.__call_api(
handler=self.benchmark_handler, payload=payload # handler=self.benchmark_handler, payload=payload
) # )
return float(f.readline()) return float(f.readline())
except FileNotFoundError: except FileNotFoundError:
pass pass
@@ -393,18 +413,23 @@ class Backend:
) )
# some backends need a few seconds after logging successful startup before # some backends need a few seconds after logging successful startup before
# they can begin accepting requests # they can begin accepting requests
await sleep(5) # await sleep(5)
try: try:
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
log.info(f"[benchmark] Benchmark complete, max_throughput={max_throughput}, setting healthcheck=True")
self.__start_healthcheck = True self.__start_healthcheck = True
self.metrics._model_loaded( self.metrics._model_loaded(
max_throughput=max_throughput, max_throughput=max_throughput,
) )
log.info(f"[benchmark] _model_loaded() called, returning from handle_log_line")
except ClientConnectorError as e: except ClientConnectorError as e:
log.debug( log.debug(
f"failed to connect to comfyui api during benchmark" f"failed to connect to model api during benchmark"
) )
self.backend_errored(str(e)) self.backend_errored(str(e))
except Exception as e:
log.error(f"Unexpected error during benchmark: {e}", exc_info=True)
self.backend_errored(f"Benchmark failed: {e}")
case LogAction.ModelError if msg in log_line: case LogAction.ModelError if msg in log_line:
log.debug(f"Got log line indicating error: {log_line}") log.debug(f"Got log line indicating error: {log_line}")
self.backend_errored(msg) self.backend_errored(msg)
@@ -414,12 +439,16 @@ class Backend:
async def tail_log(): async def tail_log():
log.debug(f"tailing file: {self.model_log_file}") log.debug(f"tailing file: {self.model_log_file}")
async with await open_file(self.model_log_file) as f: async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
while True: while True:
line = await f.readline() try:
if line: line = await f.readline()
await handle_log_line(line.rstrip()) if line:
else: await handle_log_line(line.rstrip())
else:
await asyncio.sleep(LOG_POLL_INTERVAL)
except Exception as e:
log.error(f"Error processing log line: {e}", exc_info=True)
await asyncio.sleep(LOG_POLL_INTERVAL) await asyncio.sleep(LOG_POLL_INTERVAL)
########### ###########
+6 -4
View File
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
class AuthData: class AuthData:
"""data used to authenticate requester""" """data used to authenticate requester"""
signature: str
cost: str cost: str
endpoint: str endpoint: str
reqnum: int reqnum: int
url: str
request_idx: int request_idx: int
signature: str
url: str
@classmethod @classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]): def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -190,11 +190,12 @@ class SystemMetrics:
self.additional_disk_usage = disk_usage - self.last_disk_usage self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage self.last_disk_usage = disk_usage
def reset(self): def reset(self, expected: float | None) -> None:
# autoscaler excepts model_loading_time to be populated only once, when the instance has # autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances # finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading # as well: they should send model_loading_time once when they are done loading
self.model_loading_time = None if self.model_loading_time == expected:
self.model_loading_time = None
@dataclass @dataclass
@@ -285,6 +286,7 @@ class AutoScalerData:
"""Data that is reported to autoscaler""" """Data that is reported to autoscaler"""
id: int id: int
mtoken: str
version: str version: str
loadtime: float loadtime: float
cur_load: float cur_load: float
+100 -18
View File
@@ -1,4 +1,5 @@
import os import os
import sys
import time import time
import logging import logging
import json import json
@@ -17,6 +18,14 @@ DELETE_REQUESTS_INTERVAL = 1
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
def _flush_logs():
"""Force flush all log handlers and stdout/stderr."""
for handler in logging.root.handlers:
handler.flush()
sys.stdout.flush()
sys.stderr.flush()
@cache @cache
def get_url() -> str: def get_url() -> str:
use_ssl = os.environ.get("USE_SSL", "false") == "true" use_ssl = os.environ.get("USE_SSL", "false") == "true"
@@ -28,6 +37,7 @@ def get_url() -> str:
@dataclass @dataclass
class Metrics: class Metrics:
version: str = "0" version: str = "0"
mtoken: str = ""
last_metric_update: float = 0.0 last_metric_update: float = 0.0
last_request_served: float = 0.0 last_request_served: float = 0.0
update_pending: bool = False update_pending: bool = False
@@ -118,22 +128,41 @@ class Metrics:
await self.__send_delete_requests_and_reset() await self.__send_delete_requests_and_reset()
async def _send_metrics_loop(self) -> Awaitable[NoReturn]: async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
loop_count = 0
first_loaded_send_done = False
while True: while True:
await sleep(METRICS_UPDATE_INTERVAL) await sleep(METRICS_UPDATE_INTERVAL)
loop_count += 1
elapsed = time.time() - self.last_metric_update elapsed = time.time() - self.last_metric_update
# Log heartbeat every 30 seconds to confirm loop is running
if loop_count % 30 == 0:
log.debug(f"[heartbeat] metrics loop alive, loop_count={loop_count}, model_loaded={self.system_metrics.model_is_loaded}")
_flush_logs()
# Extra logging for first few iterations after model loads
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
log.info(f"[transition] First iteration with model_loaded=True, loop_count={loop_count}, elapsed={elapsed:.1f}")
_flush_logs()
if self.system_metrics.model_is_loaded is False and elapsed >= 10: if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait") log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
await self.__send_metrics_and_reset() await self.__send_metrics_and_reset()
elif self.update_pending or elapsed > 10: elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait") log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
await self.__send_metrics_and_reset() await self.__send_metrics_and_reset()
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
first_loaded_send_done = True
log.info(f"[transition] First loaded metrics send complete, continuing to next iteration...")
_flush_logs()
def _model_loaded(self, max_throughput: float) -> None: def _model_loaded(self, max_throughput: float) -> None:
log.info(f"MODEL LOADED: Setting model_is_loaded=True, max_throughput={max_throughput}")
_flush_logs()
self.system_metrics.model_loading_time = ( self.system_metrics.model_loading_time = (
time.time() - self.system_metrics.model_loading_start time.time() - self.system_metrics.model_loading_start
) )
self.system_metrics.model_is_loaded = True self.system_metrics.model_is_loaded = True
self.model_metrics.max_throughput = max_throughput self.model_metrics.max_throughput = max_throughput
log.info(f"MODEL LOADED: model_loading_time={self.system_metrics.model_loading_time}")
_flush_logs()
def _model_errored(self, error_msg: str) -> None: def _model_errored(self, error_msg: str) -> None:
self.model_metrics.set_errored(error_msg) self.model_metrics.set_errored(error_msg)
@@ -142,17 +171,22 @@ class Metrics:
def _set_version(self, version: str) -> None: def _set_version(self, version: str) -> None:
self.version = version self.version = version
def _set_mtoken(self, mtoken: str) -> None:
self.mtoken = mtoken
#######################################Private####################################### #######################################Private#######################################
async def __send_delete_requests_and_reset(self): async def __send_delete_requests_and_reset(self):
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
async def send_data(report_addr: str, success: bool) -> bool:
data = { data = {
"worker_id": self.id, "worker_id": self.id,
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success], "mtoken": self.mtoken,
"success": success "request_idxs": idxs,
"success": success_flag,
} }
log.debug(f"Deleting requests that {'succeeded' if success else 'failed'}: {data['request_idxs']}") log.debug(
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
)
full_path = report_addr.rstrip("/") + "/delete_requests/" full_path = report_addr.rstrip("/") + "/delete_requests/"
for attempt in range(1, 4): for attempt in range(1, 4):
try: try:
@@ -162,26 +196,55 @@ class Metrics:
res.raise_for_status() res.raise_for_status()
return True return True
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.debug(f"delete_requests timed out") log.debug("delete_requests timed out")
except (ClientResponseError, Exception) as e: except (ClientResponseError, Exception) as e:
log.debug(f"delete_requests failed with error: {e}") log.debug(f"delete_requests failed with error: {e}")
await asyncio.sleep(2) await asyncio.sleep(2)
log.debug(f"retrying delete_request, attempt: {attempt}") log.debug(f"retrying delete_request, attempt: {attempt}")
return False
# Take a snapshot of what we plan to send this tick.
# New arrivals after this snapshot will remain in the queue for the next tick.
snapshot = list(self.model_metrics.requests_deleting)
success_idxs = [r.request_idx for r in snapshot if r.success is True]
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
if not success_idxs and not failed_idxs:
return # nothing to do
for report_addr in self.report_addr: for report_addr in self.report_addr:
success = await send_data(report_addr, success=True) and await send_data(report_addr, success=False) # TODO: Add a Redis subscriber queue for delete_requests
if success is True: if report_addr == "https://cloud.vast.ai/api/v0":
self.model_metrics.requests_deleting.clear() # Patch: ignore the Redis API report_addr
continue
sent_success = True
sent_failed = True
if success_idxs:
sent_success = await post(report_addr, success_idxs, True)
if failed_idxs:
sent_failed = await post(report_addr, failed_idxs, False)
if sent_success and sent_failed:
# Remove only the items we actually sent from the live queue.
sent_set = set(success_idxs) | set(failed_idxs)
self.model_metrics.requests_deleting[:] = [
r for r in self.model_metrics.requests_deleting
if r.request_idx not in sent_set
]
break break
async def __send_metrics_and_reset(self): async def __send_metrics_and_reset(self):
loadtime_snapshot = self.system_metrics.model_loading_time
def compute_autoscaler_data() -> AutoScalerData: def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData( return AutoScalerData(
id=self.id, id=self.id,
mtoken=self.mtoken,
version=self.version, version=self.version,
loadtime=(self.system_metrics.model_loading_time or 0.0), loadtime=(loadtime_snapshot or 0.0),
new_load=self.model_metrics.workload_processing, new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load, cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected, rej_load=self.model_metrics.workload_rejected,
@@ -199,17 +262,25 @@ class Metrics:
async def send_data(report_addr: str) -> bool: async def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data() data = compute_autoscaler_data()
full_path = report_addr.rstrip("/") + "/worker_status/" log_data = asdict(data)
def obfuscate(secret: str) -> str:
if secret is None:
return ""
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
log.debug( log.debug(
"\n".join( "\n".join(
[ [
"#" * 60, "#" * 60,
f"sending data to autoscaler", f"sending data to autoscaler",
f"{json.dumps((asdict(data)), indent=2)}", f"{json.dumps(log_data, indent=2)}",
"#" * 60, "#" * 60,
] ]
) )
) )
full_path = report_addr.rstrip("/") + "/worker_status/"
for attempt in range(1, 4): for attempt in range(1, 4):
try: try:
session = await self.http() session = await self.http()
@@ -228,12 +299,23 @@ class Metrics:
########### ###########
self.system_metrics.update_disk_usage() self.system_metrics.update_disk_usage()
had_loadtime = loadtime_snapshot is not None and loadtime_snapshot > 0
sent = False
for report_addr in self.report_addr: for report_addr in self.report_addr:
success = await send_data(report_addr) if await send_data(report_addr):
if success is True: sent = True
break break
self.update_pending = False
self.model_metrics.reset() if sent:
self.system_metrics.reset() if had_loadtime:
self.last_metric_update = time.time() log.info(f"FIRST LOADTIME METRICS SENT SUCCESSFULLY! loadtime={loadtime_snapshot}")
_flush_logs()
# clear the one-shot loadtime only if we actually sent *this* value
self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False
self.model_metrics.reset()
self.last_metric_update = time.time()
if had_loadtime:
log.info(f"POST-SEND: reset complete, last_metric_update={self.last_metric_update}, continuing loop...")
_flush_logs()
+65 -25
View File
@@ -1,40 +1,80 @@
import os import os
import logging import logging
import signal
import sys
from typing import List from typing import List
import ssl import ssl
from asyncio import run, gather from asyncio import run, gather
import asyncio
from lib.backend import Backend from lib.backend import Backend
from lib.metrics import Metrics
from aiohttp import web from aiohttp import web
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
def _setup_signal_handlers():
"""Setup signal handlers to log when process receives termination signals."""
def signal_handler(signum, frame):
sig_name = signal.Signals(signum).name
log.error(f"SIGNAL RECEIVED: {sig_name} ({signum}) - process is being terminated")
sys.stdout.flush()
sys.stderr.flush()
sys.exit(128 + signum)
# Handle common termination signals
for sig in [signal.SIGTERM, signal.SIGINT, signal.SIGHUP]:
try:
signal.signal(sig, signal_handler)
except (OSError, ValueError):
pass # Some signals may not be available
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs): def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("getting certificate...") _setup_signal_handlers()
use_ssl = os.environ.get("USE_SSL", "false") == "true" try:
if use_ssl is True: log.debug("getting certificate...")
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) use_ssl = os.environ.get("USE_SSL", "false") == "true"
ssl_context.load_cert_chain( if use_ssl is True:
certfile="/etc/instance.crt", ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
keyfile="/etc/instance.key", ssl_context.load_cert_chain(
) certfile="/etc/instance.crt",
else: keyfile="/etc/instance.key",
ssl_context = None )
else:
ssl_context = None
async def main(): async def main():
log.debug("starting server...") log.debug("starting server...")
app = web.Application() app = web.Application()
app.add_routes(routes) app.add_routes(routes)
runner = web.AppRunner(app) runner = web.AppRunner(app)
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(
runner, runner,
ssl_context=ssl_context, ssl_context=ssl_context,
port=int(os.environ["WORKER_PORT"]), port=int(os.environ["WORKER_PORT"]),
**kwargs **kwargs
) )
await gather(site.start(), backend._start_tracking()) await gather(site.start(), backend._start_tracking())
run(main()) run(main())
except Exception as e:
err_msg = f"PyWorker failed to launch: {e}"
log.error(err_msg)
async def beacon():
metrics = Metrics()
metrics._set_version(getattr(backend, "version", "0"))
metrics._set_mtoken(getattr(backend, "mtoken", ""))
try:
while True:
metrics._model_errored(err_msg)
await metrics._Metrics__send_metrics_and_reset()
await asyncio.sleep(10)
finally:
await metrics.aclose()
run(beacon())
+1
View File
@@ -8,3 +8,4 @@ Requests~=2.32
transformers~=4.52 transformers~=4.52
utils==1.0.* utils==1.0.*
hf_transfer>=0.1.9 hf_transfer>=0.1.9
vastai-sdk>=0.2.0
+48 -6
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log" DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log" PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}" REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}" USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}" WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR" mkdir -p "$WORKSPACE_DIR"
@@ -41,6 +41,14 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG echo_var PYWORKER_LOG
echo_var MODEL_LOG echo_var MODEL_LOG
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
if [ -e "$MODEL_LOG" ]; then
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
cat "$MODEL_LOG" >> "$MODEL_LOG.old"
: > "$MODEL_LOG"
fi
# Populate /etc/environment with quoted values # Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then if ! grep -q "VAST" /etc/environment; then
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
@@ -124,9 +132,43 @@ cd "$SERVER_DIR"
echo "launching PyWorker server" echo "launching PyWorker server"
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines set +e
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG" PY_STATUS=${PIPESTATUS[0]}
set -e
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") & if [ "${PY_STATUS}" -ne 0 ]; then
echo "launching PyWorker server done" echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
MTOKEN="${MASTER_TOKEN:-}"
VERSION="${PYWORKER_VERSION:-0}"
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
for addr in "${REPORT_ADDRS[@]}"; do
curl -sS -X POST -H 'Content-Type: application/json' \
-d "$(cat <<JSON
{
"id": ${CONTAINER_ID:-0},
"mtoken": "${MTOKEN}",
"version": "${VERSION}",
"loadtime": 0,
"new_load": 0,
"cur_load": 0,
"rej_load": 0,
"max_perf": 0,
"cur_perf": 0,
"error_msg": "${ERROR_MSG}",
"num_requests_working": 0,
"num_requests_recieved": 0,
"additional_disk_usage": 0,
"working_request_idxs": [],
"cur_capacity": 0,
"max_capacity": 0,
"url": "${URL}"
}
JSON
)" "${addr%/}/worker_status/" || true
done
fi
echo "launching PyWorker server done"
+15 -3
View File
@@ -12,9 +12,21 @@ A docker image is provided but you may use any if the above requirements are met
## Benchmarking ## Benchmarking
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines. ### Custom Benchmark Workflows
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables: You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
**Ways to provide the benchmark file:**
- Fork this repository and add your `benchmark.json` file
- Write the file during worker provisioning (onstart script or setup phase)
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
### Default Benchmark (Fallback)
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
| Environment Variable | Default Value | Description | | Environment Variable | Default Value | Description |
| -------------------- | ------------- | ----------- | | -------------------- | ------------- | ----------- |
@@ -24,7 +36,7 @@ The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image wo
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns. Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
### Calibrating Benchmark Duration #### Calibrating Fallback Benchmark Duration
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements. To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
+26 -146
View File
@@ -1,155 +1,35 @@
import logging from .data_types import count_workload
import uuid import uuid
import random import random
from urllib.parse import urljoin import asyncio
import json import random
import requests from vastai import Serverless
from lib.test_utils import print_truncate_res async def main():
from utils.endpoint_util import Endpoint async with Serverless() as client:
from utils.ssl import get_cert_file_path endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name
from .data_types import count_workload
logging.basicConfig( payload = {
level=logging.DEBUG, "input": {
format="%(asctime)s[%(levelname)-5s] %(message)s", "request_id": str(uuid.uuid4()),
datefmt="%Y-%m-%d %H:%M:%S", "modifier": "Text2Image",
) "modifications": {
log = logging.getLogger(__file__) "prompt": "a beautiful landscape with mountains and lakes",
"width": 1024,
"height": 1024,
def call_text2image_workflow( "steps": 20,
endpoint_group_name: str, api_key: str, server_url: str "seed": random.randint(0, 2**32 - 1)
) -> None: },
"""Simple Text2Image using the new modifier-based approach""" "workflow_json": {} # Empty since using modifier approach
}
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
"""Helper function for making requests with consistent error handling"""
try:
response = requests.post(
url,
json=payload,
timeout=timeout,
verify=verify
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred during {context}: {http_err}")
log.error(f"Status Code: {response.status_code}")
log.error("Response content:", response.text)
return None
except requests.exceptions.Timeout:
log.error(f"Timeout occurred during {context}: {url}")
return None
except requests.exceptions.ConnectionError:
log.error(f"Connection error occurred during {context}: {url}")
return None
except json.JSONDecodeError as json_err:
log.error(f"Failed to decode JSON response during {context}: {json_err}")
if 'response' in locals():
print("Response content:", response.text)
return None
except Exception as err:
log.error(f"An unexpected error occurred during {context}: {err}")
if 'response' in locals():
log.error("Response content (if available):", response.text)
return None
WORKER_ENDPOINT = "/generate/sync"
# This worker has concurrency = 1. All workloads have cost value 1.0
COST = count_workload()
# Route to get worker URL
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
# First request - get routing information
route_response = make_request(
url=urljoin(server_url, "/route/"),
payload=route_payload,
timeout=4,
context="route request"
)
if route_response is None:
return None
if "url" not in route_response or not route_response["url"]:
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
return None
if "status" in route_response:
print(f"Autoscaler status: {route_response['status']}")
return None
# Extract data from route response
url = route_response["url"]
auth_data = dict(
signature=route_response["signature"],
cost=route_response["cost"],
endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"],
url=route_response["url"],
)
# Build the payload for the worker request
worker_payload = {
"input": {
"request_id": str(uuid.uuid4()),
"modifier": "Text2Image",
"modifications": {
"prompt": "a beautiful landscape with mountains and lakes",
"width": 1024,
"height": 1024,
"steps": 20,
"seed": random.randint(0, 2**32 - 1)
},
"workflow_json": {} # Empty since using modifier approach
} }
}
response = await endpoint.request("/generate/sync", payload, cost=count_workload())
req_data = dict(payload=worker_payload, auth_data=auth_data)
worker_url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {worker_url}")
# Second request - call the worker endpoint
worker_response = make_request(
url=worker_url,
payload=req_data,
verify=get_cert_file_path(),
context="worker request"
)
return worker_response
# Get the file from the path on the local machine using SCP or SFTP
# or configure S3 to upload to cloud storage.
print(response["response"]["output"][0]["local_path"])
if __name__ == "__main__": if __name__ == "__main__":
from lib.test_utils import test_args asyncio.run(main())
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
result = call_text2image_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
if result is None:
log.error("Text2Image workflow failed")
else:
print(result)
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
+28 -4
View File
@@ -5,12 +5,13 @@ import dataclasses
from typing import Dict, Any from typing import Dict, Any
from functools import cache from functools import cache
from math import ceil from math import ceil
from pathlib import Path
import json
import logging
from lib.data_types import ApiPayload, JsonDataException from lib.data_types import ApiPayload, JsonDataException
log = logging.getLogger(__file__)
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
def count_workload() -> float: def count_workload() -> float:
# Always 100.0 where there is a single instance of ComfyUI handling requests # Always 100.0 where there is a single instance of ComfyUI handling requests
@@ -24,9 +25,32 @@ class ComfyWorkflowData(ApiPayload):
@classmethod @classmethod
def for_test(cls): def for_test(cls):
""" """
Use the variables available to simulate workflows of the required running time If the user has provided a benchmark workflow we can use it here to properly gauge performance.
Otherwise, use the variables available to simulate workflows of the required running time
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090) Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
""" """
# Try to load benchmark.json
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
if benchmark_file.exists():
try:
with open(benchmark_file, "r") as f:
benchmark_workflow = json.load(f)
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"workflow_json": benchmark_workflow
}
)
except (json.JSONDecodeError, IOError):
# JSON is malformed or file can't be read, fall through to default
log.error(f"Failed to benchmark using {benchmark_file}")
# Fallback: read prompts and construct payload
log.info("Using fallback method for benchmarking")
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
test_prompt = random.choice(test_prompts).rstrip() test_prompt = random.choice(test_prompts).rstrip()
return cls( return cls(
input={ input={
@@ -0,0 +1,107 @@
{
"3": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "text, watermark",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
+1
View File
@@ -19,6 +19,7 @@ MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
MODEL_SERVER_ERROR_LOG_MSGS = [ MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted "MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: ", # This error is emitted when the model file is not there at all "Value not in list: ", # This error is emitted when the model file is not there at all
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
] ]
+6 -12
View File
@@ -7,20 +7,13 @@ from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path from utils.ssl import get_cert_file_path
""" from vastai import Serverless
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
"""
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_default_workflow( ENDPOINT_NAME = "my-comfyui-endpoint"
endpoint_group_name: str, api_key: str, server_url: str COST = 100 # Use a constant cost for image generation
) -> None:
def call_default_workflow(client: Serverless) -> None:
WORKER_ENDPOINT = "/prompt" WORKER_ENDPOINT = "/prompt"
COST = 100 COST = 100
route_payload = { route_payload = {
@@ -82,6 +75,7 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"], endpoint=message["endpoint"],
reqnum=message["reqnum"], reqnum=message["reqnum"],
url=message["url"], url=message["url"],
request_idx=message["request_idx"],
) )
workflow = { workflow = {
"3": { "3": {
+357 -427
View File
@@ -1,14 +1,15 @@
import logging import logging
import sys
import json import json
import os
import sys
import subprocess import subprocess
from urllib.parse import urljoin import argparse
from typing import Dict, Any, Optional, Iterator, Union, List from typing import Any, Dict, List, Optional
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig
from vastai import Serverless
import asyncio
# ---------------------- Logging ----------------------
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s", format="%(asctime)s[%(levelname)-5s] %(message)s",
@@ -16,135 +17,20 @@ logging.basicConfig(
) )
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
# ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "the capital of USA is" COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?" TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? "
"What do you think this directory might be for?"
class APIClient: )
"""Lightweight client focused solely on API communication"""
# Remove the generic WORKER_ENDPOINT since we're now going direct
DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4
def __init__(
self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
):
self.endpoint_group_name = endpoint_group_name
self.api_key = api_key
self.server_url = server_url
self.endpoint_api_key = endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service"""
if not self.endpoint_api_key:
raise ValueError("No valid endpoint API key available")
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.endpoint_api_key,
"cost": cost,
}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
timeout=self.DEFAULT_TIMEOUT,
)
response.raise_for_status()
return response.json()
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Create auth data from routing response"""
return {
"signature": message["signature"],
"cost": message["cost"],
"endpoint": message["endpoint"],
"reqnum": message["reqnum"],
"url": message["url"],
}
def _make_request(
self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint"""
# Get worker URL and auth data
cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost)
worker_url = message["url"]
auth_data = self._create_auth_data(message)
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}")
log.debug(f"Payload: {req_data}")
# Make the request using the specified method
if method.upper() == "POST":
response = requests.post(
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET":
response = requests.get(
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
if stream:
return self._handle_streaming_response(response)
else:
return response.json()
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
"""Handle streaming response and yield tokens"""
try:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
yield data # Yield the full chunk
except json.JSONDecodeError:
continue
except Exception as e:
log.error(f"Error handling streaming response: {e}")
raise
def call_completions(
self, config: CompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/completions", stream=config.stream
)
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
)
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- Tooling ----------------------
class ToolManager: class ToolManager:
"""Handles tool definitions and execution""" """Handles tool definitions and execution"""
@@ -164,7 +50,7 @@ class ToolManager:
@staticmethod @staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]: def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""Get the ls tool definition""" """OpenAI-compatible tool schema"""
return [ return [
{ {
"type": "function", "type": "function",
@@ -178,98 +64,217 @@ class ToolManager:
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str: def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result""" """Execute a tool call and return the result"""
function_name = tool_call["function"]["name"] function_name = (tool_call.get("function") or {}).get("name")
if function_name == "list_files": if function_name == "list_files":
return self.list_files() return self.list_files()
else: raise ValueError(f"Unknown tool function: {function_name}")
raise ValueError(f"Unknown tool function: {function_name}")
# ----- Helpers to handle streamed tool_calls assembly -----
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
"""
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
We merge into a per-index state dict until the assistant message finishes.
"""
idx = tc_delta.get("index")
if idx is None:
return
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
if tc_delta.get("id"):
entry["id"] = tc_delta["id"]
fn_delta = tc_delta.get("function") or {}
if "name" in fn_delta and fn_delta["name"]:
entry["function"]["name"] = fn_delta["name"]
if "arguments" in fn_delta and fn_delta["arguments"]:
entry["function"]["arguments"] += fn_delta["arguments"]
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
return [state[i] for i in sorted(state.keys())]
# ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
}
}
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
# ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
}
}
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
# ---------------------- Demo Runner ----------------------
class APIDemo: class APIDemo:
"""Demo and testing functionality for the API client""" """Demo and testing functionality for the API client"""
def __init__( def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
self.client = client self.client = client
self.model = model self.model = model
self.tool_manager = tool_manager or ToolManager() self.tool_manager = tool_manager or ToolManager()
def handle_streaming_response( # ----- Streaming handler -----
self, response_stream, show_reasoning: bool = True async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
) -> str:
"""
Handle streaming chat response and display all output.
"""
full_response = "" full_response = ""
reasoning_content = "" reasoning_content = ""
reasoning_started = False printed_reasoning = False
content_started = False printed_answer = False
for chunk in response_stream: async for chunk in stream:
# Normalize the chunk choice = (chunk.get("choices") or [{}])[0]
if isinstance(chunk, str): delta = choice.get("delta", {})
chunk = chunk.strip()
if chunk.startswith("data: "):
chunk = chunk[6:].strip()
if chunk in ["[DONE]", ""]:
continue
try:
parsed_chunk = json.loads(chunk)
except json.JSONDecodeError:
continue
elif isinstance(chunk, dict):
parsed_chunk = chunk
else:
continue
# Parse delta from the chunk # reasoning tokens
choices = parsed_chunk.get("choices", []) rc = delta.get("reasoning_content")
if not choices: if rc and show_reasoning:
continue if not printed_reasoning:
delta = choices[0].get("delta", {})
reasoning_token = delta.get("reasoning_content", "")
content_token = delta.get("content", "")
# Print reasoning token if applicable
if show_reasoning and reasoning_token:
if not reasoning_started:
print("\n🧠 Reasoning: ", end="", flush=True) print("\n🧠 Reasoning: ", end="", flush=True)
reasoning_started = True printed_reasoning = True
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True) print(rc, end="", flush=True)
reasoning_content += reasoning_token reasoning_content += rc
# Print content token # content tokens
if content_token: content_part = delta.get("content")
if not content_started: if content_part:
if show_reasoning and reasoning_started: if not printed_answer:
print(f"\n💬 Response: ", end="", flush=True) if show_reasoning and printed_reasoning:
print("\n💬 Response: ", end="", flush=True)
else: else:
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
content_started = True printed_answer = True
print(content_token, end="", flush=True) print(content_part, end="", flush=True)
full_response += content_token full_response += content_part
print() # Ensure newline after response
print() # newline
if show_reasoning: if show_reasoning:
if reasoning_started or content_started: if printed_reasoning or printed_answer:
print("\nStreaming completed.") print("\nStreaming completed.")
if reasoning_started: if printed_reasoning:
print(f"Reasoning tokens: {len(reasoning_content.split())}") print(f"Reasoning tokens: {len(reasoning_content.split())}")
if content_started: if printed_answer:
print(f"Response tokens: {len(full_response.split())}") print(f"Response tokens: {len(full_response.split())}")
return full_response return full_response
async def demo_completions(self) -> None:
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
def test_tool_support(self) -> bool: response = await call_completions(
"""Test if the endpoint supports function calling""" client=self.client,
log.debug("Testing endpoint tool calling support...") model=self.model,
prompt=COMPLETIONS_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print("\nResponse:")
print(json.dumps(response, indent=2))
# Try a simple request with minimal tools to test support async def demo_chat(self, use_streaming: bool = True) -> None:
print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
print("=" * 60)
messages = [{"role": "user", "content": CHAT_PROMPT}]
if use_streaming:
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
try:
await self.handle_streaming_response(stream, show_reasoning=True)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
else:
response = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def test_tool_support(self) -> bool:
"""Probe that tool schema is accepted (no actual call)"""
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [ minimal_tool = [
{ {
@@ -277,170 +282,147 @@ class APIDemo:
"function": {"name": "test_function", "description": "Test function"}, "function": {"name": "test_function", "description": "Test function"},
} }
] ]
config = ChatCompletionConfig(
model=self.model,
messages=messages,
max_tokens=10,
tools=minimal_tool,
tool_choice="none", # Don't actually call the tool
)
try: try:
response = self.client.call_chat_completions(config) _ = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
tools=minimal_tool,
tool_choice="none",
max_tokens=10
)
return True return True
except Exception as e: except Exception as e:
log.error(f"Error: Endpoint does not support tool calling: {e}") log.error("Endpoint does not support tool calling: %s", e)
return False return False
def demo_completions(self) -> None: async def demo_ls_tool(self) -> None:
"""Demo: test basic completions endpoint""" """Ask to list files using function calling, then provide final analysis"""
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
config = CompletionConfig(
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
)
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
)
response = self.client.call_completions(config)
if isinstance(response, dict):
print("\nResponse:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_chat(self, use_streaming: bool = True) -> None:
"""
Demo: test chat completions endpoint with optional streaming
"""
print("=" * 60)
print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60)
config = ChatCompletionConfig(
model=self.model,
messages=[{"role": "user", "content": CHAT_PROMPT}],
stream=use_streaming,
)
log.info(f"Testing chat completions with model '{self.model}'...")
response = self.client.call_chat_completions(config)
if use_streaming:
try:
self.handle_streaming_response(response, show_reasoning=True)
except Exception as e:
log.error(f"\nError during streaming: {e}")
import traceback
traceback.print_exc()
return
else:
if isinstance(response, dict):
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
print("=" * 60) print("=" * 60)
print("TOOL USE DEMO: List Directory Contents") print("TOOL USE DEMO: List Directory Contents")
print("=" * 60) print("=" * 60)
# Test if tools are supported first if not await self.test_tool_support():
if not self.test_tool_support():
return return
# Request with tool available messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
messages = [{"role": "user", "content": TOOLS_PROMPT}]
config = ChatCompletionConfig( # First pass: let the model decide tools, stream tool_calls and partial content
stream = await stream_chat_completions(
client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
tools=self.tool_manager.get_ls_tool_definition(), tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto", tool_choice="auto",
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
) )
log.info(f"Making initial request with tool using model '{self.model}'...") assistant_content_buf: List[str] = []
response = self.client.call_chat_completions(config) tool_calls_state: Dict[int, Dict[str, Any]] = {}
printed_reasoning = False
printed_answer = False
if not isinstance(response, dict): async for chunk in stream:
raise ValueError("Expected dict response for tool use") choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
choice = response.get("choices", [{}])[0] rc = delta.get("reasoning_content")
message = choice.get("message", {}) if rc:
if not printed_reasoning:
printed_reasoning = True
print("🧠 Reasoning: ", end="", flush=True)
print(rc, end="", flush=True)
print(f"Assistant response: {message.get('content', 'No content')}") content_part = delta.get("content")
if content_part:
assistant_content_buf.append(content_part)
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(content_part, end="", flush=True)
# Check for tool calls if "tool_calls" in delta and delta["tool_calls"]:
tool_calls = message.get("tool_calls") for tc_delta in delta["tool_calls"]:
if not tool_calls: _merge_tool_call_delta(tool_calls_state, tc_delta)
raise ValueError(
"No tool calls made - model may not support function calling"
)
print(f"Tool calls detected: {len(tool_calls)}") # If no tool calls, were done.
if not tool_calls_state:
print("\n(No tool calls were made.)")
return
# Execute the tool call # Build assistant message with tool_calls
for tool_call in tool_calls: assistant_message = {
function_name = tool_call["function"]["name"] "role": "assistant",
print(f"Executing tool: {function_name}") "content": "".join(assistant_content_buf) if assistant_content_buf else None,
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
}
messages.append(assistant_message)
tool_result = self.tool_manager.execute_tool_call(tool_call) # Execute tools and feed results back
print(f"Tool result:\n{tool_result}") for tc in assistant_message["tool_calls"]:
tool_name = (tc.get("function") or {}).get("name")
call_id = tc.get("id")
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
# Add tool result and continue conversation try:
messages.append(message) # Add assistant's message with tool call args = json.loads(raw_args) if raw_args.strip() else {}
messages.append( except Exception as e:
{ tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
"role": "tool", messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
"tool_call_id": tool_call["id"], continue
"content": tool_result,
}
)
# Get final response try:
final_config = ChatCompletionConfig( if tool_name == "list_files":
model=self.model, tool_result = self.tool_manager.list_files()
messages=messages, else:
tools=self.tool_manager.get_ls_tool_definition(), tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
) except Exception as e:
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
print("Getting final response...") print("\n[Tool executed]", tool_name)
final_response = self.client.call_chat_completions(final_config) print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
if isinstance(final_response, dict): # Second pass: get final streamed answer after tool results
final_choice = final_response.get("choices", [{}])[0] stream2 = await stream_chat_completions(
final_message = final_choice.get("message", {}) client=self.client,
final_content = final_message.get("content", "") model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print("\n" + "=" * 60) final_buf = []
print("FINAL LLM ANALYSIS:") printed_reasoning2 = False
print("=" * 60) printed_answer2 = False
print(final_content)
print("=" * 60)
def interactive_chat(self) -> None: async for chunk in stream2:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
rc2 = delta.get("reasoning_content")
if rc2:
if not printed_reasoning2:
printed_reasoning2 = True
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
print(rc2, end="", flush=True)
c2 = delta.get("content")
if c2:
final_buf.append(c2)
if not printed_answer2:
printed_answer2 = True
print("\n💬 Response (final): ", end="", flush=True)
print(c2, end="", flush=True)
print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:")
print("=" * 60)
print("".join(final_buf))
print("=" * 60)
async def interactive_chat(self) -> None:
"""Interactive chat session with streaming""" """Interactive chat session with streaming"""
print("=" * 60) print("=" * 60)
print("INTERACTIVE STREAMING CHAT") print("INTERACTIVE STREAMING CHAT")
@@ -449,7 +431,7 @@ class APIDemo:
print("Type 'quit' to exit, 'clear' to clear history") print("Type 'quit' to exit, 'clear' to clear history")
print() print()
messages = [] messages: List[Dict[str, Any]] = []
while True: while True:
try: try:
@@ -467,16 +449,15 @@ class APIDemo:
messages.append({"role": "user", "content": user_input}) messages.append({"role": "user", "content": user_input})
config = ChatCompletionConfig(
model=self.model, messages=messages, stream=True, temperature=0.7
)
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
stream = await stream_chat_completions(
response = self.client.call_chat_completions(config) client=self.client,
assistant_content = self.handle_streaming_response( model=self.model,
response, show_reasoning=True messages=messages,
max_tokens=MAX_TOKENS,
temperature=0.7
) )
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
# Add assistant response to conversation history # Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content}) messages.append({"role": "assistant", "content": assistant_content})
@@ -485,115 +466,64 @@ class APIDemo:
print("\n👋 Chat interrupted. Goodbye!") print("\n👋 Chat interrupted. Goodbye!")
break break
except Exception as e: except Exception as e:
log.error(f"\nError: {e}") log.error("\nError: %s", e)
continue continue
def main(): # ---------------------- CLI ----------------------
"""Main function with CLI switches for different tests""" def build_arg_parser() -> argparse.ArgumentParser:
from lib.test_utils import test_args p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", required=True, help="Model to use for requests (required)")
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
# Add mandatory model argument modes = p.add_mutually_exclusive_group(required=False)
test_args.add_argument( modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
"--model", required=True, help="Model to use for requests (required)" modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
) modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
return p
# Add test mode arguments
test_args.add_argument(
"--completion", action="store_true", help="Test completions endpoint"
)
test_args.add_argument(
"--chat",
action="store_true",
help="Test chat completions endpoint (non-streaming)",
)
test_args.add_argument(
"--chat-stream",
action="store_true",
help="Test chat completions endpoint with streaming",
)
test_args.add_argument(
"--tools",
action="store_true",
help="Test function calling with ls tool (non-streaming)",
)
test_args.add_argument(
"--interactive",
action="store_true",
help="Start interactive streaming chat session",
)
args = test_args.parse_args() async def main_async():
args = build_arg_parser().parse_args()
# Check that only one test mode is selected selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
test_modes = [ if selected == 0:
args.completion,
args.chat,
args.chat_stream,
args.tools,
args.interactive,
]
selected_count = sum(test_modes)
if selected_count == 0:
print("Please specify exactly one test mode:") print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint") print(" --completion : Test completions endpoint")
print(" --chat : Test chat completions endpoint (non-streaming)") print(" --chat : Test chat completions endpoint (non-streaming)")
print(" --chat-stream : Test chat completions endpoint with streaming") print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool (non-streaming)") print(" --tools : Test function calling with ls tool")
print(" --interactive : Start interactive streaming chat session") print(" --interactive : Start interactive streaming chat session")
print( print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
sys.exit(1) sys.exit(1)
elif selected_count > 1: elif selected > 1:
print("Please specify exactly one test mode") print("Please specify exactly one test mode")
sys.exit(1) sys.exit(1)
print(f"Using model: {args.model}")
print("=" * 60)
try: try:
endpoint_api_key = Endpoint.get_endpoint_api_key( async with Serverless() as client:
endpoint_name=args.endpoint_group_name, demo = APIDemo(client, args.model, ToolManager())
account_api_key=args.api_key,
instance=args.instance,
)
if not endpoint_api_key: if args.completion:
log.error( await demo.demo_completions()
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting." elif args.chat:
) await demo.demo_chat(use_streaming=False)
sys.exit(1) elif args.chat_stream:
await demo.demo_chat(use_streaming=True)
# Create the core API client elif args.tools:
client = APIClient( await demo.demo_ls_tool()
endpoint_group_name=args.endpoint_group_name, elif args.interactive:
api_key=args.api_key, await demo.interactive_chat()
server_url=Endpoint.get_autoscaler_server_url(args.instance),
endpoint_api_key=endpoint_api_key,
)
# Create tool manager and demo (passing the model parameter)
tool_manager = ToolManager()
demo = APIDemo(client, args.model, tool_manager)
print(f"Using model: {args.model}")
print("=" * 60)
# Run the selected test
if args.completion:
demo.demo_completions()
elif args.chat:
demo.demo_chat(use_streaming=False)
elif args.chat_stream:
demo.demo_chat(use_streaming=True)
elif args.tools:
demo.demo_ls_tool()
elif args.interactive:
demo.interactive_chat()
except Exception as e: except Exception as e:
log.error(f"Error during test: {e}", exc_info=True) log.error("Error during test: %s", e, exc_info=True)
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
main() asyncio.run(main_async())
+49 -113
View File
@@ -1,125 +1,61 @@
import logging from vastai import Serverless
import sys import asyncio
import json
from urllib.parse import urljoin
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
logging.basicConfig( ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
level=logging.DEBUG, MAX_TOKENS = 1024
format="%(asctime)s[%(levelname)-5s] %(message)s", PROMPT = "Think step by step: Tell me about the Python programming language."
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def call_generate(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None: payload = {
WORKER_ENDPOINT = "/generate" "inputs": PROMPT,
COST = 100 "parameters": {
route_payload = { "max_new_tokens": MAX_TOKENS,
"endpoint": endpoint_group_name, "temperature": 0.7,
"api_key": api_key, "return_full_text": False
"cost": COST, }
} }
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
auth_data = dict( resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=url,
)
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500)) print(resp["response"]["generated_text"])
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
res = response.json()
print(res)
def call_generate_stream( async def call_generate_stream(client: Serverless) -> None:
endpoint_group_name: str, api_key: str, server_url: str endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
) -> None:
WORKER_ENDPOINT = "/generate_stream" payload = {
COST = 100 "inputs": PROMPT,
route_payload = { "parameters": {
"endpoint": endpoint_group_name, "max_new_tokens": MAX_TOKENS,
"api_key": api_key, "temperature": 0.7,
"cost": COST, "do_sample": True,
"return_full_text": False,
}
} }
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
print(f"url: {url}")
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
response = requests.post(url, json=req_data, stream=True)
response.raise_for_status() # Raise an exception for bad status codes
for line in response.iter_lines():
payload = line.decode().lstrip("data:").rstrip()
if payload:
try:
data = json.loads(payload)
print(data["token"]["text"], end="")
sys.stdout.flush()
except (json.JSONDecodeError, KeyError) as e:
log.warning(f"Failed to parse streaming response: {e}")
continue
print()
resp = await endpoint.request(
"/generate_stream",
payload,
cost=MAX_TOKENS,
stream=True,
)
stream = resp["response"]
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("Answer:\n", end="", flush=True)
print(tok, end="", flush=True)
async def main():
async with Serverless() as client:
await call_generate(client)
await call_generate_stream(client)
if __name__ == "__main__": if __name__ == "__main__":
from lib.test_utils import test_args asyncio.run(main())
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
try:
call_generate(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_generate_stream(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")