Compare commits

...

92 Commits

Author SHA1 Message Date
Lucas Armand 7be8aa6397 pin pycares 2025-12-10 17:38:03 -08:00
Colter-Downing 138fc3ac47 Merge pull request #71 from vast-ai/AUTO-comfyui-updates
Auto comfyui updates
2025-12-04 10:55:12 -08:00
Colter Downing 222ac2a0dd default endpoint name 2025-12-04 10:54:55 -08:00
Colter Downing 40aed9b5f8 adding s3 as an option 2025-12-04 10:52:57 -08:00
Colter Downing d4d36bf86e done with comfy updates 2025-12-03 20:45:55 -08:00
Colter Downing e839cfc6e8 include view in API wrapper 2025-12-03 20:22:45 -08:00
Colter Downing f04138e13b update to be able to get images 2025-12-03 20:16:25 -08:00
Colter-Downing de3aa87c8f Merge pull request #70 from vast-ai/AUTO-tgi-client-edits
update tgi client
2025-12-03 18:40:01 -08:00
Colter Downing 6b5b1341a7 update tgi client 2025-12-03 18:38:42 -08:00
Colter-Downing 8be92c03de Merge pull request #69 from vast-ai/AUTO-874--fix-openai-worker-client
defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first
2025-12-03 16:59:56 -08:00
Colter Downing adedb8ba90 defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first if present 2025-12-03 16:57:28 -08:00
LucasArmandVast 2f543c01ad Merge pull request #68 from vast-ai/fix-vllm-concurrency
Increase model wait time for vLLM
2025-12-03 16:13:51 -05:00
Lucas Armand 0bcd2219ea Increase model wait time for vLLM 2025-12-03 12:38:52 -08:00
LucasArmandVast 0339b471c5 Merge pull request #66 from vast-ai/synthesis
PyWorker Error Handling
2025-11-25 16:02:26 -08:00
Lucas Armand e143162438 bumpy pyworker version 2025-11-25 16:01:23 -08:00
Lucas Armand 7986e51e9e early errors 2025-11-24 15:24:06 -08:00
Lucas Armand 9c6ab78503 Move model log line 2025-11-24 15:22:23 -08:00
Lucas Armand 45e0c7d9ca Move model log rotate to top 2025-11-24 15:02:33 -08:00
LucasArmandVast 7a792fd176 Merge pull request #64 from vast-ai/add-llama-log
add llama log
2025-11-21 10:24:27 -08:00
Lucas Armand e0449cb3c7 add llama log 2025-11-21 10:22:16 -08:00
Lucas Armand a4339bd3f1 hotfix: add f 2025-11-12 16:10:55 -08:00
Lucas Armand 2b26e5e20c hotfix: remove g 2025-11-12 16:01:57 -08:00
LucasArmandVast d3727d4fd7 Merge pull request #58 from vast-ai/update-client-scripts
Update client scripts
2025-11-12 10:22:42 -08:00
Lucas Armand a47c9d1ed0 remove test bugs 2025-11-11 18:13:46 -08:00
Lucas Armand 0b14562a63 dont exit on pyworker fail 2025-11-11 17:57:08 -08:00
Lucas Armand de9b50abb9 use set +e 2025-11-11 17:53:36 -08:00
Lucas Armand c510801723 fix 2025-11-11 17:49:34 -08:00
Lucas Armand a12523b1d2 Added bad code to tgi server to test 2025-11-11 17:41:12 -08:00
Lucas Armand eedf81c0a3 Updated readme and .gitignore 2025-11-11 17:18:40 -08:00
Lucas Armand 3adec1826d minor changes 2025-11-11 17:11:38 -08:00
Lucas Armand b55bfa9611 Updated clients, include vastai-sdk, handle non-UTF-8 2025-11-11 17:09:28 -08:00
LucasArmandVast 7db54f3bd7 Merge pull request #55 from vast-ai/use-mtoken
Use mtoken
2025-11-10 11:54:04 -08:00
LucasArmandVast d63a060202 Merge pull request #56 from vast-ai/obfuscate-mtoken
Obfuscate mtoken in logs
2025-11-10 11:53:17 -08:00
Lucas Armand c6521cb6d4 add ... 2025-11-07 10:10:35 -08:00
Lucas Armand b7fe4ebb91 Obfuscate mtoken in logs 2025-11-07 10:02:39 -08:00
Lucas Armand 8ae7b74605 bump version to 0.2.0 2025-11-05 13:32:21 -08:00
Lucas Armand 106067d716 bump version to 0.1.1 2025-11-04 17:15:59 -08:00
Lucas Armand f5134d4bf5 Fix spelling mistake 2025-11-04 16:59:39 -08:00
Lucas Armand 47e5460532 added mtoken 2025-11-04 15:55:14 -08:00
Colter-Downing ec2ac0a21a Merge pull request #52 from vast-ai/remove-sleeps-and-delays
Remove sleeps and delays
2025-10-30 11:53:39 -07:00
Abiola Akinnubi 2cde573c56 Merge pull request #48 from vast-ai/comfy-request-idx
Added request_idx to comfy auth_data
2025-10-30 11:27:35 -07:00
Abiola Akinnubi b2e4a5db0c Merge pull request #49 from vast-ai/unsecure_report_addr
Added caller for REPORT_ADDR to backend.py to use the report add
2025-10-30 10:39:46 -07:00
Abiola Akinnubi 7437028cb2 Added caller for REPORT_ADDR to backend.py 2025-10-29 18:02:17 -07:00
edgaratvast 02c8307af7 remove redis pubsub from pyworker (#53)
Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-29 17:07:56 -07:00
Colter Downing 7c0f316eeb leave the env vars alone! 2025-10-29 11:36:46 -07:00
Colter Downing b4025a744f remove env var writing 2025-10-29 09:58:09 -07:00
Colter Downing d190308329 removed 5 sec sleep and warmup request on load 2025-10-29 09:57:46 -07:00
LucasArmandVast 9f5a432513 Merge pull request #51 from vast-ai/delete-reqs-hotfix
Redis subscriber queue patch
2025-10-28 16:07:28 -07:00
Lucas Armand e09f1fa953 patch for redis queue 2025-10-28 16:03:50 -07:00
edgaratvast ba6f1c2e4b Fix signature (#50)
* change order of fields in auth_data to match autoscaler for signature verification

* also ignore __request_id

* Revert "change order of fields in auth_data to match autoscaler for signature verification" so that it's alphabetical again

This reverts commit b8223879c9.

* enforce alphabetical json dumping of message for signature verification

---------

Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-28 16:01:32 -07:00
Abiola Akinnubi 944f83fc03 Removed extra spaces from operator assignment 2025-10-28 21:03:52 +00:00
edgaratvast 298590fb88 Merge pull request #45 from vast-ai/new-pyworker
New PyWorker
2025-10-28 14:02:53 -07:00
Lucas Armand 814c3acd4c remove unused code 2025-10-28 13:43:57 -07:00
Lucas Armand 22bca74087 Prevent load time race 2025-10-27 18:25:21 -07:00
Lucas Armand 9c795e2a01 removed bad code 2025-10-27 17:03:13 -07:00
Lucas Armand 830b532781 Trying unified delete 2025-10-27 16:57:52 -07:00
LucasArmandVast d6a6e34c6b Merge branch 'main' into new-pyworker 2025-10-27 12:43:49 -07:00
Colter-Downing ac1e109c48 Merge pull request #47 from vast-ai/new-pyworker-vllm-prefix-cache
vLLM Prefix caching, benchmark bug fix, test load script
2025-10-27 12:30:34 -07:00
Colter Downing d6eb498ee4 catch the case where all benchmarks fail (sets error) 2025-10-27 12:01:55 -07:00
Abiola Akinnubi f56bbc0ebe Added request_idx to comfy auth_data 2025-10-27 03:17:06 +00:00
Colter Downing bcecd6df40 Suppress matplot debug logs 2025-10-25 16:18:02 -07:00
Lucas Armand 4d9bf2048c Fix 2025-10-24 15:44:38 -07:00
Lucas Armand 7788bc4a62 Added some debug logs 2025-10-24 15:41:00 -07:00
Lucas Armand 37ad3f8d46 asyncio in metrics 2025-10-23 10:18:31 -07:00
Rob Ballantyne 70d51bafe1 Merge pull request #36 from robballantyne/feat/comfyui-json-benchmark-workflow-from-file 2025-10-23 17:05:48 +01:00
Rob Ballantyne 63909736bb Merge pull request #4 from robballantyne/feat/comfyui-json-benchmark-workflow-from-file-no-silent-fail
Feat/comfyui json benchmark workflow from file no silent fail
2025-10-23 17:02:12 +01:00
Rob Ballantyne f4f7080df1 Re-add comment 2025-10-23 17:00:28 +01:00
Rob Ballantyne d51a338e8f log when benchmark file not used 2025-10-23 16:41:02 +01:00
Rob Ballantyne 92a04bd7af No silent fail if benchmark file is missing 2025-10-23 13:41:03 +01:00
Lucas Armand 0f13506938 Send success param 2025-10-22 10:18:59 -07:00
Lucas Armand 01e752d31f use more asyncio sleep 2025-10-21 18:52:13 -07:00
Lucas Armand 5edfa968ca async sleep 2025-10-21 18:49:48 -07:00
Lucas Armand 5b5ef7227a nvm moved it here 2025-10-21 18:20:11 -07:00
Lucas Armand 16990ff8ff move start request 2025-10-21 18:18:44 -07:00
Lucas Armand 9748176366 fixed semaphore acquire bool 2025-10-21 18:12:23 -07:00
Lucas Armand b39193ae70 check for sem acquire 2025-10-21 18:02:14 -07:00
Lucas Armand 9a6ca5d412 added versioning 2025-10-21 15:42:43 -07:00
Lucas Armand e9ba1b03e4 Use delete_requests and track request_idxs 2025-10-21 11:59:35 -07:00
LucasArmandVast c98d661513 Merge pull request #39 from vast-ai/remove-time-divide
PyWorker fixes for cur_load and acks bug
2025-10-13 10:06:22 -07:00
Lucas Armand f6fd1c6ac1 merge 2025-10-09 18:15:55 -07:00
Lucas Armand 055e346c8c Send metrics on request start 2025-10-09 10:13:50 -07:00
Lucas Armand 1cedb28acf Removed division by elapsed time, since autoscaler cur_load in units of workload 2025-10-08 16:54:18 -07:00
Rob Ballantyne ec25dda3ad Merge branch 'vast-ai:main' into feat/comfyui-json-benchmark-workflow-from-file 2025-10-08 14:49:32 +01:00
Colter-Downing 0397af719d Merge pull request #37 from robballantyne/bugfix/healthcheck-endpoint
Fix healthcheck endpoint URL

Tested and merged by Colter
2025-10-06 15:11:27 -07:00
Rob Ballantyne 4fdc314fd9 Fix healthcheck endpoint URL 2025-10-06 22:16:09 +01:00
Rob Ballantyne 3786cf978d Add awareness of errors thrown by the provisioning script 2025-10-05 23:14:59 +01:00
Rob Ballantyne a86d4bcf9c Import json 2025-10-05 23:05:33 +01:00
Rob Ballantyne e9b6a14a5e Import Path 2025-10-05 22:59:19 +01:00
Rob Ballantyne cadac033e1 Enables use of custom workflow for benchmarking
Retains existing method is misc/benchmark.json is nopt present
2025-10-05 22:53:22 +01:00
Colter-Downing 639d82f5b4 Merge pull request #35 from vast-ai/AUTO-664--Healthcheck-error
Fix healthcheck with separate session
2025-10-02 12:51:19 -07:00
Colter Downing 25db78e39d Fix healthcheck with separate session 2025-10-01 18:04:31 -07:00
Scott-Laytart 4e2f2311d0 Merge pull request #33 from vast-ai/comfy-blind-fix-override
undo the fix for comfy yesterday.
2025-09-03 11:50:07 -07:00
23 changed files with 2205 additions and 904 deletions
+2 -1
View File
@@ -2,4 +2,5 @@
.envrc
__pycache__
bin/
lib64
lib64
.venv
+4 -3
View File
@@ -39,11 +39,12 @@ reporting these metrics to the autoscaler.
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447)
* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d)
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938)
Currently available workers:
* `hello_world`: A simple example worker for a basic LLM server.
* `openai`: A simple example worker for a basic vLLM server.
* `comfyui`: A worker for the ComfyUI image generation backend.
* `tgi`: A worker for the Text Generation Inference backend.
+108 -61
View File
@@ -12,6 +12,7 @@ from distutils.util import strtobool
from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
import asyncio
import requests
from Crypto.Signature import pkcs1_15
@@ -25,8 +26,12 @@ from lib.data_types import (
LogAction,
ApiPayload_T,
JsonDataException,
RequestMetrics,
BenchmarkResult
)
VERSION = "0.2.1"
MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
@@ -53,15 +58,25 @@ class Backend:
EndpointHandler # this endpoint handler will be used for benchmarking
)
log_actions: List[Tuple[LogAction, str]]
max_wait_time: float = 10.0
reqnum = -1
version = VERSION
msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
mtoken: str = dataclasses.field(
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
)
def __post_init__(self):
self.metrics = Metrics()
self.metrics._set_version(self.version)
self.metrics._set_mtoken(self.mtoken)
self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False
@@ -96,23 +111,19 @@ class Backend:
#######################################Private#######################################
def _fetch_pubkey(self):
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = None
for _ in range(5):
try:
key = RSA.import_key(result)
break
except ValueError as e:
log.debug(f"Error downloading key: {e}")
time.sleep(15)
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
report_addr = self.report_addr.rstrip("/")
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
try:
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = RSA.import_key(result)
if key is not None:
return key
except (ValueError , subprocess.CalledProcessError) as e:
log.debug(f"Error downloading key: {e}")
self.backend_errored("Failed to get autoscaler pubkey")
async def __handle_request(
self,
@@ -128,55 +139,56 @@ class Backend:
except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload)
return web.Response(status=500)
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics)
raise asyncio.CancelledError
async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
await self.sem.acquire()
log.debug(
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
try:
response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status
log.debug(
" ".join(
[
f"request with reqnum:{auth_data.reqnum}",
f"request with reqnum:{request_metrics.reqnum}",
f"returned status code: {status_code},",
]
)
)
res = await handler.generate_client_response(request, response)
self.metrics._request_success(workload=workload)
self.metrics._request_success(request_metrics)
return res
except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(workload=workload)
self.metrics._request_errored(request_metrics)
return web.Response(status=500)
finally:
self.metrics._request_end(
workload=workload,
reqnum=auth_data.reqnum,
)
self.sem.release()
###########
if self.__check_signature(auth_data) is False:
self.metrics._request_reject(request_metrics)
return web.Response(status=401)
if self.metrics.model_metrics.wait_time > self.max_wait_time:
self.metrics._request_reject(request_metrics)
return web.Response(status=429)
acquired = False
try:
self.metrics._request_start(request_metrics)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
acquired = True
log.debug(
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait(
[
create_task(make_request()),
@@ -184,24 +196,52 @@ class Backend:
],
return_when=FIRST_COMPLETED,
)
[task.cancel() for task in pending]
return done.pop().result()
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
done_task = done.pop()
try:
return done_task.result()
except Exception as e:
log.debug(f"Request task raised exception: {e}")
return web.Response(status=500)
except asyncio.CancelledError:
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499)
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
finally:
# Always release the semaphore if it was acquired
if acquired:
self.sem.release()
self.metrics._request_end(request_metrics)
@cached_property
def healthcheck_session(self):
"""Dedicated session for healthchecks to avoid conflicts with API session"""
log.debug("creating dedicated healthcheck session")
connector = TCPConnector(
force_close=True, # Keep this for isolation
enable_cleanup_closed=True,
)
timeout = ClientTimeout(total=10) # Reasonable timeout for healthchecks
return ClientSession(timeout=timeout, connector=connector)
async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint
if health_check_url is None:
log.debug("No healthcheck endpoint defined, skipping healthcheck")
return
while True:
await sleep(10)
if self.__start_healthcheck is False:
continue
try:
log.debug(f"Performing healthcheck on {health_check_url}")
async with self.session.get(health_check_url) as response:
async with self.healthcheck_session.get(health_check_url) as response:
if response.status == 200:
log.debug("Healthcheck successful")
elif response.status == 503:
@@ -210,7 +250,6 @@ class Backend:
f"Healthcheck failed with status: {response.status}"
)
else:
# endpoint not ready yet so bail
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}")
@@ -218,7 +257,7 @@ class Backend:
async def _start_tracking(self) -> None:
await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
)
def backend_errored(self, msg: str) -> None:
@@ -250,7 +289,7 @@ class Backend:
message = {
key: value
for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature"
if key != "signature" and key != "__request_id"
}
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug(
@@ -260,7 +299,7 @@ class Backend:
elif message in self.msg_history:
log.debug(f"message: {message} already in message history")
return False
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
@@ -279,10 +318,10 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark")
# trigger model load
payload = self.benchmark_handler.make_benchmark_payload()
_ = await self.__call_api(
handler=self.benchmark_handler, payload=payload
)
# payload = self.benchmark_handler.make_benchmark_payload()
# _ = await self.__call_api(
# handler=self.benchmark_handler, payload=payload
# )
return float(f.readline())
except FileNotFoundError:
pass
@@ -297,18 +336,26 @@ class Backend:
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
start = time.time()
tasks = []
total_workload = 0
benchmark_requests = []
for _ in range(concurrent_requests):
for i in range(concurrent_requests):
payload = self.benchmark_handler.make_benchmark_payload()
total_workload += payload.count_workload()
tasks.append(
self.__call_api(handler=self.benchmark_handler, payload=payload)
workload = payload.count_workload()
task = self.__call_api(handler=self.benchmark_handler, payload=payload)
benchmark_requests.append(
BenchmarkResult(request_idx=i, workload=workload, task=task)
)
responses = await gather(*tasks)
responses = await gather(*[br.task for br in benchmark_requests])
for br, response in zip(benchmark_requests, responses):
br.response = response
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
time_elapsed = time.time() - start
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
if successful_responses == 0:
self.backend_errored("No successful responses from benchmark")
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
throughput = total_workload / time_elapsed
sum_throughput += throughput
@@ -322,7 +369,7 @@ class Backend:
f"Run: {run}, concurrent_requests: {concurrent_requests}",
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
f"Throughput: {throughput} workload/s",
f"Successful responses: {len([r for r in responses if r.status == 200])}",
f"Successful responses: {successful_responses}/{concurrent_requests}",
"#" * 60,
]
)
@@ -349,7 +396,7 @@ class Backend:
)
# some backends need a few seconds after logging successful startup before
# they can begin accepting requests
await sleep(5)
# await sleep(5)
try:
max_throughput = await run_benchmark()
self.__start_healthcheck = True
@@ -370,13 +417,13 @@ class Backend:
async def tail_log():
log.debug(f"tailing file: {self.model_log_file}")
async with await open_file(self.model_log_file) as f:
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
while True:
line = await f.readline()
if line:
await handle_log_line(line.rstrip())
else:
time.sleep(LOG_POLL_INTERVAL)
await asyncio.sleep(LOG_POLL_INTERVAL)
###########
+52 -11
View File
@@ -3,7 +3,7 @@ import logging
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
from aiohttp import web, ClientResponse
import inspect
@@ -65,10 +65,11 @@ class ApiPayload(ABC):
class AuthData:
"""data used to authenticate requester"""
signature: str
cost: str
endpoint: str
reqnum: int
request_idx: int
signature: str
url: str
@classmethod
@@ -189,13 +190,34 @@ class SystemMetrics:
self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage
def reset(self):
def reset(self, expected: float | None) -> None:
# autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading
self.model_loading_time = None
if self.model_loading_time == expected:
self.model_loading_time = None
@dataclass
class RequestMetrics:
"""Tracks metrics for an active request."""
request_idx: int
reqnum: int
workload: float
status: str
success: bool = False
@dataclass
class BenchmarkResult:
request_idx: int
workload: float
task: Awaitable[ClientResponse]
response: Optional[ClientResponse] = None
@property
def is_successful(self) -> bool:
return self.response is not None and self.response.status == 200
@dataclass
class ModelMetrics:
"""Model specific metrics"""
@@ -205,12 +227,14 @@ class ModelMetrics:
workload_received: float
workload_cancelled: float
workload_errored: float
workload_rejected: float
# these are not
workload_pending: float
error_msg: Optional[str]
max_throughput: float
requests_recieved: Set[int] = field(default_factory=set)
requests_working: Set[int] = field(default_factory=set)
requests_working: dict[int, RequestMetrics] = field(default_factory=dict)
requests_deleting: list[RequestMetrics] = field(default_factory=list)
last_update: float = field(default_factory=time.time)
@classmethod
@@ -220,19 +244,30 @@ class ModelMetrics:
workload_served=0.0,
workload_cancelled=0.0,
workload_errored=0.0,
workload_rejected=0.0,
workload_received=0.0,
error_msg=None,
max_throughput=0.0,
)
@property
def cur_perf(self) -> float:
return max(self.workload_served / (time.time() - self.last_update), 0.0)
@property
def workload_processing(self) -> float:
return max(self.workload_received - self.workload_cancelled, 0.0)
@property
def wait_time(self) -> float:
if (len(self.requests_working) == 0):
return 0.0
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
@property
def cur_load(self) -> float:
return sum([request.workload for request in self.requests_working.values()])
@property
def working_request_idxs(self) -> list[int]:
return [req.request_idx for req in self.requests_working.values()]
def set_errored(self, error_msg):
self.reset()
self.error_msg = error_msg
@@ -242,16 +277,21 @@ class ModelMetrics:
self.workload_received = 0
self.workload_cancelled = 0
self.workload_errored = 0
self.workload_rejected = 0
self.last_update = time.time()
@dataclass
class AutoScalaerData:
class AutoScalerData:
"""Data that is reported to autoscaler"""
id: int
mtoken: str
version: str
loadtime: float
cur_load: float
rej_load: float
new_load: float
error_msg: str
max_perf: float
cur_perf: float
@@ -260,6 +300,7 @@ class AutoScalaerData:
num_requests_working: int
num_requests_recieved: int
additional_disk_usage: float
working_request_idxs: list[int]
url: str
+170 -39
View File
@@ -5,13 +5,14 @@ import json
from asyncio import sleep
from dataclasses import dataclass, asdict, field
from functools import cache
import asyncio
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError
import requests
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics
from typing import Awaitable, NoReturn, List
METRICS_UPDATE_INTERVAL = 1
DELETE_REQUESTS_INTERVAL = 1
log = logging.getLogger(__file__)
@@ -26,7 +27,10 @@ def get_url() -> str:
@dataclass
class Metrics:
version: str = "0"
mtoken: str = ""
last_metric_update: float = 0.0
last_request_served: float = 0.0
update_pending: bool = False
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
report_addr: List[str] = field(
@@ -35,42 +39,84 @@ class Metrics:
url: str = field(default_factory=get_url)
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
_session: ClientSession | None = field(default=None, init=False, repr=False)
def _request_start(self, workload: float, reqnum: int) -> None:
async def http(self) -> ClientSession:
if self._session is None:
self._session = ClientSession(
timeout=ClientTimeout(total=10),
connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True)
)
return self._session
async def aclose(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
def _request_start(self, request: RequestMetrics) -> None:
"""
this function is called prior to forwarding a request to a model API.
"""
log.debug("request start")
self.model_metrics.workload_pending += workload
self.model_metrics.workload_received += workload
self.model_metrics.requests_recieved.add(reqnum)
self.model_metrics.requests_working.add(reqnum)
request.status = "Started"
self.model_metrics.workload_pending += request.workload
self.model_metrics.workload_received += request.workload
self.model_metrics.requests_recieved.add(request.reqnum)
self.model_metrics.requests_working[request.reqnum] = request
self.update_pending = True
def _request_end(self, workload: float, reqnum: int) -> None:
def _request_end(self, request: RequestMetrics) -> None:
"""
this function is called after handling of a request ends, regardless of the outcome
"""
self.model_metrics.workload_pending -= workload
self.model_metrics.requests_working.discard(reqnum)
self.model_metrics.workload_pending -= request.workload
self.model_metrics.requests_working.pop(request.reqnum, None)
self.model_metrics.requests_deleting.append(request)
self.last_request_served = time.time()
def _request_success(self, workload: float) -> None:
def _request_success(self, request: RequestMetrics) -> None:
"""
this function is called after a response from model API is received and forwarded.
"""
self.model_metrics.workload_served += workload
self.model_metrics.workload_served += request.workload
request.status = "Success"
request.success = True
self.update_pending = True
def _request_errored(self, workload: float) -> None:
def _request_errored(self, request: RequestMetrics) -> None:
"""
this function is called if model API returns an error
"""
self.model_metrics.workload_errored += workload
self.model_metrics.workload_errored += request.workload
request.status = "Error"
request.success = False
self.update_pending = True
def _request_canceled(self, workload: float) -> None:
def _request_canceled(self, request: RequestMetrics) -> None:
"""
this function is called if client drops connection before model API has responded
"""
self.model_metrics.workload_cancelled += workload
self.model_metrics.workload_cancelled += request.workload
request.success = True
request.status = "Cancelled"
def _request_reject(self, request: RequestMetrics):
"""
this function is called if the current wait time for the model is above max_wait_time
"""
self.model_metrics.requests_recieved.add(request.reqnum)
self.model_metrics.requests_deleting.append(request)
self.model_metrics.workload_rejected += request.workload
request.success = False
request.status = "Rejected"
self.update_pending = True
async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]:
while True:
await sleep(DELETE_REQUESTS_INTERVAL)
if len(self.model_metrics.requests_deleting) > 0:
await self.__send_delete_requests_and_reset()
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
while True:
@@ -78,10 +124,10 @@ class Metrics:
elapsed = time.time() - self.last_metric_update
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed)
await self.__send_metrics_and_reset()
elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed)
await self.__send_metrics_and_reset()
def _model_loaded(self, max_throughput: float) -> None:
self.system_metrics.model_loading_time = (
@@ -94,49 +140,130 @@ class Metrics:
self.model_metrics.set_errored(error_msg)
self.system_metrics.model_is_loaded = True
def _set_version(self, version: str) -> None:
self.version = version
def _set_mtoken(self, mtoken: str) -> None:
self.mtoken = mtoken
#######################################Private#######################################
def __send_metrics_and_reset(self, elapsed):
async def __send_delete_requests_and_reset(self):
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
data = {
"worker_id": self.id,
"mtoken": self.mtoken,
"request_idxs": idxs,
"success": success_flag,
}
log.debug(
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
)
full_path = report_addr.rstrip("/") + "/delete_requests/"
for attempt in range(1, 4):
try:
session = await self.http()
async with session.post(full_path, json=data) as res:
log.debug(f"delete_requests response: {res.status}")
res.raise_for_status()
return True
except asyncio.TimeoutError:
log.debug("delete_requests timed out")
except (ClientResponseError, Exception) as e:
log.debug(f"delete_requests failed with error: {e}")
await asyncio.sleep(2)
log.debug(f"retrying delete_request, attempt: {attempt}")
return False
def compute_autoscaler_data() -> AutoScalaerData:
return AutoScalaerData(
# Take a snapshot of what we plan to send this tick.
# New arrivals after this snapshot will remain in the queue for the next tick.
snapshot = list(self.model_metrics.requests_deleting)
success_idxs = [r.request_idx for r in snapshot if r.success is True]
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
if not success_idxs and not failed_idxs:
return # nothing to do
for report_addr in self.report_addr:
# TODO: Add a Redis subscriber queue for delete_requests
if report_addr == "https://cloud.vast.ai/api/v0":
# Patch: ignore the Redis API report_addr
continue
sent_success = True
sent_failed = True
if success_idxs:
sent_success = await post(report_addr, success_idxs, True)
if failed_idxs:
sent_failed = await post(report_addr, failed_idxs, False)
if sent_success and sent_failed:
# Remove only the items we actually sent from the live queue.
sent_set = set(success_idxs) | set(failed_idxs)
self.model_metrics.requests_deleting[:] = [
r for r in self.model_metrics.requests_deleting
if r.request_idx not in sent_set
]
break
async def __send_metrics_and_reset(self):
loadtime_snapshot = self.system_metrics.model_loading_time
def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData(
id=self.id,
loadtime=(self.system_metrics.model_loading_time or 0.0),
cur_load=(self.model_metrics.workload_processing / elapsed),
mtoken=self.mtoken,
version=self.version,
loadtime=(loadtime_snapshot or 0.0),
new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected,
max_perf=self.model_metrics.max_throughput,
cur_perf=self.model_metrics.cur_perf,
cur_perf=self.model_metrics.workload_served,
error_msg=self.model_metrics.error_msg or "",
num_requests_working=len(self.model_metrics.requests_working),
num_requests_recieved=len(self.model_metrics.requests_recieved),
additional_disk_usage=self.system_metrics.additional_disk_usage,
working_request_idxs=self.model_metrics.working_request_idxs,
cur_capacity=0,
max_capacity=0,
url=self.url,
)
def send_data(report_addr: str) -> bool:
async def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data()
full_path = report_addr.rstrip("/") + "/worker_status/"
log_data = asdict(data)
def obfuscate(secret: str) -> str:
if secret is None:
return ""
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
log.debug(
"\n".join(
[
"#" * 60,
f"sending data to autoscaler",
f"{json.dumps((asdict(data)), indent=2)}",
f"{json.dumps(log_data, indent=2)}",
"#" * 60,
]
)
)
full_path = report_addr.rstrip("/") + "/worker_status/"
for attempt in range(1, 4):
try:
res = requests.post(full_path, json=asdict(data), timeout=1)
res.raise_for_status()
session = await self.http()
async with session.post(full_path, json=asdict(data)) as res:
res.raise_for_status()
return True
except requests.Timeout:
except asyncio.TimeoutError:
log.debug(f"autoscaler status update timed out")
except Exception as e:
except (ClientResponseError, Exception) as e:
log.debug(f"autoscaler status update failed with error: {e}")
time.sleep(2)
await asyncio.sleep(2)
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
log.debug(f"failed to send update through {report_addr}")
return False
@@ -145,11 +272,15 @@ class Metrics:
self.system_metrics.update_disk_usage()
sent = False
for report_addr in self.report_addr:
success = send_data(report_addr)
if success is True:
if await send_data(report_addr):
sent = True
break
self.update_pending = False
self.model_metrics.reset()
self.system_metrics.reset()
self.last_metric_update = time.time()
if sent:
# clear the one-shot loadtime only if we actually sent *this* value
self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False
self.model_metrics.reset()
self.last_metric_update = time.time()
+45 -25
View File
@@ -3,38 +3,58 @@ import logging
from typing import List
import ssl
from asyncio import run, gather
import asyncio
from lib.backend import Backend
from lib.metrics import Metrics
from aiohttp import web
log = logging.getLogger(__file__)
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("getting certificate...")
use_ssl = os.environ.get("USE_SSL", "false") == "true"
if use_ssl is True:
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(
certfile="/etc/instance.crt",
keyfile="/etc/instance.key",
)
else:
ssl_context = None
try:
log.debug("getting certificate...")
use_ssl = os.environ.get("USE_SSL", "false") == "true"
if use_ssl is True:
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(
certfile="/etc/instance.crt",
keyfile="/etc/instance.key",
)
else:
ssl_context = None
async def main():
log.debug("starting server...")
app = web.Application()
app.add_routes(routes)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(
runner,
ssl_context=ssl_context,
port=int(os.environ["WORKER_PORT"]),
**kwargs
)
await gather(site.start(), backend._start_tracking())
async def main():
log.debug("starting server...")
app = web.Application()
app.add_routes(routes)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(
runner,
ssl_context=ssl_context,
port=int(os.environ["WORKER_PORT"]),
**kwargs
)
await gather(site.start(), backend._start_tracking())
run(main())
run(main())
except Exception as e:
err_msg = f"PyWorker failed to launch: {e}"
log.error(err_msg)
async def beacon():
metrics = Metrics()
metrics._set_version(getattr(backend, "version", "0"))
metrics._set_mtoken(getattr(backend, "mtoken", ""))
try:
while True:
metrics._model_errored(err_msg)
await metrics._Metrics__send_metrics_and_reset()
await asyncio.sleep(10)
finally:
await metrics.aclose()
run(beacon())
+6 -6
View File
@@ -292,12 +292,12 @@ def test_load_cmd(
args = arg_parser.parse_args()
if hasattr(args, "comfy_model"):
os.environ["COMFY_MODEL"] = args.comfy_model
server_url = dict(
prod="https://run.vast.ai",
alpha="https://run-alpha.vast.ai",
candidate="https://run-candidate.vast.ai",
local="http://localhost:8080",
)[args.instance]
server_url = {
"prod": "https://run.vast.ai",
"alpha": "https://run-alpha.vast.ai",
"candidate": "https://run-candidate.vast.ai",
"local": "http://localhost:8080",
}.get(args.instance, "http://localhost:8080")
run_test(
num_requests=args.num_requests,
requests_per_second=args.requests_per_second,
+4 -1
View File
@@ -1,4 +1,6 @@
aiohttp[speedups]==3.10.1
aiohttp~=3.10.1
aiodns~=3.6.0
pycares~=4.11.0
anyio~=4.4
lib~=4.0
nltk~=3.9
@@ -8,3 +10,4 @@ Requests~=2.32
transformers~=4.52
utils==1.0.*
hf_transfer>=0.1.9
vastai-sdk>=0.2.0
+48 -6
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR"
@@ -41,6 +41,14 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG
echo_var MODEL_LOG
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
if [ -e "$MODEL_LOG" ]; then
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
cat "$MODEL_LOG" >> "$MODEL_LOG.old"
: > "$MODEL_LOG"
fi
# Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
@@ -124,9 +132,43 @@ cd "$SERVER_DIR"
echo "launching PyWorker server"
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
set +e
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
set -e
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
echo "launching PyWorker server done"
if [ "${PY_STATUS}" -ne 0 ]; then
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
MTOKEN="${MASTER_TOKEN:-}"
VERSION="${PYWORKER_VERSION:-0}"
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
for addr in "${REPORT_ADDRS[@]}"; do
curl -sS -X POST -H 'Content-Type: application/json' \
-d "$(cat <<JSON
{
"id": ${CONTAINER_ID:-0},
"mtoken": "${MTOKEN}",
"version": "${VERSION}",
"loadtime": 0,
"new_load": 0,
"cur_load": 0,
"rej_load": 0,
"max_perf": 0,
"cur_perf": 0,
"error_msg": "${ERROR_MSG}",
"num_requests_working": 0,
"num_requests_recieved": 0,
"additional_disk_usage": 0,
"working_request_idxs": [],
"cur_capacity": 0,
"max_capacity": 0,
"url": "${URL}"
}
JSON
)" "${addr%/}/worker_status/" || true
done
fi
echo "launching PyWorker server done"
+43 -5
View File
@@ -1,5 +1,6 @@
import logging
from typing import Any, Dict, Optional
import time
from typing import Any, Dict, Optional, Tuple
import requests
@@ -16,6 +17,38 @@ class Endpoint:
Utility class for handling endpoint operations.
"""
@staticmethod
def get_endpoint_info(
endpoint_name: str, account_api_key: str, instance: str
) -> Optional[Dict[str, Any]]:
headers = {"Authorization": f"Bearer {account_api_key}"}
url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}"
# Retry a few times to smooth over transient propagation/network delays
for attempt in range(4):
try:
response = requests.get(url, headers=headers, timeout=8)
if response.status_code != 200:
# brief backoff and retry
time.sleep(0.3 * (attempt + 1))
continue
try:
data = response.json()
except Exception:
# JSON parse failed; backoff and retry
time.sleep(0.3 * (attempt + 1))
continue
result = data.get("results", []) if isinstance(data, dict) else []
endpoint = next(
(item for item in result if item.get("endpoint_name") == endpoint_name),
None,
)
if endpoint and endpoint.get("id") and endpoint.get("api_key"):
return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")}
except Exception:
# network or other transient error; retry
time.sleep(0.3 * (attempt + 1))
return None
@staticmethod
def get_autoscaler_server_url(instance: str) -> str:
endpoints = {
@@ -23,7 +56,10 @@ class Endpoint:
"candidate": "run-candidate",
"prod": "run",
}
return f"https://{endpoints[instance]}.vast.ai/"
host = endpoints.get(instance)
if host:
return f"https://{host}.vast.ai/"
return "http://localhost:8080"
@staticmethod
def get_server_url(instance: str) -> str:
@@ -32,7 +68,8 @@ class Endpoint:
"candidate": "candidate",
"prod": "console",
}
return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
host = endpoints.get(instance, "alpha")
return f"https://{host}.vast.ai/api/v0/endptjobs/"
@staticmethod
def get_endpoint_api_key(
@@ -55,6 +92,7 @@ class Endpoint:
response = requests.get(
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
headers=headers,
timeout=8,
)
if response.status_code != 200:
@@ -64,14 +102,14 @@ class Endpoint:
try:
data = response.json()
except requests.exceptions.JSONDecodeError as e:
except Exception as e:
log.debug(f"Failed to parse JSON response: {e}")
return None
result = data.get("results", [])
endpoint: Optional[Dict[str, Any]] = next(
(item for item in result if item["endpoint_name"] == endpoint_name),
(item for item in result if item.get("endpoint_name") == endpoint_name),
None,
)
if not endpoint:
+107 -13
View File
@@ -1,8 +1,16 @@
# ComfyUI PyWorker
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture.
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
## Instance Setup
1. Pick a template
- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless))
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Requirements
@@ -10,11 +18,105 @@ This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) a
A docker image is provided but you may use any if the above requirements are met.
## Client
The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage.
### Setup
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
2. Set your API key:
```bash
export VAST_API_KEY=<your_api_key>
```
### Usage
```bash
# Default prompt
python -m workers.comfyui-json.client
# Custom prompt
python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow"
# With options
python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30
# Using a custom workflow file
python -m workers.comfyui-json.client --workflow my_workflow.json
# With S3 upload
python -m workers.comfyui-json.client --s3
```
### CLI Flags
| Flag | Default | Description |
|------|---------|-------------|
| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name |
| `--prompt` | (default) | Text prompt for image generation |
| `--workflow` | (none) | Path to custom workflow JSON file |
| `--width` | 512 | Image width in pixels |
| `--height` | 512 | Image height in pixels |
| `--steps` | 20 | Number of denoising steps |
| `--seed` | (random) | Random seed for reproducibility |
| `--s3` | (disabled) | Upload generated images to S3 |
### Output
Images are saved to `./generated_images/comfy_{seed}.png`.
### S3 Upload (Optional)
You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag.
**1. Set environment variables:**
```bash
export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com"
export S3_BUCKET_NAME="my-bucket"
export S3_ACCESS_KEY_ID="your-access-key-id"
export S3_SECRET_ACCESS_KEY="your-secret-access-key"
```
**2. Run with S3 upload enabled:**
```bash
python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3
```
Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`.
**Note:** Requires `boto3` (`pip install boto3`).
## Benchmarking
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
### Custom Benchmark Workflows
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
**Ways to provide the benchmark file:**
- Fork this repository and add your `benchmark.json` file
- Write the file during worker provisioning (onstart script or setup phase)
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
### Default Benchmark (Fallback)
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
| Environment Variable | Default Value | Description |
| -------------------- | ------------- | ----------- |
@@ -24,7 +126,7 @@ The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image wo
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
### Calibrating Benchmark Duration
#### Calibrating Fallback Benchmark Duration
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
@@ -200,11 +302,3 @@ WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
}
}
```
## Client Libraries
See the test client examples for implementation details on how to integrate with the ComfyUI worker.
---
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
+293 -136
View File
@@ -1,155 +1,312 @@
import logging
import os
import sys
import json
import uuid
import random
from urllib.parse import urljoin
import json
import asyncio
import logging
import argparse
import aiohttp
import requests
from vastai import Serverless
from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types import count_workload
# ---------------------- Config ----------------------
DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed"
ENDPOINT_NAME = "my-comfyui-endpoint"
DEFAULT_WIDTH = 512
DEFAULT_HEIGHT = 512
DEFAULT_STEPS = 20
COST = 100 # Fixed cost for ComfyUI requests
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
# Optional S3 Configuration (from environment variables)
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
log = logging.getLogger(__name__)
def call_text2image_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
"""Simple Text2Image using the new modifier-based approach"""
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
"""Helper function for making requests with consistent error handling"""
try:
response = requests.post(
url,
json=payload,
timeout=timeout,
verify=verify
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred during {context}: {http_err}")
log.error(f"Status Code: {response.status_code}")
log.error("Response content:", response.text)
return None
except requests.exceptions.Timeout:
log.error(f"Timeout occurred during {context}: {url}")
return None
except requests.exceptions.ConnectionError:
log.error(f"Connection error occurred during {context}: {url}")
return None
except json.JSONDecodeError as json_err:
log.error(f"Failed to decode JSON response during {context}: {json_err}")
if 'response' in locals():
print("Response content:", response.text)
return None
except Exception as err:
log.error(f"An unexpected error occurred during {context}: {err}")
if 'response' in locals():
log.error("Response content (if available):", response.text)
return None
WORKER_ENDPOINT = "/generate/sync"
def get_s3_client():
"""Create and return an S3 client configured for the S3-compatible endpoint"""
try:
import boto3
from botocore.config import Config
except ImportError:
log.error("boto3 is required for S3 uploads. Install with: pip install boto3")
return None
# This worker has concurrency = 1. All workloads have cost value 1.0
COST = count_workload()
# Route to get worker URL
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
# First request - get routing information
route_response = make_request(
url=urljoin(server_url, "/route/"),
payload=route_payload,
timeout=4,
context="route request"
if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
log.error("S3 environment variables not fully configured. Required:")
log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY")
return None
return boto3.client(
"s3",
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config=Config(signature_version="s3v4"),
)
if route_response is None:
return None
if "url" not in route_response or not route_response["url"]:
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
return None
if "status" in route_response:
print(f"Autoscaler status: {route_response['status']}")
return None
# Extract data from route response
url = route_response["url"]
auth_data = dict(
signature=route_response["signature"],
cost=route_response["cost"],
endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"],
url=route_response["url"],
)
# Build the payload for the worker request
worker_payload = {
# ---------------------- API Functions ----------------------
async def call_generate(
client: Serverless,
*,
endpoint_name: str,
prompt: str,
width: int,
height: int,
steps: int,
seed: int,
) -> dict:
"""Generate image using Text2Image modifier"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"request_id": str(uuid.uuid4()),
"modifier": "Text2Image",
"modifications": {
"prompt": "a beautiful landscape with mountains and lakes",
"width": 1024,
"height": 1024,
"steps": 20,
"seed": random.randint(0, 2**32 - 1)
"prompt": prompt,
"width": width,
"height": height,
"steps": steps,
"seed": seed,
},
"workflow_json": {} # Empty since using modifier approach
}
}
req_data = dict(payload=worker_payload, auth_data=auth_data)
worker_url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {worker_url}")
# Second request - call the worker endpoint
worker_response = make_request(
url=worker_url,
payload=req_data,
verify=get_cert_file_path(),
context="worker request"
)
return worker_response
return await endpoint.request("/generate/sync", payload, cost=COST)
async def call_generate_workflow(
client: Serverless,
*,
endpoint_name: str,
workflow_json: dict,
) -> dict:
"""Generate using custom workflow JSON"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"request_id": str(uuid.uuid4()),
"workflow_json": workflow_json,
}
}
return await endpoint.request("/generate/sync", payload, cost=COST)
# ---------------------- Demo Class ----------------------
class APIDemo:
def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False):
self.client = client
self.endpoint_name = endpoint_name
self.upload_s3 = upload_s3
self.s3_client = get_s3_client() if upload_s3 else None
if upload_s3 and not self.s3_client:
log.warning("S3 upload requested but client creation failed. Images will only be saved locally.")
def extract_filename(self, response: dict) -> str | None:
"""Extract the generated image filename from ComfyUI response"""
if "comfyui_response" in response:
for data in response["comfyui_response"].values():
if isinstance(data, dict) and "outputs" in data:
for node_output in data["outputs"].values():
if "images" in node_output and node_output["images"]:
return node_output["images"][0].get("filename")
return None
async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch and save image locally from the worker, optionally upload to S3"""
os.makedirs("generated_images", exist_ok=True)
return await self._fetch_image(worker_url, filename, local_name)
def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None:
"""Upload a local file to S3 and return the S3 URL"""
if not self.s3_client:
return None
try:
self.s3_client.upload_file(
local_path,
S3_BUCKET_NAME,
s3_key,
ExtraArgs={"ContentType": "image/png"}
)
s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}"
print(f" ☁️ Uploaded to S3: {s3_key}")
return s3_url
except Exception as e:
log.error(f"Failed to upload to S3: {e}")
return None
async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch image from worker's /view endpoint and save locally"""
if not worker_url:
return None
try:
url = f"{worker_url}/view"
params = {"filename": filename, "type": "output"}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, ssl=False) as resp:
if resp.status == 200:
path = f"generated_images/{local_name}"
image_data = await resp.read()
with open(path, "wb") as f:
f.write(image_data)
print(f" 💾 Saved: {path}")
# Upload to S3 if enabled
if self.upload_s3 and self.s3_client:
s3_key = f"comfyui/{local_name}"
self._upload_to_s3(path, s3_key)
return path
return None
except Exception:
return None
async def demo_prompt(
self,
prompt: str,
width: int,
height: int,
steps: int,
seed: int | None,
):
"""Demo: Generate image from text prompt"""
print("=" * 60)
print("COMFYUI TEXT-TO-IMAGE DEMO")
print("=" * 60)
if seed is None:
seed = random.randint(0, 2**32 - 1)
print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}")
print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}")
print("\n🎨 Generating image...")
response = await call_generate(
self.client,
endpoint_name=self.endpoint_name,
prompt=prompt,
width=width,
height=height,
steps=steps,
seed=seed,
)
print("\n✅ Generation complete!")
# Get worker URL for fetching images
worker_url = response.get("url", "")
print(f"Worker URL: {worker_url}")
# Fetch and save image
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, f"comfy_{seed}.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
async def demo_workflow(self, workflow_file: str):
"""Demo: Generate using custom workflow file"""
print("=" * 60)
print("COMFYUI CUSTOM WORKFLOW DEMO")
print("=" * 60)
if not os.path.exists(workflow_file):
log.error(f"Workflow file not found: {workflow_file}")
return
with open(workflow_file, "r") as f:
workflow_json = json.load(f)
print(f"Workflow: {workflow_file}")
print("\n🎨 Generating...")
response = await call_generate_workflow(
self.client,
endpoint_name=self.endpoint_name,
workflow_json=workflow_json,
)
print("\n✅ Generation complete!")
worker_url = response.get("url", "")
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, "workflow.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT",
help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')")
p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead")
p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})")
p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})")
p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})")
p.add_argument("--seed", type=int, default=None, help="Seed (default: random)")
p.add_argument("--s3", action="store_true",
help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)")
return p
async def main_async():
args = build_arg_parser().parse_args()
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
if args.s3:
print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint, upload_s3=args.s3)
if args.workflow:
await demo.demo_workflow(workflow_file=args.workflow)
else:
await demo.demo_prompt(
prompt=args.prompt,
width=args.width,
height=args.height,
steps=args.steps,
seed=args.seed,
)
except AttributeError as e:
if "API key" in str(e):
log.error("API key missing. Set VAST_API_KEY environment variable.")
else:
log.error(f"Error: {e}")
sys.exit(1)
except Exception as e:
log.error(f"Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
result = call_text2image_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
if result is None:
log.error("Text2Image workflow failed")
else:
print(result)
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
asyncio.run(main_async())
+28 -4
View File
@@ -5,12 +5,13 @@ import dataclasses
from typing import Dict, Any
from functools import cache
from math import ceil
from pathlib import Path
import json
import logging
from lib.data_types import ApiPayload, JsonDataException
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
log = logging.getLogger(__file__)
def count_workload() -> float:
# Always 100.0 where there is a single instance of ComfyUI handling requests
@@ -24,9 +25,32 @@ class ComfyWorkflowData(ApiPayload):
@classmethod
def for_test(cls):
"""
Use the variables available to simulate workflows of the required running time
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
Otherwise, use the variables available to simulate workflows of the required running time
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
"""
# Try to load benchmark.json
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
if benchmark_file.exists():
try:
with open(benchmark_file, "r") as f:
benchmark_workflow = json.load(f)
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"workflow_json": benchmark_workflow
}
)
except (json.JSONDecodeError, IOError):
# JSON is malformed or file can't be read, fall through to default
log.error(f"Failed to benchmark using {benchmark_file}")
# Fallback: read prompts and construct payload
log.info("Using fallback method for benchmarking")
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
test_prompt = random.choice(test_prompts).rstrip()
return cls(
input={
@@ -0,0 +1,107 @@
{
"3": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "text, watermark",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
+35 -1
View File
@@ -4,6 +4,7 @@ import dataclasses
import base64
from typing import Optional, Union, Type
import aiohttp
from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction
@@ -13,12 +14,14 @@ from .data_types import ComfyWorkflowData
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:18188") # Raw ComfyUI server
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: ", # This error is emitted when the model file is not there at all
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
]
@@ -70,7 +73,7 @@ class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
@property
def healthcheck_endpoint(self) -> Optional[str]:
return "/health"
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[ComfyWorkflowData]:
@@ -107,8 +110,39 @@ async def handle_ping(_):
return web.Response(body="pong")
async def handle_view(request: web.Request) -> web.Response:
"""Proxy /view requests to raw ComfyUI server to fetch generated images"""
# Forward query params to raw ComfyUI (not the API wrapper)
query_string = request.query_string
url = f"{COMFYUI_URL}/view?{query_string}"
log.debug(f"Proxying /view request to: {url}")
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status == 200:
content = await resp.read()
return web.Response(
body=content,
status=200,
content_type=resp.content_type or "image/png"
)
else:
text = await resp.text()
return web.Response(
text=text,
status=resp.status,
content_type="text/plain"
)
except Exception as e:
log.error(f"Error proxying /view: {e}")
return web.Response(text=str(e), status=500)
routes = [
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
web.get("/view", handle_view),
web.get("/ping", handle_ping),
]
+6 -12
View File
@@ -7,20 +7,13 @@ from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
"""
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
"""
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
from vastai import Serverless
def call_default_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
ENDPOINT_NAME = "my-comfyui-endpoint"
COST = 100 # Use a constant cost for image generation
def call_default_workflow(client: Serverless) -> None:
WORKER_ENDPOINT = "/prompt"
COST = 100
route_payload = {
@@ -82,6 +75,7 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
request_idx=message["request_idx"],
)
workflow = {
"3": {
+33 -26
View File
@@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
@@ -34,38 +33,20 @@ uv pip install -r requirements.txt
Several examples have been provided in the client to help you get started with your own implementation.
### Completions
Call to `/v1/completions` with json response
First, set your API key as an environment variable:
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
export VAST_API_KEY=<your_api_key>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Interactive Chat (streaming)
@@ -75,6 +56,32 @@ Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
+375 -429
View File
@@ -1,14 +1,15 @@
import logging
import sys
import json
import os
import sys
import subprocess
from urllib.parse import urljoin
from typing import Dict, Any, Optional, Iterator, Union, List
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig
import argparse
from typing import Any, Dict, List, Optional
from vastai import Serverless
import asyncio
# ---------------------- Logging ----------------------
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
@@ -16,135 +17,20 @@ logging.basicConfig(
)
log = logging.getLogger(__file__)
COMPLETIONS_PROMPT = "the capital of USA is"
# ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
class APIClient:
"""Lightweight client focused solely on API communication"""
# Remove the generic WORKER_ENDPOINT since we're now going direct
DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4
def __init__(
self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
):
self.endpoint_group_name = endpoint_group_name
self.api_key = api_key
self.server_url = server_url
self.endpoint_api_key = endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service"""
if not self.endpoint_api_key:
raise ValueError("No valid endpoint API key available")
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.endpoint_api_key,
"cost": cost,
}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
timeout=self.DEFAULT_TIMEOUT,
)
response.raise_for_status()
return response.json()
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Create auth data from routing response"""
return {
"signature": message["signature"],
"cost": message["cost"],
"endpoint": message["endpoint"],
"reqnum": message["reqnum"],
"url": message["url"],
}
def _make_request(
self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint"""
# Get worker URL and auth data
cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost)
worker_url = message["url"]
auth_data = self._create_auth_data(message)
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}")
log.debug(f"Payload: {req_data}")
# Make the request using the specified method
if method.upper() == "POST":
response = requests.post(
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET":
response = requests.get(
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
if stream:
return self._handle_streaming_response(response)
else:
return response.json()
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
"""Handle streaming response and yield tokens"""
try:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
yield data # Yield the full chunk
except json.JSONDecodeError:
continue
except Exception as e:
log.error(f"Error handling streaming response: {e}")
raise
def call_completions(
self, config: CompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/completions", stream=config.stream
)
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
)
TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? "
"What do you think this directory might be for?"
)
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- Tooling ----------------------
class ToolManager:
"""Handles tool definitions and execution"""
@@ -164,7 +50,7 @@ class ToolManager:
@staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""Get the ls tool definition"""
"""OpenAI-compatible tool schema"""
return [
{
"type": "function",
@@ -178,98 +64,228 @@ class ToolManager:
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result"""
function_name = tool_call["function"]["name"]
function_name = (tool_call.get("function") or {}).get("name")
if function_name == "list_files":
return self.list_files()
else:
raise ValueError(f"Unknown tool function: {function_name}")
raise ValueError(f"Unknown tool function: {function_name}")
# ----- Helpers to handle streamed tool_calls assembly -----
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
"""
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
We merge into a per-index state dict until the assistant message finishes.
"""
idx = tc_delta.get("index")
if idx is None:
return
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
if tc_delta.get("id"):
entry["id"] = tc_delta["id"]
fn_delta = tc_delta.get("function") or {}
if "name" in fn_delta and fn_delta["name"]:
entry["function"]["name"] = fn_delta["name"]
if "arguments" in fn_delta and fn_delta["arguments"]:
entry["function"]["arguments"] += fn_delta["arguments"]
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
return [state[i] for i in sorted(state.keys())]
# ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
}
}
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
# ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
}
}
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
# ---------------------- Demo Runner ----------------------
class APIDemo:
"""Demo and testing functionality for the API client"""
def __init__(
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
self.client = client
self.model = model
self.endpoint_name = endpoint_name
self.tool_manager = tool_manager or ToolManager()
def handle_streaming_response(
self, response_stream, show_reasoning: bool = True
) -> str:
"""
Handle streaming chat response and display all output.
"""
# ----- Streaming handler -----
async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
full_response = ""
reasoning_content = ""
reasoning_started = False
content_started = False
printed_reasoning = False
printed_answer = False
finish_reason = None
for chunk in response_stream:
# Normalize the chunk
if isinstance(chunk, str):
chunk = chunk.strip()
if chunk.startswith("data: "):
chunk = chunk[6:].strip()
if chunk in ["[DONE]", ""]:
continue
try:
parsed_chunk = json.loads(chunk)
except json.JSONDecodeError:
continue
elif isinstance(chunk, dict):
parsed_chunk = chunk
else:
continue
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
# Track finish reason
if choice.get("finish_reason"):
finish_reason = choice.get("finish_reason")
# Parse delta from the chunk
choices = parsed_chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
reasoning_token = delta.get("reasoning_content", "")
content_token = delta.get("content", "")
# Print reasoning token if applicable
if show_reasoning and reasoning_token:
if not reasoning_started:
# reasoning tokens
rc = delta.get("reasoning_content")
if rc and show_reasoning:
if not printed_reasoning:
print("\n🧠 Reasoning: ", end="", flush=True)
reasoning_started = True
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
reasoning_content += reasoning_token
printed_reasoning = True
print(rc, end="", flush=True)
reasoning_content += rc
# Print content token
if content_token:
if not content_started:
if show_reasoning and reasoning_started:
print(f"\n💬 Response: ", end="", flush=True)
# content tokens
content_part = delta.get("content")
if content_part:
if not printed_answer:
if show_reasoning and printed_reasoning:
print("\n💬 Response: ", end="", flush=True)
else:
print("Assistant: ", end="", flush=True)
content_started = True
print(content_token, end="", flush=True)
full_response += content_token
print() # Ensure newline after response
printed_answer = True
print(content_part, end="", flush=True)
full_response += content_part
print() # newline
if show_reasoning:
if reasoning_started or content_started:
if printed_reasoning or printed_answer:
print("\nStreaming completed.")
if reasoning_started:
if printed_reasoning:
print(f"Reasoning tokens: {len(reasoning_content.split())}")
if content_started:
if printed_answer:
print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
return full_response
async def demo_completions(self) -> None:
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
def test_tool_support(self) -> bool:
"""Test if the endpoint supports function calling"""
log.debug("Testing endpoint tool calling support...")
response = await call_completions(
client=self.client,
model=self.model,
prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print("\nResponse:")
print(json.dumps(response, indent=2))
# Try a simple request with minimal tools to test support
async def demo_chat(self, use_streaming: bool = True) -> None:
print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
print("=" * 60)
messages = [{"role": "user", "content": CHAT_PROMPT}]
if use_streaming:
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
try:
await self.handle_streaming_response(stream, show_reasoning=True)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
else:
response = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def test_tool_support(self) -> bool:
"""Probe that tool schema is accepted (no actual call)"""
messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [
{
@@ -277,179 +293,158 @@ class APIDemo:
"function": {"name": "test_function", "description": "Test function"},
}
]
config = ChatCompletionConfig(
model=self.model,
messages=messages,
max_tokens=10,
tools=minimal_tool,
tool_choice="none", # Don't actually call the tool
)
try:
response = self.client.call_chat_completions(config)
_ = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool,
tool_choice="none",
max_tokens=10
)
return True
except Exception as e:
log.error(f"Error: Endpoint does not support tool calling: {e}")
log.error("Endpoint does not support tool calling: %s", e)
return False
def demo_completions(self) -> None:
"""Demo: test basic completions endpoint"""
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
config = CompletionConfig(
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
)
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
)
response = self.client.call_completions(config)
if isinstance(response, dict):
print("\nResponse:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_chat(self, use_streaming: bool = True) -> None:
"""
Demo: test chat completions endpoint with optional streaming
"""
print("=" * 60)
print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60)
config = ChatCompletionConfig(
model=self.model,
messages=[{"role": "user", "content": CHAT_PROMPT}],
stream=use_streaming,
)
log.info(f"Testing chat completions with model '{self.model}'...")
response = self.client.call_chat_completions(config)
if use_streaming:
try:
self.handle_streaming_response(response, show_reasoning=True)
except Exception as e:
log.error(f"\nError during streaming: {e}")
import traceback
traceback.print_exc()
return
else:
if isinstance(response, dict):
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
async def demo_ls_tool(self) -> None:
"""Ask to list files using function calling, then provide final analysis"""
print("=" * 60)
print("TOOL USE DEMO: List Directory Contents")
print("=" * 60)
# Test if tools are supported first
if not self.test_tool_support():
if not await self.test_tool_support():
return
# Request with tool available
messages = [{"role": "user", "content": TOOLS_PROMPT}]
messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
config = ChatCompletionConfig(
# First pass: let the model decide tools, stream tool_calls and partial content
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto",
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
log.info(f"Making initial request with tool using model '{self.model}'...")
response = self.client.call_chat_completions(config)
assistant_content_buf: List[str] = []
tool_calls_state: Dict[int, Dict[str, Any]] = {}
printed_reasoning = False
printed_answer = False
if not isinstance(response, dict):
raise ValueError("Expected dict response for tool use")
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
rc = delta.get("reasoning_content")
if rc:
if not printed_reasoning:
printed_reasoning = True
print("🧠 Reasoning: ", end="", flush=True)
print(rc, end="", flush=True)
print(f"Assistant response: {message.get('content', 'No content')}")
content_part = delta.get("content")
if content_part:
assistant_content_buf.append(content_part)
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(content_part, end="", flush=True)
# Check for tool calls
tool_calls = message.get("tool_calls")
if not tool_calls:
raise ValueError(
"No tool calls made - model may not support function calling"
)
if "tool_calls" in delta and delta["tool_calls"]:
for tc_delta in delta["tool_calls"]:
_merge_tool_call_delta(tool_calls_state, tc_delta)
print(f"Tool calls detected: {len(tool_calls)}")
# If no tool calls, were done.
if not tool_calls_state:
print("\n(No tool calls were made.)")
return
# Execute the tool call
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
print(f"Executing tool: {function_name}")
# Build assistant message with tool_calls
assistant_message = {
"role": "assistant",
"content": "".join(assistant_content_buf) if assistant_content_buf else None,
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
}
messages.append(assistant_message)
tool_result = self.tool_manager.execute_tool_call(tool_call)
print(f"Tool result:\n{tool_result}")
# Execute tools and feed results back
for tc in assistant_message["tool_calls"]:
tool_name = (tc.get("function") or {}).get("name")
call_id = tc.get("id")
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
# Add tool result and continue conversation
messages.append(message) # Add assistant's message with tool call
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": tool_result,
}
)
try:
args = json.loads(raw_args) if raw_args.strip() else {}
except Exception as e:
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
continue
# Get final response
final_config = ChatCompletionConfig(
model=self.model,
messages=messages,
tools=self.tool_manager.get_ls_tool_definition(),
)
try:
if tool_name == "list_files":
tool_result = self.tool_manager.list_files()
else:
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
except Exception as e:
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
print("Getting final response...")
final_response = self.client.call_chat_completions(final_config)
print("\n[Tool executed]", tool_name)
print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
if isinstance(final_response, dict):
final_choice = final_response.get("choices", [{}])[0]
final_message = final_choice.get("message", {})
final_content = final_message.get("content", "")
# Second pass: get final streamed answer after tool results
stream2 = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:")
print("=" * 60)
print(final_content)
print("=" * 60)
final_buf = []
printed_reasoning2 = False
printed_answer2 = False
def interactive_chat(self) -> None:
async for chunk in stream2:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
rc2 = delta.get("reasoning_content")
if rc2:
if not printed_reasoning2:
printed_reasoning2 = True
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
print(rc2, end="", flush=True)
c2 = delta.get("content")
if c2:
final_buf.append(c2)
if not printed_answer2:
printed_answer2 = True
print("\n💬 Response (final): ", end="", flush=True)
print(c2, end="", flush=True)
print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:")
print("=" * 60)
print("".join(final_buf))
print("=" * 60)
async def interactive_chat(self) -> None:
"""Interactive chat session with streaming"""
print("=" * 60)
print("INTERACTIVE STREAMING CHAT")
print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history")
print()
messages = []
messages: List[Dict[str, Any]] = []
while True:
try:
@@ -467,16 +462,16 @@ class APIDemo:
messages.append({"role": "user", "content": user_input})
config = ChatCompletionConfig(
model=self.model, messages=messages, stream=True, temperature=0.7
)
print("Assistant: ", end="", flush=True)
response = self.client.call_chat_completions(config)
assistant_content = self.handle_streaming_response(
response, show_reasoning=True
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=0.7
)
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
# Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content})
@@ -485,115 +480,66 @@ class APIDemo:
print("\n👋 Chat interrupted. Goodbye!")
break
except Exception as e:
log.error(f"\nError: {e}")
log.error("\nError: %s", e)
continue
def main():
"""Main function with CLI switches for different tests"""
from lib.test_utils import test_args
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
# Add mandatory model argument
test_args.add_argument(
"--model", required=True, help="Model to use for requests (required)"
)
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
return p
# Add test mode arguments
test_args.add_argument(
"--completion", action="store_true", help="Test completions endpoint"
)
test_args.add_argument(
"--chat",
action="store_true",
help="Test chat completions endpoint (non-streaming)",
)
test_args.add_argument(
"--chat-stream",
action="store_true",
help="Test chat completions endpoint with streaming",
)
test_args.add_argument(
"--tools",
action="store_true",
help="Test function calling with ls tool (non-streaming)",
)
test_args.add_argument(
"--interactive",
action="store_true",
help="Start interactive streaming chat session",
)
args = test_args.parse_args()
async def main_async():
args = build_arg_parser().parse_args()
# Check that only one test mode is selected
test_modes = [
args.completion,
args.chat,
args.chat_stream,
args.tools,
args.interactive,
]
selected_count = sum(test_modes)
if selected_count == 0:
selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
if selected == 0:
print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint")
print(" --chat : Test chat completions endpoint (non-streaming)")
print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool (non-streaming)")
print(" --tools : Test function calling with ls tool")
print(" --interactive : Start interactive streaming chat session")
print(
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
sys.exit(1)
elif selected_count > 1:
elif selected > 1:
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}")
try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
async with Serverless() as client:
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client
client = APIClient(
endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key,
server_url=Endpoint.get_autoscaler_server_url(args.instance),
endpoint_api_key=endpoint_api_key,
)
# Create tool manager and demo (passing the model parameter)
tool_manager = ToolManager()
demo = APIDemo(client, args.model, tool_manager)
print(f"Using model: {args.model}")
print("=" * 60)
# Run the selected test
if args.completion:
demo.demo_completions()
elif args.chat:
demo.demo_chat(use_streaming=False)
elif args.chat_stream:
demo.demo_chat(use_streaming=True)
elif args.tools:
demo.demo_ls_tool()
elif args.interactive:
demo.interactive_chat()
if args.completion:
await demo.demo_completions()
elif args.chat:
await demo.demo_chat(use_streaming=False)
elif args.chat_stream:
await demo.demo_chat(use_streaming=True)
elif args.tools:
await demo.demo_ls_tool()
elif args.interactive:
await demo.interactive_chat()
except Exception as e:
log.error(f"Error during test: {e}", exc_info=True)
log.error("Error during test: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()
asyncio.run(main_async())
+29 -4
View File
@@ -119,14 +119,25 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
class CompletionsData(GenericData):
@classmethod
def for_test(cls) -> "CompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
test_input = {
"model": model,
"prompt": prompt,
"prompt": f"{system_prompt}\n\n{unique_question}",
"temperature": 0.7,
"max_tokens": 500,
}
@@ -153,7 +164,18 @@ class ChatCompletionsData(GenericData):
@classmethod
def for_test(cls) -> "ChatCompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
@@ -161,7 +183,10 @@ class ChatCompletionsData(GenericData):
# Chat completions use messages format instead of prompt
test_input = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"messages": [
{"role": "system", "content": system_prompt}, # Shared prefix
{"role": "user", "content": unique_question} # Unique per request
],
"temperature": 0.7,
"max_tokens": 500,
}
+2
View File
@@ -11,6 +11,7 @@ MODEL_SERVER_START_LOG_MSG = [
"llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI
"main: model loaded" # llama.cpp
]
MODEL_SERVER_ERROR_LOG_MSGS = [
@@ -34,6 +35,7 @@ backend = Backend(
model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
max_wait_time=600.0,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
+414 -8
View File
@@ -1,8 +1,395 @@
from lib.test_utils import test_load_cmd, test_args
from lib.test_utils import test_args
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from lib.data_types import AuthData
from .data_types.server import CompletionsData
import os
WORKER_ENDPOINT = "/v1/completions"
import os
import time
import threading
import requests
from dataclasses import dataclass
from collections import Counter
from urllib.parse import urljoin, urlparse
import re
# Headless plotting
import matplotlib
matplotlib.use("Agg")
import logging
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
import matplotlib.pyplot as plt
import numpy as np
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
from requests.adapters import HTTPAdapter
def get_incremented_path(path: str) -> str:
base, ext = os.path.splitext(path)
if not os.path.exists(path):
return path
i = 1
while os.path.exists(f"{base}-{i}{ext}"):
i += 1
return f"{base}-{i}{ext}"
WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT)
@dataclass
class ReqResult:
worker_url: str
route_ms: float
worker_ms: float
total_ms: float
ok: bool
error: str = ""
status_code: int = 0
t_start: float = 0.0
t_end: float = 0.0
workload: float = 0.0
def do_one(endpoint_name: str,
endpoint_id: int,
endpoint_api_key: str,
server_url: str,
worker_endpoint: str,
payload,
results_list,
t0,
status_samples,
route_session,
worker_session):
try:
workload = payload.count_workload()
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
start = time.time()
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
t_after_route = time.time()
if r0.status_code != 200:
results_list.append(ReqResult(worker_url="",
route_ms=(t_after_route - start) * 1000.0,
worker_ms=0.0,
total_ms=(t_after_route - start) * 1000.0,
ok=False,
error=f"route error {r0.reason} {r0.text}",
status_code=r0.status_code,
t_start=start - t0,
t_end=t_after_route - t0,
workload=workload))
return
msg = r0.json()
# 1) Check if we got a worker back from route
worker_url = msg.get("url", "")
if not worker_url:
status = msg.get("status", "")
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
if m:
tot, loading, standby, err = map(int, m.groups())
idle = max(tot - loading - standby - err, 0)
status_samples.append((time.time() - t0, idle))
# 2) If we got a worker, send the request
if worker_url:
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
t_before_worker = time.time()
r1 = worker_session.post(
urljoin(worker_url, worker_endpoint),
json=req,
verify=get_cert_file_path(),
timeout=(4, 120),
)
t_after_worker = time.time()
if r1.status_code != 200:
results_list.append(ReqResult(worker_url=worker_url,
route_ms=(t_after_route - start) * 1000.0,
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
total_ms=(t_after_worker - start) * 1000.0,
ok=False,
error=f"worker inference error {r1.reason} {r1.text}",
status_code=r1.status_code,
t_start=start - t0,
t_end=t_after_worker - t0,
workload=workload))
return
# Success case
results_list.append(ReqResult(worker_url=worker_url,
route_ms=(t_after_route - start) * 1000.0,
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
total_ms=(t_after_worker - start) * 1000.0,
ok=True,
error="",
status_code=200,
t_start=start - t0,
t_end=t_after_worker - t0,
workload=workload))
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
if worker_url:
try:
r_status = route_session.post(
urljoin(server_url, "/get_endpoint_workers/"),
json={"id": endpoint_id},
headers={"Authorization": f"Bearer {endpoint_api_key}"},
timeout=3,
)
if r_status.status_code == 200:
workers = r_status.json()
idle = 0
for w in workers:
st = str(w.get("status", "")).lower()
if (st in ("idle")):
idle += 1
status_samples.append((time.time() - t0, idle))
except Exception:
pass
except Exception as e:
t = time.time()
results_list.append(ReqResult(worker_url="",
route_ms=0.0,
worker_ms=0.0,
total_ms=0.0,
ok=False,
error=f"unknown error {e}",
status_code=0,
t_start=t - t0,
t_end=t - t0,
workload=0.0))
def run_load_with_metrics(num_requests: int,
requests_per_second: float,
endpoint_group_name: str,
account_api_key: str,
server_url: str,
worker_endpoint: str,
instance: str,
out_path: str):
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
account_api_key=account_api_key,
instance=instance)
if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"):
print(f"Endpoint {endpoint_group_name} not found for API key")
return
endpoint_id = int(ep_info["id"])
endpoint_api_key = ep_info["api_key"]
t0 = time.time()
results = []
status_samples = []
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
submit_queue_factor = 2 # cap queued tasks to reduce memory
# Shared HTTP sessions with connection pooling (persistent connections)
def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session:
sess = requests.Session()
adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0)
sess.mount("https://", adapter)
sess.mount("http://", adapter)
return sess
# Router: mostly single host, small connection pool is sufficient
route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
# Fire requests using a thread pool, scheduling at requested RPS
inflight = set()
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
for i in range(num_requests):
# Pace submissions to RPS
target_time = t0 + i / max(requests_per_second, 1e-9)
sleep_s = target_time - time.time()
if sleep_s > 0:
time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive
payload = CompletionsData.for_test()
fut = executor.submit(
do_one,
endpoint_group_name,
endpoint_id,
endpoint_api_key,
server_url,
worker_endpoint,
payload,
results,
t0,
status_samples,
route_session,
worker_session,
)
inflight.add(fut)
# Prevent unbounded queue growth
if len(inflight) >= max_concurrency * submit_queue_factor:
done, not_done = wait(inflight, return_when=FIRST_COMPLETED)
inflight = not_done
# Wait for all outstanding tasks
if inflight:
wait(inflight)
# Close sessions
try:
route_session.close()
finally:
worker_session.close()
# Aggregate results
oks = [r for r in results if r.ok]
errs = [r for r in results if not r.ok]
total_reqs = len(results)
succ = len(oks)
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
avg_total = float(np.mean(total_ms)) if succ else 0.0
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
avg_route = float(np.mean(route_ms)) if succ else 0.0
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
# Distribution over workers (by host:port)
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
dist = Counter(hosts)
# Idle over time (mode per second)
idle_ts, idle_vals = [], []
if status_samples:
buckets = {}
for ts, idle in status_samples:
k = int(ts)
buckets.setdefault(k, []).append(idle)
keys = sorted(buckets.keys())
idle_ts = keys
# Use the most frequent sampled value per second (mode) to keep integer counts
idle_vals = []
for k in keys:
vals_k = [int(v) for v in buckets[k]]
if vals_k:
cnt = Counter(vals_k)
idle_vals.append(cnt.most_common(1)[0][0])
else:
idle_vals.append(0)
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
if errs:
print("Sample errors:")
for e in errs[:5]:
print(f" {e.status_code} {e.error}")
# Plot: 2x3 grid
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}")
# Dist per worker
ax0 = axes[0, 0]
if dist:
items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True)
labels, counts = zip(*items)
ax0.bar(range(len(labels)), counts)
ax0.set_xticks(range(len(labels)))
ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
ax0.set_title("Request distribution over workers")
ax0.set_ylabel("count")
# Latency histogram (total)
ax1 = axes[0, 1]
if succ:
ax1.hist(total_ms, bins=30)
ax1.set_title("Total latency (ms)")
ax1.set_xlabel("ms")
ax1.set_ylabel("freq")
# Eligible workers over time
ax_idle = axes[0, 2]
if idle_ts:
ax_idle.plot(idle_ts, idle_vals, "-o", ms=3)
ax_idle.set_title("Eligible workers over time")
ax_idle.set_xlabel("time (s)")
ax_idle.set_ylabel("eligible count")
# Throughput over time (completions/sec)
ax_idle = axes[1, 0]
ax_idle.clear()
if succ:
per_sec = {}
for r in oks:
s = int(r.t_end)
per_sec[s] = per_sec.get(s, 0) + 1
ts = sorted(per_sec.keys())
vals = [per_sec[t] for t in ts]
ax_idle.plot(ts, vals, "-o", ms=3)
ax_idle.set_title("Completions per second")
ax_idle.set_xlabel("time (s)")
ax_idle.set_ylabel("completions / sec")
# Summary text
ax3 = axes[1, 1]
ax3.axis("off")
text = (
f"Total requests: {total_reqs}\n"
f"Success: {succ} Errors: {len(errs)}\n"
f"Avg total latency: {avg_total:.1f} ms\n"
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
f"Avg route latency: {avg_route:.1f} ms\n"
f"Avg worker latency: {avg_worker:.1f} ms\n"
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
)
ax3.set_title("Summary")
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
# Error count over time
ax_errors = axes[1, 2]
all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
if all_end_times:
min_second = min(all_end_times)
max_second = max(all_end_times)
# Count errors per second
errors_per_second = {}
for result in errs:
second = int(result.t_end)
errors_per_second[second] = errors_per_second.get(second, 0) + 1
# Create complete timeline including zeros
time_seconds = list(range(min_second, max_second + 1))
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
ax_errors.set_title("Errors per second")
ax_errors.set_xlabel("time (s)")
ax_errors.set_ylabel("errors / sec")
# Ensure unique output path and create directory if needed
final_out_path = get_incremented_path(out_path)
out_dir = os.path.dirname(final_out_path)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(final_out_path, dpi=120)
print(f"Saved report to: {final_out_path}")
# Per-worker latency boxplot (top 12 by volume)
groups = {}
for r in oks:
host = urlparse(r.worker_url).netloc
groups.setdefault(host, []).append(r.total_ms)
items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12]
if items:
labels, data = zip(*items)
fig2, axb = plt.subplots(1, 1, figsize=(12, 5))
axb.boxplot(data, showfliers=False)
axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
axb.set_title("Per-worker latency (ms)")
axb.set_ylabel("ms")
plt.tight_layout()
extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png")
plt.savefig(extra_out, dpi=120)
fig2.tight_layout()
fig2.savefig(extra_out, dpi=120)
print(f"Saved worker latency plot to: {extra_out}")
if __name__ == "__main__":
# Check if MODEL_NAME environment variable is set
@@ -16,13 +403,32 @@ if __name__ == "__main__":
help="Model to use for completions request (required if MODEL_NAME env var not set)",
)
# Parse known args to get model early, before test_load_cmd adds its args
# Parse known args to get model early, before adding load args
known_args, _ = test_args.parse_known_args()
# Set environment variable if model was provided
if hasattr(known_args, "model") and known_args.model:
os.environ["MODEL_NAME"] = known_args.model
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
# Now call test_load_cmd normally - it will add its own args and re-parse
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)
# Load test args
test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
args = test_args.parse_args()
server_url = {
"prod": "https://run.vast.ai",
"alpha": "https://run-alpha.vast.ai",
"candidate": "https://run-candidate.vast.ai",
"local": "http://localhost:8080"
}.get(args.instance, "http://localhost:8080")
run_load_with_metrics(
num_requests=args.num_requests,
requests_per_second=args.requests_per_second,
endpoint_group_name=args.endpoint_group_name,
account_api_key=args.api_key,
server_url=server_url,
worker_endpoint=WORKER_ENDPOINT,
instance=args.instance,
out_path=args.out_path,
)
+93 -9
View File
@@ -1,19 +1,103 @@
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
# HuggingFace TGI PyWorker
1. `generate`: Generates the LLM's response to a given prompt in a single request.
2. `generate_stream`: Streams the LLM's response token by token.
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
Both endpoints use the following API payload format:
## Instance Setup
1. Pick a template
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
The template can be configured via the template interface. You may want to change the model or startup arguments.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
## Using the Test Client
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
First, set your API key as an environment variable:
```bash
export VAST_API_KEY=<your_api_key>
```
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
### Generate (Streaming)
Call to `/generate_stream` with streaming response:
```bash
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
```
### Generate (Non-Streaming)
Call to `/generate` with json response:
```bash
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
```
### Interactive Session (Streaming)
Interactive session with streaming responses. Type `quit` to exit.
```bash
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
```
## API Endpoints
TGI provides two primary endpoints:
### Generate (Non-Streaming)
`/generate` - Returns the complete response in a single request.
```json
{
"inputs": "PROMPT",
"inputs": "Your prompt here",
"parameters": {
"max_new_tokens": 250
"max_new_tokens": 1024,
"temperature": 0.7,
"return_full_text": false
}
}
```
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
approximately 2 seconds to complete.
### Generate Stream (Streaming)
`/generate_stream` - Streams the response token by token.
```json
{
"inputs": "Your prompt here",
"parameters": {
"max_new_tokens": 1024,
"temperature": 0.7,
"do_sample": true,
"return_full_text": false
}
}
```
## Performance Notes
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
+201 -104
View File
@@ -1,11 +1,13 @@
import logging
import sys
import json
from urllib.parse import urljoin
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
import os
import sys
import argparse
from vastai import Serverless
import asyncio
# ---------------------- Logging ----------------------
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
@@ -13,113 +15,208 @@ logging.basicConfig(
)
log = logging.getLogger(__file__)
# ---------------------- Defaults ----------------------
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
WORKER_ENDPOINT = "/generate"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- API Calls ----------------------
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
"""Non-streaming generation via /generate endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"return_full_text": False,
}
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=url,
)
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
res = response.json()
print(res)
log.debug("POST /generate %s", json.dumps(payload)[:500])
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
return resp["response"]
def call_generate_stream(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/generate_stream"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
"""Streaming generation via /generate_stream endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"do_sample": True,
"return_full_text": False,
}
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
resp = await endpoint.request(
"/generate_stream",
payload,
cost=payload["parameters"]["max_new_tokens"],
stream=True,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
print(f"url: {url}")
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
response = requests.post(url, json=req_data, stream=True)
response.raise_for_status() # Raise an exception for bad status codes
for line in response.iter_lines():
payload = line.decode().lstrip("data:").rstrip()
if payload:
return resp["response"] # async generator
# ---------------------- Demo Runner ----------------------
class APIDemo:
"""Demo and testing functionality for the TGI API client"""
def __init__(self, client: Serverless, endpoint_name: str):
self.client = client
self.endpoint_name = endpoint_name
async def handle_streaming_response(self, stream) -> str:
"""Process streaming response and print tokens"""
full_response = ""
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(tok, end="", flush=True)
full_response += tok
print() # newline
if printed_answer:
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
return full_response
async def demo_generate(self) -> None:
"""Demo non-streaming generation"""
print("=" * 60)
print("GENERATE DEMO (NON-STREAMING)")
print("=" * 60)
response = await call_generate(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print(f"\n💬 Response: {response.get('generated_text', '')}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def demo_generate_stream(self) -> None:
"""Demo streaming generation"""
print("=" * 60)
print("GENERATE DEMO (STREAMING)")
print("=" * 60)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
try:
await self.handle_streaming_response(stream)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
async def interactive_chat(self) -> None:
"""Interactive session with streaming generation"""
print("=" * 60)
print("INTERACTIVE STREAMING SESSION")
print("=" * 60)
print(f"Using endpoint: {self.endpoint_name}")
print("Type 'quit' to exit")
print()
while True:
try:
data = json.loads(payload)
print(data["token"]["text"], end="")
sys.stdout.flush()
except (json.JSONDecodeError, KeyError) as e:
log.warning(f"Failed to parse streaming response: {e}")
user_input = input("You: ").strip()
if user_input.lower() == "quit":
print("👋 Goodbye!")
break
elif not user_input:
continue
print("Assistant: ", end="", flush=True)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=user_input,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
full_response = ""
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
print(tok, end="", flush=True)
full_response += tok
print() # newline
except KeyboardInterrupt:
print("\n👋 Session interrupted. Goodbye!")
break
except Exception as e:
log.error("\nError: %s", e)
continue
print()
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
return p
async def main_async():
args = build_arg_parser().parse_args()
selected = sum([args.generate, args.generate_stream, args.interactive])
if selected == 0:
print("Please specify exactly one test mode:")
print(" --generate : Test generate endpoint (non-streaming)")
print(" --generate-stream : Test generate endpoint with streaming")
print(" --interactive : Start interactive streaming session")
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
sys.exit(1)
elif selected > 1:
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint)
if args.generate:
await demo.demo_generate()
elif args.generate_stream:
await demo.demo_generate_stream()
elif args.interactive:
await demo.interactive_chat()
except Exception as e:
log.error("Error during test: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
try:
call_generate(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_generate_stream(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
asyncio.run(main_async())