Compare commits

...

78 Commits

Author SHA1 Message Date
Abiola Akinnubi 7437028cb2 Added caller for REPORT_ADDR to backend.py 2025-10-29 18:02:17 -07:00
LucasArmandVast 9f5a432513 Merge pull request #51 from vast-ai/delete-reqs-hotfix
Redis subscriber queue patch
2025-10-28 16:07:28 -07:00
Lucas Armand e09f1fa953 patch for redis queue 2025-10-28 16:03:50 -07:00
edgaratvast ba6f1c2e4b Fix signature (#50)
* change order of fields in auth_data to match autoscaler for signature verification

* also ignore __request_id

* Revert "change order of fields in auth_data to match autoscaler for signature verification" so that it's alphabetical again

This reverts commit b8223879c9.

* enforce alphabetical json dumping of message for signature verification

---------

Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-28 16:01:32 -07:00
edgaratvast 298590fb88 Merge pull request #45 from vast-ai/new-pyworker
New PyWorker
2025-10-28 14:02:53 -07:00
Lucas Armand 814c3acd4c remove unused code 2025-10-28 13:43:57 -07:00
Lucas Armand 22bca74087 Prevent load time race 2025-10-27 18:25:21 -07:00
Lucas Armand 9c795e2a01 removed bad code 2025-10-27 17:03:13 -07:00
Lucas Armand 830b532781 Trying unified delete 2025-10-27 16:57:52 -07:00
LucasArmandVast d6a6e34c6b Merge branch 'main' into new-pyworker 2025-10-27 12:43:49 -07:00
Colter-Downing ac1e109c48 Merge pull request #47 from vast-ai/new-pyworker-vllm-prefix-cache
vLLM Prefix caching, benchmark bug fix, test load script
2025-10-27 12:30:34 -07:00
Colter Downing d6eb498ee4 catch the case where all benchmarks fail (sets error) 2025-10-27 12:01:55 -07:00
Colter Downing bcecd6df40 Suppress matplot debug logs 2025-10-25 16:18:02 -07:00
Lucas Armand 4d9bf2048c Fix 2025-10-24 15:44:38 -07:00
Lucas Armand 7788bc4a62 Added some debug logs 2025-10-24 15:41:00 -07:00
Lucas Armand 37ad3f8d46 asyncio in metrics 2025-10-23 10:18:31 -07:00
Rob Ballantyne 70d51bafe1 Merge pull request #36 from robballantyne/feat/comfyui-json-benchmark-workflow-from-file 2025-10-23 17:05:48 +01:00
Rob Ballantyne 63909736bb Merge pull request #4 from robballantyne/feat/comfyui-json-benchmark-workflow-from-file-no-silent-fail
Feat/comfyui json benchmark workflow from file no silent fail
2025-10-23 17:02:12 +01:00
Rob Ballantyne f4f7080df1 Re-add comment 2025-10-23 17:00:28 +01:00
Rob Ballantyne d51a338e8f log when benchmark file not used 2025-10-23 16:41:02 +01:00
Rob Ballantyne 92a04bd7af No silent fail if benchmark file is missing 2025-10-23 13:41:03 +01:00
Lucas Armand 0f13506938 Send success param 2025-10-22 10:18:59 -07:00
Lucas Armand 01e752d31f use more asyncio sleep 2025-10-21 18:52:13 -07:00
Lucas Armand 5edfa968ca async sleep 2025-10-21 18:49:48 -07:00
Lucas Armand 5b5ef7227a nvm moved it here 2025-10-21 18:20:11 -07:00
Lucas Armand 16990ff8ff move start request 2025-10-21 18:18:44 -07:00
Lucas Armand 9748176366 fixed semaphore acquire bool 2025-10-21 18:12:23 -07:00
Lucas Armand b39193ae70 check for sem acquire 2025-10-21 18:02:14 -07:00
Lucas Armand 9a6ca5d412 added versioning 2025-10-21 15:42:43 -07:00
Lucas Armand e9ba1b03e4 Use delete_requests and track request_idxs 2025-10-21 11:59:35 -07:00
LucasArmandVast c98d661513 Merge pull request #39 from vast-ai/remove-time-divide
PyWorker fixes for cur_load and acks bug
2025-10-13 10:06:22 -07:00
Lucas Armand f6fd1c6ac1 merge 2025-10-09 18:15:55 -07:00
Lucas Armand 055e346c8c Send metrics on request start 2025-10-09 10:13:50 -07:00
Lucas Armand 1cedb28acf Removed division by elapsed time, since autoscaler cur_load in units of workload 2025-10-08 16:54:18 -07:00
Rob Ballantyne ec25dda3ad Merge branch 'vast-ai:main' into feat/comfyui-json-benchmark-workflow-from-file 2025-10-08 14:49:32 +01:00
Colter-Downing 0397af719d Merge pull request #37 from robballantyne/bugfix/healthcheck-endpoint
Fix healthcheck endpoint URL

Tested and merged by Colter
2025-10-06 15:11:27 -07:00
Rob Ballantyne 4fdc314fd9 Fix healthcheck endpoint URL 2025-10-06 22:16:09 +01:00
Rob Ballantyne 3786cf978d Add awareness of errors thrown by the provisioning script 2025-10-05 23:14:59 +01:00
Rob Ballantyne a86d4bcf9c Import json 2025-10-05 23:05:33 +01:00
Rob Ballantyne e9b6a14a5e Import Path 2025-10-05 22:59:19 +01:00
Rob Ballantyne cadac033e1 Enables use of custom workflow for benchmarking
Retains existing method is misc/benchmark.json is nopt present
2025-10-05 22:53:22 +01:00
Colter-Downing 639d82f5b4 Merge pull request #35 from vast-ai/AUTO-664--Healthcheck-error
Fix healthcheck with separate session
2025-10-02 12:51:19 -07:00
Colter Downing 25db78e39d Fix healthcheck with separate session 2025-10-01 18:04:31 -07:00
Scott-Laytart 4e2f2311d0 Merge pull request #33 from vast-ai/comfy-blind-fix-override
undo the fix for comfy yesterday.
2025-09-03 11:50:07 -07:00
abiola-vastai 38782d89bc undo the fix for comfy yesterday. 2025-09-03 17:12:35 +00:00
Scott-Laytart 0185216ccb Merge pull request #32 from vast-ai/blindhotfix_comfy_ui_default_port
Blind hotfix to see if comfy UI default is needed. if it does work we…
2025-09-02 18:26:25 -07:00
abiola-vastai b20d9e714c Blind hotfix to see if comfy UI default is needed. if it does work we would revert back. 2025-09-03 01:20:09 +00:00
Rob Ballantyne b1eb65d75d Merge pull request #31 from vast-ai/bugfix/startup-script-20250901
Update uv venv creation command
2025-09-01 18:19:17 +01:00
Rob Ballantyne 1d09d7fe96 Update uv venv creation command 2025-09-01 16:55:20 +01:00
Colter-Downing 1b37054dec Merge pull request #28 from vast-ai/bugfix/backend-timeout-infinite
Bugfix/backend timeout infinite
2025-08-28 11:22:33 -07:00
Colter-Downing 1a1e4174b8 Merge pull request #29 from vast-ai/bugfix/comfyui-json-cost-fix
Set cost to 100
2025-08-28 11:22:21 -07:00
Rob Ballantyne b8377c4081 Set cost to 100 2025-08-28 16:13:17 +01:00
Rob Ballantyne 1e4fa87437 Prevent timeout and allow long running connections 2025-08-28 15:48:57 +01:00
Rob Ballantyne 4c5fa03c7b adds import for ClientTimeout 2025-08-27 20:54:27 +01:00
Rob Ballantyne a8fe74f771 Remove default 300s timeout 2025-08-27 18:34:45 +01:00
Rob Ballantyne b482de8394 Merge pull request #27 from vast-ai/feat/comfyui-api-s3-webhook
Adds new ComfyUI worker

Upload assets to s3 compatible storage via intermediate API wrapper
2025-08-26 14:22:05 +01:00
Rob Ballantyne 703435d10e Improve MODEL_SERVER_START_* messages 2025-08-26 12:42:04 +01:00
Rob Ballantyne 947fc5eea4 Improve benchmarking explanation 2025-08-26 12:41:30 +01:00
Rob Ballantyne 7c1a544b19 Improve error reporting when no ready workers 2025-08-26 12:41:05 +01:00
Rob Ballantyne 16b414676e Use count_workload() function for cost 2025-08-25 18:31:10 +01:00
Rob Ballantyne ba74ac8136 Use cost value 1 for all jobs 2025-08-25 17:58:22 +01:00
Rob Ballantyne 92ff412679 Use MODEL_SERVER_URL environment variable 2025-08-25 17:57:32 +01:00
Rob Ballantyne fc75a64684 Use MODEL_SERVER_URL environment variable 2025-08-25 17:56:27 +01:00
Rob Ballantyne b00bef547c Ensure uv env script is present before sourcing 2025-08-22 17:08:42 +01:00
Rob Ballantyne 3f4acb29fa Improved client exception handling 2025-08-22 15:20:15 +01:00
Rob Ballantyne 58b078f908 Fix modifier class 2025-08-20 18:06:02 +01:00
Rob Ballantyne f9fdf04884 Fix signature 2025-08-20 13:27:29 +01:00
Rob Ballantyne 636f17d27f Fix workflow modifier class 2025-08-20 09:57:07 +01:00
Rob Ballantyne 08c88f7527 Improve testability 2025-08-20 09:34:09 +01:00
Rob Ballantyne 8797b504af Initial ComfyUI implementation with updated wrapper 2025-08-19 17:59:20 +01:00
Nader Arbabian cd946b0a9f update report_addr to use new webserver endpoint with AS fallback 2025-08-12 13:31:19 -07:00
Nader Arbabian c595b42410 for benchmarking, use concurrent requests (#26) 2025-08-11 12:39:28 -07:00
Nader Arbabian 0bf3247a34 fix completions and interactive client 2025-08-11 12:37:53 -07:00
Nader Arbabian 52ac4c0c1a fix endpoint_util not using the correct instance's endpoint 2025-08-11 12:05:58 -07:00
Nader Arbabian 8804e17201 download vast.ai's root certificate in order to make pyworker requests (#25) 2025-08-08 17:04:16 -07:00
Nader Arbabian 4016cf9a53 redo metrics tracking for requests, fixes bug wherere some requests were marked as pending, even though they had finished (#24) 2025-08-08 17:01:21 -07:00
Rob Ballantyne e0be45f39a Addresses breaking change in core pyworker (#22)
* Addresses breaking change in test_utils.py

Endpoint.get_endpoint_api_key() now requires instance

Moves the call to this function out of the APIClient and into main

* Ensure make_benchmark_payload has a value to calculate the workload

---------

Co-authored-by: Nader Arbabian <nader@vast.ai>
2025-07-18 16:11:10 -07:00
Nader Arbabian be2aafdb1f fix pyright errors + revert to old way of handling cancelled api requests (#23) 2025-07-17 16:59:06 -07:00
25 changed files with 1925 additions and 412 deletions
+145 -80
View File
@@ -5,13 +5,14 @@ import base64
import subprocess import subprocess
import dataclasses import dataclasses
import logging import logging
from asyncio import sleep, gather, Semaphore from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool from distutils.util import strtobool
from anyio import open_file from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
import asyncio
import requests import requests
from Crypto.Signature import pkcs1_15 from Crypto.Signature import pkcs1_15
@@ -25,8 +26,12 @@ from lib.data_types import (
LogAction, LogAction,
ApiPayload_T, ApiPayload_T,
JsonDataException, JsonDataException,
RequestMetrics,
BenchmarkResult
) )
VERSION = "0.1.0"
MSG_HISTORY_LEN = 100 MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
@@ -53,15 +58,21 @@ class Backend:
EndpointHandler # this endpoint handler will be used for benchmarking EndpointHandler # this endpoint handler will be used for benchmarking
) )
log_actions: List[Tuple[LogAction, str]] log_actions: List[Tuple[LogAction, str]]
max_wait_time: float = 10.0
reqnum = -1 reqnum = -1
version = VERSION
msg_history = [] msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore) sem: Semaphore = dataclasses.field(default_factory=Semaphore)
unsecured: bool = dataclasses.field( unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
) )
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
self.metrics._set_version(self.version)
self._total_pubkey_fetch_errors = 0 self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False self.__start_healthcheck: bool = False
@@ -75,7 +86,13 @@ class Backend:
@cached_property @cached_property
def session(self): def session(self):
log.debug(f"starting session with {self.model_server_url}") log.debug(f"starting session with {self.model_server_url}")
return ClientSession(self.model_server_url) connector = TCPConnector(
force_close=True, # Required for long running jobs
enable_cleanup_closed=True,
)
timeout = ClientTimeout(total=None)
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
def create_handler( def create_handler(
self, self,
@@ -90,23 +107,19 @@ class Backend:
#######################################Private####################################### #######################################Private#######################################
def _fetch_pubkey(self): def _fetch_pubkey(self):
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"] report_addr = self.report_addr.rstrip("/")
result = subprocess.check_output(command, universal_newlines=True) command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
log.debug("public key:") try:
log.debug(result) result = subprocess.check_output(command, universal_newlines=True)
key = None log.debug("public key:")
for _ in range(5): log.debug(result)
try: key = RSA.import_key(result)
key = RSA.import_key(result) if key is not None:
break return key
except ValueError as e: except (ValueError , subprocess.CalledProcessError) as e:
log.debug(f"Error downloading key: {e}") log.debug(f"Error downloading key: {e}")
time.sleep(15) self.backend_errored("Failed to get autoscaler pubkey")
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
async def __handle_request( async def __handle_request(
self, self,
@@ -122,75 +135,109 @@ class Backend:
except json.JSONDecodeError: except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics)
raise asyncio.CancelledError
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
await self.sem.acquire()
log.debug(
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
try: try:
start_time = time.time()
response = await self.__call_api(handler=handler, payload=payload) response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status status_code = response.status
log.debug( log.debug(
" ".join( " ".join(
[ [
f"request with reqnum:{auth_data.reqnum}", f"request with reqnum:{request_metrics.reqnum}",
f"returned status code: {status_code},", f"returned status code: {status_code},",
] ]
) )
) )
res = await handler.generate_client_response(request, response) res = await handler.generate_client_response(request, response)
self.metrics._request_end( self.metrics._request_success(request_metrics)
workload=workload,
req_response_time=time.time() - start_time,
reqnum=auth_data.reqnum,
)
return res return res
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}") log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored( self.metrics._request_errored(request_metrics)
workload=workload, reqnum=auth_data.reqnum
)
return web.Response(status=500) return web.Response(status=500)
finally:
self.sem.release()
########### ###########
if self.__check_signature(auth_data) is False: if self.__check_signature(auth_data) is False:
self.metrics._request_reject(request_metrics)
return web.Response(status=401) return web.Response(status=401)
if self.metrics.model_metrics.wait_time > self.max_wait_time:
self.metrics._request_reject(request_metrics)
return web.Response(status=429)
acquired = False
try: try:
return await make_request() self.metrics._request_start(request_metrics)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
acquired = True
log.debug(
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
done_task = done.pop()
try:
return done_task.result()
except Exception as e:
log.debug(f"Request task raised exception: {e}")
return web.Response(status=500)
except asyncio.CancelledError:
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499)
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
finally: finally:
if request.task.cancelled(): # Always release the semaphore if it was acquired
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled") if acquired:
self.metrics._request_canceled( self.sem.release()
workload=workload, reqnum=auth_data.reqnum self.metrics._request_end(request_metrics)
)
@cached_property
def healthcheck_session(self):
"""Dedicated session for healthchecks to avoid conflicts with API session"""
log.debug("creating dedicated healthcheck session")
connector = TCPConnector(
force_close=True, # Keep this for isolation
enable_cleanup_closed=True,
)
timeout = ClientTimeout(total=10) # Reasonable timeout for healthchecks
return ClientSession(timeout=timeout, connector=connector)
async def __healthcheck(self): async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint health_check_url = self.benchmark_handler.healthcheck_endpoint
if health_check_url is None: if health_check_url is None:
log.debug("No healthcheck endpoint defined, skipping healthcheck") log.debug("No healthcheck endpoint defined, skipping healthcheck")
return return
while True: while True:
await sleep(10) await sleep(10)
if self.__start_healthcheck is False: if self.__start_healthcheck is False:
continue continue
try: try:
log.debug(f"Performing healthcheck on {health_check_url}") log.debug(f"Performing healthcheck on {health_check_url}")
async with self.session.get(health_check_url) as response: async with self.healthcheck_session.get(health_check_url) as response:
if response.status == 200: if response.status == 200:
log.debug("Healthcheck successful") log.debug("Healthcheck successful")
elif response.status == 503: elif response.status == 503:
@@ -199,7 +246,6 @@ class Backend:
f"Healthcheck failed with status: {response.status}" f"Healthcheck failed with status: {response.status}"
) )
else: else:
# endpoint not ready yet so bail
log.debug(f"Healthcheck Endpoint not ready: {response.status}") log.debug(f"Healthcheck Endpoint not ready: {response.status}")
except Exception as e: except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}") log.debug(f"Healthcheck failed with exception: {e}")
@@ -207,7 +253,7 @@ class Backend:
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
await gather( await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck() self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
) )
def backend_errored(self, msg: str) -> None: def backend_errored(self, msg: str) -> None:
@@ -239,7 +285,7 @@ class Backend:
message = { message = {
key: value key: value
for (key, value) in (dataclasses.asdict(auth_data).items()) for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature" if key != "signature" and key != "__request_id"
} }
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN): if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug( log.debug(
@@ -249,7 +295,7 @@ class Backend:
elif message in self.msg_history: elif message in self.msg_history:
log.debug(f"message: {message} already in message history") log.debug(f"message: {message} already in message history")
return False return False
elif verify_signature(json.dumps(message, indent=4), auth_data.signature): elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum) self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message) self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:] self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
@@ -275,41 +321,60 @@ class Backend:
return float(f.readline()) return float(f.readline())
except FileNotFoundError: except FileNotFoundError:
pass pass
log.debug("Initial run to trigger model loading...")
payload = self.benchmark_handler.make_benchmark_payload()
await self.__call_api(handler=self.benchmark_handler, payload=payload)
max_throughput = 0 max_throughput = 0
last_throughput = 0
sum_throughput = 0 sum_throughput = 0
for run in range(self.benchmark_handler.benchmark_runs + 1): concurrent_requests = 10 if self.allow_parallel_requests else 1
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
start = time.time() start = time.time()
payload = self.benchmark_handler.make_benchmark_payload() benchmark_requests = []
res = await self.__call_api(
handler=self.benchmark_handler, payload=payload for i in range(concurrent_requests):
) payload = self.benchmark_handler.make_benchmark_payload()
data = await res.json()
time_elapsed = time.time() - start
# first run triggers one-time loading of the model which is very slow, so we skip counting it
if run == 0:
continue
else:
workload = payload.count_workload() workload = payload.count_workload()
last_throughput = workload / time_elapsed task = self.__call_api(handler=self.benchmark_handler, payload=payload)
sum_throughput += last_throughput benchmark_requests.append(
max_throughput = max(max_throughput, last_throughput) BenchmarkResult(request_idx=i, workload=workload, task=task)
log.debug(
"\n".join(
[
"#" * 60,
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
"",
f"response: {data}",
"#" * 60,
]
)
) )
responses = await gather(*[br.task for br in benchmark_requests])
for br, response in zip(benchmark_requests, responses):
br.response = response
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
time_elapsed = time.time() - start
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
if successful_responses == 0:
self.backend_errored("No successful responses from benchmark")
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
throughput = total_workload / time_elapsed
sum_throughput += throughput
max_throughput = max(max_throughput, throughput)
# Log results for debugging
log.debug(
"\n".join(
[
"#" * 60,
f"Run: {run}, concurrent_requests: {concurrent_requests}",
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
f"Throughput: {throughput} workload/s",
f"Successful responses: {successful_responses}/{concurrent_requests}",
"#" * 60,
]
)
)
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
log.debug( log.debug(
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}" f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
) )
# save max_throughput so we don't have to run benchmark again on restart of cold instances
with open(BENCHMARK_INDICATOR_FILE, "w") as f: with open(BENCHMARK_INDICATOR_FILE, "w") as f:
f.write(str(max_throughput)) f.write(str(max_throughput))
return max_throughput return max_throughput
@@ -354,7 +419,7 @@ class Backend:
if line: if line:
await handle_log_line(line.rstrip()) await handle_log_line(line.rstrip())
else: else:
time.sleep(LOG_POLL_INTERVAL) await asyncio.sleep(LOG_POLL_INTERVAL)
########### ###########
+54 -11
View File
@@ -3,12 +3,11 @@ import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
from aiohttp import web, ClientResponse from aiohttp import web, ClientResponse
import inspect import inspect
import psutil import psutil
import requests
""" """
@@ -66,10 +65,11 @@ class ApiPayload(ABC):
class AuthData: class AuthData:
"""data used to authenticate requester""" """data used to authenticate requester"""
signature: str
cost: str cost: str
endpoint: str endpoint: str
reqnum: int reqnum: int
request_idx: int
signature: str
url: str url: str
@classmethod @classmethod
@@ -190,13 +190,34 @@ class SystemMetrics:
self.additional_disk_usage = disk_usage - self.last_disk_usage self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage self.last_disk_usage = disk_usage
def reset(self): def reset(self, expected: float | None) -> None:
# autoscaler excepts model_loading_time to be populated only once, when the instance has # autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances # finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading # as well: they should send model_loading_time once when they are done loading
self.model_loading_time = None if self.model_loading_time == expected:
self.model_loading_time = None
@dataclass
class RequestMetrics:
"""Tracks metrics for an active request."""
request_idx: int
reqnum: int
workload: float
status: str
success: bool = False
@dataclass
class BenchmarkResult:
request_idx: int
workload: float
task: Awaitable[ClientResponse]
response: Optional[ClientResponse] = None
@property
def is_successful(self) -> bool:
return self.response is not None and self.response.status == 200
@dataclass @dataclass
class ModelMetrics: class ModelMetrics:
"""Model specific metrics""" """Model specific metrics"""
@@ -206,13 +227,15 @@ class ModelMetrics:
workload_received: float workload_received: float
workload_cancelled: float workload_cancelled: float
workload_errored: float workload_errored: float
workload_pending: float workload_rejected: float
# these are not # these are not
cur_perf: float workload_pending: float
error_msg: Optional[str] error_msg: Optional[str]
max_throughput: float max_throughput: float
requests_recieved: Set[int] = field(default_factory=set) requests_recieved: Set[int] = field(default_factory=set)
requests_working: Set[int] = field(default_factory=set) requests_working: dict[int, RequestMetrics] = field(default_factory=dict)
requests_deleting: list[RequestMetrics] = field(default_factory=list)
last_update: float = field(default_factory=time.time)
@classmethod @classmethod
def empty(cls): def empty(cls):
@@ -221,16 +244,30 @@ class ModelMetrics:
workload_served=0.0, workload_served=0.0,
workload_cancelled=0.0, workload_cancelled=0.0,
workload_errored=0.0, workload_errored=0.0,
cur_perf=0.0, workload_rejected=0.0,
workload_received=0.0, workload_received=0.0,
error_msg=None, error_msg=None,
max_throughput=0.0, max_throughput=0.0,
) )
@property @property
def workload_processing(self) -> float: def workload_processing(self) -> float:
return max(self.workload_received - self.workload_cancelled, 0.0) return max(self.workload_received - self.workload_cancelled, 0.0)
@property
def wait_time(self) -> float:
if (len(self.requests_working) == 0):
return 0.0
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
@property
def cur_load(self) -> float:
return sum([request.workload for request in self.requests_working.values()])
@property
def working_request_idxs(self) -> list[int]:
return [req.request_idx for req in self.requests_working.values()]
def set_errored(self, error_msg): def set_errored(self, error_msg):
self.reset() self.reset()
self.error_msg = error_msg self.error_msg = error_msg
@@ -240,15 +277,20 @@ class ModelMetrics:
self.workload_received = 0 self.workload_received = 0
self.workload_cancelled = 0 self.workload_cancelled = 0
self.workload_errored = 0 self.workload_errored = 0
self.workload_rejected = 0
self.last_update = time.time()
@dataclass @dataclass
class AutoScalaerData: class AutoScalerData:
"""Data that is reported to autoscaler""" """Data that is reported to autoscaler"""
id: int id: int
version: str
loadtime: float loadtime: float
cur_load: float cur_load: float
rej_load: float
new_load: float
error_msg: str error_msg: str
max_perf: float max_perf: float
cur_perf: float cur_perf: float
@@ -257,6 +299,7 @@ class AutoScalaerData:
num_requests_working: int num_requests_working: int
num_requests_recieved: int num_requests_recieved: int
additional_disk_usage: float additional_disk_usage: float
working_request_idxs: list[int]
url: str url: str
+166 -46
View File
@@ -5,13 +5,14 @@ import json
from asyncio import sleep from asyncio import sleep
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from functools import cache from functools import cache
import asyncio
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError
import requests from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
from typing import Awaitable, NoReturn, List from typing import Awaitable, NoReturn, List
METRICS_UPDATE_INTERVAL = 1 METRICS_UPDATE_INTERVAL = 1
DELETE_REQUESTS_INTERVAL = 1
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
@@ -26,7 +27,9 @@ def get_url() -> str:
@dataclass @dataclass
class Metrics: class Metrics:
version: str = "0"
last_metric_update: float = 0.0 last_metric_update: float = 0.0
last_request_served: float = 0.0
update_pending: bool = False update_pending: bool = False
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"])) id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
report_addr: List[str] = field( report_addr: List[str] = field(
@@ -35,44 +38,84 @@ class Metrics:
url: str = field(default_factory=get_url) url: str = field(default_factory=get_url)
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty) system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty) model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
_session: ClientSession | None = field(default=None, init=False, repr=False)
def _request_start(self, workload: float, reqnum: int) -> None: async def http(self) -> ClientSession:
if self._session is None:
self._session = ClientSession(
timeout=ClientTimeout(total=10),
connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True)
)
return self._session
async def aclose(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
def _request_start(self, request: RequestMetrics) -> None:
""" """
this function is called prior to forwarding a request to a model API. this function is called prior to forwarding a request to a model API.
""" """
log.debug("request start") log.debug("request start")
self.model_metrics.workload_pending += workload request.status = "Started"
self.model_metrics.workload_received += workload self.model_metrics.workload_pending += request.workload
self.model_metrics.requests_recieved.add(reqnum) self.model_metrics.workload_received += request.workload
self.model_metrics.requests_working.add(reqnum) self.model_metrics.requests_recieved.add(request.reqnum)
self.model_metrics.requests_working[request.reqnum] = request
def _request_end(
self, workload: float, req_response_time: float, reqnum: int
) -> None:
"""
this function is called after a response from model API is received.
"""
self.model_metrics.workload_served += workload
self.model_metrics.workload_pending -= workload
self.model_metrics.requests_working.discard(reqnum)
self.model_metrics.cur_perf = workload / req_response_time
self.update_pending = True self.update_pending = True
def _request_errored(self, workload: float, reqnum: int) -> None: def _request_end(self, request: RequestMetrics) -> None:
"""
this function is called after handling of a request ends, regardless of the outcome
"""
self.model_metrics.workload_pending -= request.workload
self.model_metrics.requests_working.pop(request.reqnum, None)
self.model_metrics.requests_deleting.append(request)
self.last_request_served = time.time()
def _request_success(self, request: RequestMetrics) -> None:
"""
this function is called after a response from model API is received and forwarded.
"""
self.model_metrics.workload_served += request.workload
request.status = "Success"
request.success = True
self.update_pending = True
def _request_errored(self, request: RequestMetrics) -> None:
""" """
this function is called if model API returns an error this function is called if model API returns an error
""" """
self.model_metrics.workload_pending -= workload self.model_metrics.workload_errored += request.workload
self.model_metrics.workload_errored += workload request.status = "Error"
self.model_metrics.requests_working.discard(reqnum) request.success = False
self.update_pending = True
def _request_canceled(self, workload: float, reqnum: int) -> None: def _request_canceled(self, request: RequestMetrics) -> None:
""" """
this function is called if client drops connection before model API has responded this function is called if client drops connection before model API has responded
""" """
self.model_metrics.workload_pending -= workload self.model_metrics.workload_cancelled += request.workload
self.model_metrics.workload_cancelled += workload request.success = True
self.model_metrics.requests_working.discard(reqnum) request.status = "Cancelled"
def _request_reject(self, request: RequestMetrics):
"""
this function is called if the current wait time for the model is above max_wait_time
"""
self.model_metrics.requests_recieved.add(request.reqnum)
self.model_metrics.requests_deleting.append(request)
self.model_metrics.workload_rejected += request.workload
request.success = False
request.status = "Rejected"
self.update_pending = True
async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]:
while True:
await sleep(DELETE_REQUESTS_INTERVAL)
if len(self.model_metrics.requests_deleting) > 0:
await self.__send_delete_requests_and_reset()
async def _send_metrics_loop(self) -> Awaitable[NoReturn]: async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
while True: while True:
@@ -80,10 +123,10 @@ class Metrics:
elapsed = time.time() - self.last_metric_update elapsed = time.time() - self.last_metric_update
if self.system_metrics.model_is_loaded is False and elapsed >= 10: if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait") log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed) await self.__send_metrics_and_reset()
elif self.update_pending or elapsed > 10: elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait") log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed) await self.__send_metrics_and_reset()
def _model_loaded(self, max_throughput: float) -> None: def _model_loaded(self, max_throughput: float) -> None:
self.system_metrics.model_loading_time = ( self.system_metrics.model_loading_time = (
@@ -96,27 +139,94 @@ class Metrics:
self.model_metrics.set_errored(error_msg) self.model_metrics.set_errored(error_msg)
self.system_metrics.model_is_loaded = True self.system_metrics.model_is_loaded = True
def _set_version(self, version: str) -> None:
self.version = version
#######################################Private####################################### #######################################Private#######################################
def __send_metrics_and_reset(self, elapsed): async def __send_delete_requests_and_reset(self):
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
data = {
"worker_id": self.id,
"request_idxs": idxs,
"success": success_flag,
}
log.debug(
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
)
full_path = report_addr.rstrip("/") + "/delete_requests/"
for attempt in range(1, 4):
try:
session = await self.http()
async with session.post(full_path, json=data) as res:
log.debug(f"delete_requests response: {res.status}")
res.raise_for_status()
return True
except asyncio.TimeoutError:
log.debug("delete_requests timed out")
except (ClientResponseError, Exception) as e:
log.debug(f"delete_requests failed with error: {e}")
await asyncio.sleep(2)
log.debug(f"retrying delete_request, attempt: {attempt}")
return False
def compute_autoscaler_data() -> AutoScalaerData: # Take a snapshot of what we plan to send this tick.
return AutoScalaerData( # New arrivals after this snapshot will remain in the queue for the next tick.
snapshot = list(self.model_metrics.requests_deleting)
success_idxs = [r.request_idx for r in snapshot if r.success is True]
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
if not success_idxs and not failed_idxs:
return # nothing to do
for report_addr in self.report_addr:
# TODO: Add a Redis subscriber queue for delete_requests
if report_addr == "https://cloud.vast.ai/api/v0":
# Patch: ignore the Redis API report_addr
continue
sent_success = True
sent_failed = True
if success_idxs:
sent_success = await post(report_addr, success_idxs, True)
if failed_idxs:
sent_failed = await post(report_addr, failed_idxs, False)
if sent_success and sent_failed:
# Remove only the items we actually sent from the live queue.
sent_set = set(success_idxs) | set(failed_idxs)
self.model_metrics.requests_deleting[:] = [
r for r in self.model_metrics.requests_deleting
if r.request_idx not in sent_set
]
break
async def __send_metrics_and_reset(self):
loadtime_snapshot = self.system_metrics.model_loading_time
def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData(
id=self.id, id=self.id,
loadtime=(self.system_metrics.model_loading_time or 0.0), version=self.version,
cur_load=(self.model_metrics.workload_processing / elapsed), loadtime=(loadtime_snapshot or 0.0),
new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected,
max_perf=self.model_metrics.max_throughput, max_perf=self.model_metrics.max_throughput,
cur_perf=self.model_metrics.cur_perf, cur_perf=self.model_metrics.workload_served,
error_msg=self.model_metrics.error_msg or "", error_msg=self.model_metrics.error_msg or "",
num_requests_working=len(self.model_metrics.requests_working), num_requests_working=len(self.model_metrics.requests_working),
num_requests_recieved=len(self.model_metrics.requests_recieved), num_requests_recieved=len(self.model_metrics.requests_recieved),
additional_disk_usage=self.system_metrics.additional_disk_usage, additional_disk_usage=self.system_metrics.additional_disk_usage,
working_request_idxs=self.model_metrics.working_request_idxs,
cur_capacity=0, cur_capacity=0,
max_capacity=0, max_capacity=0,
url=self.url, url=self.url,
) )
def send_data(report_addr: str) -> None: async def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data() data = compute_autoscaler_data()
full_path = report_addr.rstrip("/") + "/worker_status/" full_path = report_addr.rstrip("/") + "/worker_status/"
log.debug( log.debug(
@@ -131,22 +241,32 @@ class Metrics:
) )
for attempt in range(1, 4): for attempt in range(1, 4):
try: try:
requests.post(full_path, json=asdict(data), timeout=1) session = await self.http()
break async with session.post(full_path, json=asdict(data)) as res:
except requests.Timeout: res.raise_for_status()
return True
except asyncio.TimeoutError:
log.debug(f"autoscaler status update timed out") log.debug(f"autoscaler status update timed out")
except Exception as e: except (ClientResponseError, Exception) as e:
log.debug(f"autoscaler status update failed with error: {e}") log.debug(f"autoscaler status update failed with error: {e}")
time.sleep(2) await asyncio.sleep(2)
log.debug(f"retrying autoscaler status update, attempt: {attempt}") log.debug(f"retrying autoscaler status update, attempt: {attempt}")
log.debug(f"failed to send update through {report_addr}")
return False
########### ###########
self.system_metrics.update_disk_usage() self.system_metrics.update_disk_usage()
sent = False
for report_addr in self.report_addr: for report_addr in self.report_addr:
send_data(report_addr) if await send_data(report_addr):
self.update_pending = False sent = True
self.model_metrics.reset() break
self.system_metrics.reset()
self.last_metric_update = time.time() if sent:
# clear the one-shot loadtime only if we actually sent *this* value
self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False
self.model_metrics.reset()
self.last_metric_update = time.time()
+1 -1
View File
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("starting server...") log.debug("starting server...")
app = web.Application() app = web.Application()
app.add_routes(routes) app.add_routes(routes)
runner = web.AppRunner(app, handler_cancellation=True) runner = web.AppRunner(app)
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(
runner, runner,
+9 -6
View File
@@ -10,6 +10,7 @@ from collections import Counter
from dataclasses import dataclass, field, asdict from dataclasses import dataclass, field, asdict
from urllib.parse import urljoin from urllib.parse import urljoin
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
import requests import requests
from lib.data_types import AuthData, ApiPayload from lib.data_types import AuthData, ApiPayload
@@ -120,9 +121,11 @@ class ClientState:
self.url = worker_address self.url = worker_address
url = urljoin(worker_address, self.worker_endpoint) url = urljoin(worker_address, self.worker_endpoint)
self.status = ClientStatus.Generating self.status = ClientStatus.Generating
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
if response.status_code != 200: if response.status_code != 200:
self.infer_error.append( self.infer_error.append(
@@ -289,12 +292,12 @@ def test_load_cmd(
args = arg_parser.parse_args() args = arg_parser.parse_args()
if hasattr(args, "comfy_model"): if hasattr(args, "comfy_model"):
os.environ["COMFY_MODEL"] = args.comfy_model os.environ["COMFY_MODEL"] = args.comfy_model
server_url = dict( server_url = {
prod="https://run.vast.ai", "prod": "https://run.vast.ai",
alpha="https://run-alpha.vast.ai", "alpha": "https://run-alpha.vast.ai",
candidate="https://run-candidate.vast.ai", "candidate": "https://run-candidate.vast.ai",
local="http://localhost:8080", "local": "http://localhost:8080",
)[args.instance] }.get(args.instance, "http://localhost:8080")
run_test( run_test(
num_requests=args.num_requests, num_requests=args.num_requests,
requests_per_second=args.requests_per_second, requests_per_second=args.requests_per_second,
+2 -2
View File
@@ -1,4 +1,4 @@
aiohttp~=3.11 aiohttp[speedups]==3.10.1
anyio~=4.4 anyio~=4.4
lib~=4.0 lib~=4.0
nltk~=3.9 nltk~=3.9
@@ -6,5 +6,5 @@ psutil~=6.0
pycryptodome~=3.20 pycryptodome~=3.20
Requests~=2.32 Requests~=2.32
transformers~=4.52 transformers~=4.52
utils~=1.0 utils==1.0.*
hf_transfer>=0.1.9 hf_transfer>=0.1.9
+23 -10
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log" DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log" PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}" REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}" USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}" WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR" mkdir -p "$WORKSPACE_DIR"
@@ -41,24 +41,37 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG echo_var PYWORKER_LOG
echo_var MODEL_LOG echo_var MODEL_LOG
env | grep _ >> /etc/environment; # Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
name=${line%%=*}
value=${line#*=}
printf '%s="%s"\n' "$name" "$value"
done > /etc/environment
fi
if [ ! -d "$ENV_PATH" ] if [ ! -d "$ENV_PATH" ]
then then
echo "setting up venv" echo "setting up venv"
curl -LsSf https://astral.sh/uv/install.sh | sh if ! which uv; then
source ~/.local/bin/env curl -LsSf https://astral.sh/uv/install.sh | sh
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR" source ~/.local/bin/env
fi
uv venv --managed-python "$WORKSPACE_DIR/worker-env" -p 3.10 # Fork testing
source "$WORKSPACE_DIR/worker-env/bin/activate" [[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
if [[ -n ${PYWORKER_REF:-} ]]; then
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
fi
uv pip install -r vast-pyworker/requirements.txt uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
source "$ENV_PATH/bin/activate"
uv pip install -r "${SERVER_DIR}/requirements.txt"
touch ~/.no_auto_tmux touch ~/.no_auto_tmux
else else
source ~/.local/bin/env [[ -f ~/.local/bin/env ]] && source ~/.local/bin/env
source "$WORKSPACE_DIR/worker-env/bin/activate" source "$WORKSPACE_DIR/worker-env/bin/activate"
echo "environment activated" echo "environment activated"
echo "venv: $VIRTUAL_ENV" echo "venv: $VIRTUAL_ENV"
+61 -5
View File
@@ -1,5 +1,6 @@
import logging import logging
from typing import Any, Dict, Optional import time
from typing import Any, Dict, Optional, Tuple
import requests import requests
@@ -16,6 +17,60 @@ class Endpoint:
Utility class for handling endpoint operations. Utility class for handling endpoint operations.
""" """
@staticmethod
def get_endpoint_info(
endpoint_name: str, account_api_key: str, instance: str
) -> Optional[Dict[str, Any]]:
headers = {"Authorization": f"Bearer {account_api_key}"}
url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}"
# Retry a few times to smooth over transient propagation/network delays
for attempt in range(4):
try:
response = requests.get(url, headers=headers, timeout=8)
if response.status_code != 200:
# brief backoff and retry
time.sleep(0.3 * (attempt + 1))
continue
try:
data = response.json()
except Exception:
# JSON parse failed; backoff and retry
time.sleep(0.3 * (attempt + 1))
continue
result = data.get("results", []) if isinstance(data, dict) else []
endpoint = next(
(item for item in result if item.get("endpoint_name") == endpoint_name),
None,
)
if endpoint and endpoint.get("id") and endpoint.get("api_key"):
return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")}
except Exception:
# network or other transient error; retry
time.sleep(0.3 * (attempt + 1))
return None
@staticmethod
def get_autoscaler_server_url(instance: str) -> str:
endpoints = {
"alpha": "run-alpha",
"candidate": "run-candidate",
"prod": "run",
}
host = endpoints.get(instance)
if host:
return f"https://{host}.vast.ai/"
return "http://localhost:8080"
@staticmethod
def get_server_url(instance: str) -> str:
endpoints = {
"alpha": "alpha",
"candidate": "candidate",
"prod": "console",
}
host = endpoints.get(instance, "alpha")
return f"https://{host}.vast.ai/api/v0/endptjobs/"
@staticmethod @staticmethod
def get_endpoint_api_key( def get_endpoint_api_key(
endpoint_name: str, account_api_key: str, instance: str endpoint_name: str, account_api_key: str, instance: str
@@ -30,13 +85,14 @@ class Endpoint:
Returns: Returns:
Endpoint API key if successful, None otherwise Endpoint API key if successful, None otherwise
""" """
vast_console_url = "https://console.vast.ai/api/v0/endptjobs/"
headers = {"Authorization": f"Bearer {account_api_key}"} headers = {"Authorization": f"Bearer {account_api_key}"}
try: try:
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}") log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
response = requests.get( response = requests.get(
f"{vast_console_url}?autoscaler_instance={instance}", headers=headers f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
headers=headers,
timeout=8,
) )
if response.status_code != 200: if response.status_code != 200:
@@ -46,14 +102,14 @@ class Endpoint:
try: try:
data = response.json() data = response.json()
except requests.exceptions.JSONDecodeError as e: except Exception as e:
log.debug(f"Failed to parse JSON response: {e}") log.debug(f"Failed to parse JSON response: {e}")
return None return None
result = data.get("results", []) result = data.get("results", [])
endpoint: Optional[Dict[str, Any]] = next( endpoint: Optional[Dict[str, Any]] = next(
(item for item in result if item["endpoint_name"] == endpoint_name), (item for item in result if item.get("endpoint_name") == endpoint_name),
None, None,
) )
if not endpoint: if not endpoint:
+15
View File
@@ -0,0 +1,15 @@
import tempfile
from functools import cache
import requests
@cache
def get_cert_file_path():
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
response = requests.get(cert_url)
response.raise_for_status()
# Use a temporary file that is not deleted on close
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
f.write(response.content)
return f.name
+222
View File
@@ -0,0 +1,222 @@
# ComfyUI PyWorker
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture.
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
## Requirements
This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [ComfyUI API Wrapper](https://github.com/ai-dock/comfyui-api-wrapper).
A docker image is provided but you may use any if the above requirements are met.
## Benchmarking
### Custom Benchmark Workflows
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
**Ways to provide the benchmark file:**
- Fork this repository and add your `benchmark.json` file
- Write the file during worker provisioning (onstart script or setup phase)
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
### Default Benchmark (Fallback)
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
| Environment Variable | Default Value | Description |
| -------------------- | ------------- | ----------- |
| BENCHMARK_TEST_WIDTH | 512 | Image width (pixels) |
| BENCHMARK_TEST_HEIGHT | 512 | Image height (pixels) |
| BENCHMARK_TEST_STEPS | 20 | Number of denoising steps |
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
#### Calibrating Fallback Benchmark Duration
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
**Example:** If your typical workflow should complete in 90 seconds on acceptable hardware:
```bash
# 1. Measure it/sec on your reference machine
# RTX 4090 typically achieves ~43 it/sec with SD1.5
# 2. Calculate required steps
# 90 seconds × 43 it/sec = 3870 steps
# 3. Configure benchmark
export BENCHMARK_TEST_STEPS=3870
# 4. Machines completing significantly slower than 90s indicate hardware issues
```
**Performance expectations:**
- Benchmark duration should remain consistent across identical GPU models
- Significant variation (>20%) may indicate thermal, power, or configuration issues
## Endpoint
The worker provides a single endpoint:
- `/generate/sync`: Processes ComfyUI workflows using either predefined modifiers or custom workflow JSON
## Request Format
The worker accepts requests in the following format. Choose either modifier mode OR custom workflow mode:
**Modifier Mode:**
```json
{
"input": {
"request_id": "uuid-string", // optional - UUID generated if not provided
"modifier": "RawWorkflow",
"modifications": {
"prompt": "a beautiful landscape",
"width": 1024,
"height": 1024,
"steps": 20,
"seed": 123456789
},
"s3": { ... }, // optional
"webhook": { ... } // optional
}
}
```
**Custom Workflow Mode:**
```json
{
"input": {
"request_id": "uuid-string", // optional - UUID generated if not provided
"workflow_json": {
// Complete ComfyUI workflow JSON
},
"s3": { ... }, // optional
"webhook": { ... } // optional
}
}
```
## Request Fields
### Required Fields
- **`input`**: Contains the main workflow data
- **`input.request_id`**: Unique identifier for the request
### Workflow Mode (Choose One)
You must provide either `modifier` OR `workflow_json`, but not both:
#### Option 1: Modifier Mode
- **`input.modifier`**: Name of the predefined workflow modifier (e.g., "Text2Image")
- **`input.modifications`**: Parameters to pass to the modifier
#### Option 2: Custom Workflow Mode
- **`input.workflow_json`**: Complete ComfyUI workflow JSON
### Optional Fields
- **`input.s3`**: S3 configuration for file storage
- **`input.webhook`**: Webhook configuration for notifications
These configurations can be provided in the request JSON or via environment variables. Request-level configuration takes precedence over environment variables.
#### S3 Configuration
**Via Request JSON:**
```json
"s3": {
"access_key_id": "your-s3-access-key",
"secret_access_key": "your-s3-secret-access-key",
"endpoint_url": "https://my-endpoint.backblaze.com",
"bucket_name": "your-bucket",
"region": "us-east-1"
}
```
**Via Environment Variables:**
```bash
S3_ACCESS_KEY_ID=your-key
S3_SECRET_ACCESS_KEY=your-secret
S3_BUCKET_NAME=your-bucket
S3_ENDPOINT_URL=https://s3.amazonaws.com
S3_REGION=us-east-1
```
#### Webhook Configuration
**Via Request JSON:**
```json
"webhook": {
"url": "your-webhook-url",
"extra_params": {
"custom_field": "value"
}
}
```
**Via Environment Variables:**
```bash
WEBHOOK_URL=https://your-webhook.com # Default webhook URL
WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
```
## Examples
### Basic Text-to-Image (Modifier Mode)
```json
{
"input": {
"modifier": "Text2Image",
"modifications": {
"prompt": "a cat sitting on a windowsill",
"width": 512,
"height": 512,
"steps": 20,
"seed": 42
}
}
}
```
### Custom Workflow Mode
```json
{
"input": {
"request_id": "67890", // optional - using custom ID for tracking
"workflow_json": {
"3": {
"inputs": {
"seed": 42,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": ["4", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["5", 0]
},
"class_type": "KSampler"
}
}
}
}
```
## Client Libraries
See the test client examples for implementation details on how to integrate with the ComfyUI worker.
---
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
View File
+155
View File
@@ -0,0 +1,155 @@
import logging
import uuid
import random
from urllib.parse import urljoin
import json
import requests
from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types import count_workload
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_text2image_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
"""Simple Text2Image using the new modifier-based approach"""
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
"""Helper function for making requests with consistent error handling"""
try:
response = requests.post(
url,
json=payload,
timeout=timeout,
verify=verify
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred during {context}: {http_err}")
log.error(f"Status Code: {response.status_code}")
log.error("Response content:", response.text)
return None
except requests.exceptions.Timeout:
log.error(f"Timeout occurred during {context}: {url}")
return None
except requests.exceptions.ConnectionError:
log.error(f"Connection error occurred during {context}: {url}")
return None
except json.JSONDecodeError as json_err:
log.error(f"Failed to decode JSON response during {context}: {json_err}")
if 'response' in locals():
print("Response content:", response.text)
return None
except Exception as err:
log.error(f"An unexpected error occurred during {context}: {err}")
if 'response' in locals():
log.error("Response content (if available):", response.text)
return None
WORKER_ENDPOINT = "/generate/sync"
# This worker has concurrency = 1. All workloads have cost value 1.0
COST = count_workload()
# Route to get worker URL
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
# First request - get routing information
route_response = make_request(
url=urljoin(server_url, "/route/"),
payload=route_payload,
timeout=4,
context="route request"
)
if route_response is None:
return None
if "url" not in route_response or not route_response["url"]:
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
return None
if "status" in route_response:
print(f"Autoscaler status: {route_response['status']}")
return None
# Extract data from route response
url = route_response["url"]
auth_data = dict(
signature=route_response["signature"],
cost=route_response["cost"],
endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"],
url=route_response["url"],
)
# Build the payload for the worker request
worker_payload = {
"input": {
"request_id": str(uuid.uuid4()),
"modifier": "Text2Image",
"modifications": {
"prompt": "a beautiful landscape with mountains and lakes",
"width": 1024,
"height": 1024,
"steps": 20,
"seed": random.randint(0, 2**32 - 1)
},
"workflow_json": {} # Empty since using modifier approach
}
}
req_data = dict(payload=worker_payload, auth_data=auth_data)
worker_url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {worker_url}")
# Second request - call the worker endpoint
worker_response = make_request(
url=worker_url,
payload=req_data,
verify=get_cert_file_path(),
context="worker request"
)
return worker_response
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
result = call_text2image_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
if result is None:
log.error("Text2Image workflow failed")
else:
print(result)
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
+84
View File
@@ -0,0 +1,84 @@
import os
import sys
import random
import dataclasses
from typing import Dict, Any
from functools import cache
from math import ceil
from pathlib import Path
import json
import logging
from lib.data_types import ApiPayload, JsonDataException
log = logging.getLogger(__file__)
def count_workload() -> float:
# Always 100.0 where there is a single instance of ComfyUI handling requests
# Results will indicate % or a job completed per second. Avoids sub 0.1 sec performance indication
return 100.0
@dataclasses.dataclass
class ComfyWorkflowData(ApiPayload):
input: dict
@classmethod
def for_test(cls):
"""
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
Otherwise, use the variables available to simulate workflows of the required running time
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
"""
# Try to load benchmark.json
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
if benchmark_file.exists():
try:
with open(benchmark_file, "r") as f:
benchmark_workflow = json.load(f)
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"workflow_json": benchmark_workflow
}
)
except (json.JSONDecodeError, IOError):
# JSON is malformed or file can't be read, fall through to default
log.error(f"Failed to benchmark using {benchmark_file}")
# Fallback: read prompts and construct payload
log.info("Using fallback method for benchmarking")
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
test_prompt = random.choice(test_prompts).rstrip()
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"modifier": "Text2Image",
"modifications": {
"prompt": test_prompt,
"width": os.getenv('BENCHMARK_TEST_WIDTH', 512),
"height": os.getenv('BENCHMARK_TEST_HEIGHT', 512),
"steps": os.getenv('BENCHMARK_TEST_STEPS', 20),
"seed": random.randint(0, sys.maxsize),
}
}
)
def generate_payload_json(self) -> Dict[str, Any]:
# input is already a dict, just return it wrapped in the expected structure
return {"input": self.input}
def count_workload(self) -> float:
return count_workload()
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ComfyWorkflowData":
# Extract required fields
if "input" not in json_msg:
raise JsonDataException({"input": "missing parameter"})
return cls(
input=json_msg["input"]
)
@@ -0,0 +1,107 @@
{
"3": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "text, watermark",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
@@ -0,0 +1,34 @@
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
stardew valley, fine details
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
realistic futuristic city-downtown with short buildings, sunset
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
Pope Francis wearing biker (leather jacket), a masterpiece
Luke Skywalker ordering a burger and fries from the Death Star canteen.
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
london luxurious interior living-room, light walls
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
the street of amedieval fantasy town, at dawn, dark, highly detailed
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
+117
View File
@@ -0,0 +1,117 @@
import os
import logging
import dataclasses
import base64
from typing import Optional, Union, Type
from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import ComfyWorkflowData
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: ", # This error is emitted when the model file is not there at all
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def generate_client_response(
client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
# Check if the response is actually streaming based on response headers/content-type
is_streaming_response = (
model_response.content_type == "text/event-stream"
or model_response.content_type == "application/x-ndjson"
or model_response.headers.get("Transfer-Encoding") == "chunked"
or "stream" in model_response.content_type.lower()
)
if is_streaming_response:
log.debug("Detected streaming response...")
res = web.StreamResponse()
res.content_type = model_response.content_type
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
else:
log.debug("Detected non-streaming response...")
content = await model_response.read()
return web.Response(
body=content,
status=model_response.status,
content_type=model_response.content_type
)
@dataclasses.dataclass
class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/generate/sync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[ComfyWorkflowData]:
return ComfyWorkflowData
def make_benchmark_payload(self) -> ComfyWorkflowData:
return ComfyWorkflowData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
return await generate_client_response(client_request, model_response)
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=False,
benchmark_handler=ComfyWorkflowHandler(
benchmark_runs=3, benchmark_words=100
),
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, "Downloading:"),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
+8
View File
@@ -0,0 +1,8 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import ComfyWorkflowData
WORKER_ENDPOINT = "/generate/sync"
if __name__ == "__main__":
test_load_cmd(ComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
+3
View File
@@ -5,6 +5,7 @@ import requests
from lib.test_utils import print_truncate_res from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
""" """
NOTE: this client example uses a custom comfy workflow compatible with SD3 only NOTE: this client example uses a custom comfy workflow compatible with SD3 only
@@ -51,6 +52,7 @@ def call_default_workflow(
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
response.raise_for_status() response.raise_for_status()
print_truncate_res(str(response.json())) print_truncate_res(str(response.json()))
@@ -141,6 +143,7 @@ def call_custom_workflow_for_sd3(
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
response.raise_for_status() response.raise_for_status()
print_truncate_res(str(response.json())) print_truncate_res(str(response.json()))
+1 -1
View File
@@ -13,7 +13,7 @@ from lib.server import start_server
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
MODEL_SERVER_URL = "http://0.0.0.0:38188" MODEL_SERVER_URL = "http://127.0.0.1:18288" # API Wrapper Service
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded # This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188" MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
+193 -172
View File
@@ -6,6 +6,7 @@ from urllib.parse import urljoin
from typing import Dict, Any, Optional, Iterator, Union, List from typing import Dict, Any, Optional, Iterator, Union, List
import requests import requests
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig from .data_types.client import CompletionConfig, ChatCompletionConfig
logging.basicConfig( logging.basicConfig(
@@ -19,40 +20,37 @@ COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?" TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
class APIClient: class APIClient:
"""Lightweight client focused solely on API communication""" """Lightweight client focused solely on API communication"""
# Remove the generic WORKER_ENDPOINT since we're now going direct # Remove the generic WORKER_ENDPOINT since we're now going direct
DEFAULT_COST = 100 DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4 DEFAULT_TIMEOUT = 4
def __init__(self, endpoint_group_name: str, api_key: str, server_url: str): def __init__(
self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
):
self.endpoint_group_name = endpoint_group_name self.endpoint_group_name = endpoint_group_name
self.api_key = api_key self.api_key = api_key
self.server_url = server_url self.server_url = server_url
self.endpoint_api_key = self._get_endpoint_api_key() self.endpoint_api_key = endpoint_api_key
def _get_endpoint_api_key(self) -> Optional[str]:
"""Get the endpoint API key"""
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
)
if not endpoint_api_key:
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
return endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]: def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service""" """Get worker URL and auth data from routing service"""
if not self.endpoint_api_key: if not self.endpoint_api_key:
raise ValueError("No valid endpoint API key available") raise ValueError("No valid endpoint API key available")
route_payload = { route_payload = {
"endpoint": self.endpoint_group_name, "endpoint": self.endpoint_group_name,
"api_key": self.endpoint_api_key, "api_key": self.endpoint_api_key,
"cost": cost, "cost": cost,
} }
response = requests.post( response = requests.post(
urljoin(self.server_url, "/route/"), urljoin(self.server_url, "/route/"),
json=route_payload, json=route_payload,
@@ -60,7 +58,7 @@ class APIClient:
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]: def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Create auth data from routing response""" """Create auth data from routing response"""
return { return {
@@ -70,42 +68,46 @@ class APIClient:
"reqnum": message["reqnum"], "reqnum": message["reqnum"],
"url": message["url"], "url": message["url"],
} }
def _make_request(self, payload: Dict[str, Any], endpoint: str, method: str = "POST", def _make_request(
stream: bool = False) -> Union[Dict[str, Any], Iterator[str]]: self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint""" """Make request directly to the specific worker endpoint"""
# Get worker URL and auth data # Get worker URL and auth data
cost = payload.get('max_tokens') cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost) message = self._get_worker_url(cost=cost)
worker_url = message["url"] worker_url = message["url"]
auth_data = self._create_auth_data(message) auth_data = self._create_auth_data(message)
req_data = { req_data = {"payload": {"input": payload}, "auth_data": auth_data}
"payload": {
"input": payload
},
"auth_data": auth_data
}
url = urljoin(worker_url, endpoint) url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}") log.debug(f"Making direct request to: {url}")
log.debug(f"Payload: {req_data}") log.debug(f"Payload: {req_data}")
# Make the request using the specified method # Make the request using the specified method
if method.upper() == "POST": if method.upper() == "POST":
response = requests.post(url, json=req_data, stream=stream) response = requests.post(
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET": elif method.upper() == "GET":
response = requests.get(url, params=req_data, stream=stream) response = requests.get(
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else: else:
raise ValueError(f"Unsupported HTTP method: {method}") raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status() response.raise_for_status()
if stream: if stream:
return self._handle_streaming_response(response) return self._handle_streaming_response(response)
else: else:
return response.json() return response.json()
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]: def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
"""Handle streaming response and yield tokens""" """Handle streaming response and yield tokens"""
try: try:
@@ -124,61 +126,60 @@ class APIClient:
log.error(f"Error handling streaming response: {e}") log.error(f"Error handling streaming response: {e}")
raise raise
def call_completions(
def call_completions(self, config: CompletionConfig) -> Union[Dict[str, Any], Iterator[str]]: self, config: CompletionConfig
payload = config.to_dict() ) -> Union[Dict[str, Any], Iterator[str]]:
return self._make_request(
payload=payload,
endpoint="/v1/completions",
stream=config.stream
)
def call_chat_completions(self, config: ChatCompletionConfig) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict() payload = config.to_dict()
return self._make_request( return self._make_request(
payload=payload, payload=payload, endpoint="/v1/completions", stream=config.stream
endpoint="/v1/chat/completions", )
stream=config.stream
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
) )
class ToolManager: class ToolManager:
"""Handles tool definitions and execution""" """Handles tool definitions and execution"""
@staticmethod @staticmethod
def list_files() -> str: def list_files() -> str:
"""Execute ls on current directory""" """Execute ls on current directory"""
try: try:
result = subprocess.run(['ls', '-la', '.'], capture_output=True, text=True, timeout=10) result = subprocess.run(
["ls", "-la", "."], capture_output=True, text=True, timeout=10
)
if result.returncode == 0: if result.returncode == 0:
return result.stdout return result.stdout
else: else:
return f"Error: {result.stderr}" return f"Error: {result.stderr}"
except Exception as e: except Exception as e:
return f"Error running ls: {e}" return f"Error running ls: {e}"
@staticmethod @staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]: def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""Get the ls tool definition""" """Get the ls tool definition"""
return [{ return [
"type": "function", {
"function": { "type": "function",
"name": "list_files", "function": {
"description": "List files and directories in the cwd", "name": "list_files",
"parameters": { "description": "List files and directories in the cwd",
"type": "object", "parameters": {"type": "object", "properties": {}, "required": []},
"properties": {}, },
"required": []
}
} }
}] ]
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str: def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result""" """Execute a tool call and return the result"""
function_name = tool_call["function"]["name"] function_name = tool_call["function"]["name"]
if function_name == "list_files": if function_name == "list_files":
return self.list_files() return self.list_files()
else: else:
@@ -187,13 +188,17 @@ class ToolManager:
class APIDemo: class APIDemo:
"""Demo and testing functionality for the API client""" """Demo and testing functionality for the API client"""
def __init__(self, client: APIClient, model: str, tool_manager: ToolManager = None): def __init__(
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
self.client = client self.client = client
self.model = model self.model = model
self.tool_manager = tool_manager or ToolManager() self.tool_manager = tool_manager or ToolManager()
def handle_streaming_response(self, response_stream, show_reasoning: bool = True) -> str: def handle_streaming_response(
self, response_stream, show_reasoning: bool = True
) -> str:
""" """
Handle streaming chat response and display all output. Handle streaming chat response and display all output.
""" """
@@ -260,178 +265,181 @@ class APIDemo:
return full_response return full_response
def test_tool_support(self) -> bool: def test_tool_support(self) -> bool:
"""Test if the endpoint supports function calling""" """Test if the endpoint supports function calling"""
log.debug("Testing endpoint tool calling support...") log.debug("Testing endpoint tool calling support...")
# Try a simple request with minimal tools to test support # Try a simple request with minimal tools to test support
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [{ minimal_tool = [
"type": "function", {
"function": { "type": "function",
"name": "test_function", "function": {"name": "test_function", "description": "Test function"},
"description": "Test function"
} }
}] ]
config = ChatCompletionConfig( config = ChatCompletionConfig(
model=self.model, model=self.model,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
tools=minimal_tool, tools=minimal_tool,
tool_choice="none" # Don't actually call the tool tool_choice="none", # Don't actually call the tool
) )
try: try:
response = self.client.call_chat_completions(config) response = self.client.call_chat_completions(config)
return True return True
except Exception as e: except Exception as e:
log.error(f"Error: Endpoint does not support tool calling: {e}") log.error(f"Error: Endpoint does not support tool calling: {e}")
return False return False
def demo_completions(self) -> None: def demo_completions(self) -> None:
"""Demo: test basic completions endpoint""" """Demo: test basic completions endpoint"""
print("=" * 60) print("=" * 60)
print("COMPLETIONS DEMO") print("COMPLETIONS DEMO")
print("=" * 60) print("=" * 60)
config = CompletionConfig( config = CompletionConfig(
model=self.model, model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
prompt=COMPLETIONS_PROMPT, )
stream=False
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
) )
log.info(f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'")
response = self.client.call_completions(config) response = self.client.call_completions(config)
if isinstance(response, dict): if isinstance(response, dict):
print("\nResponse:") print("\nResponse:")
print(json.dumps(response, indent=2)) print(json.dumps(response, indent=2))
else: else:
log.error("Unexpected response format") log.error("Unexpected response format")
def demo_chat(self, use_streaming: bool = True) -> None: def demo_chat(self, use_streaming: bool = True) -> None:
""" """
Demo: test chat completions endpoint with optional streaming Demo: test chat completions endpoint with optional streaming
""" """
print("=" * 60) print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}") print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60) print("=" * 60)
config = ChatCompletionConfig( config = ChatCompletionConfig(
model=self.model, model=self.model,
messages=[{"role": "user", "content": CHAT_PROMPT}], messages=[{"role": "user", "content": CHAT_PROMPT}],
stream=use_streaming, stream=use_streaming,
) )
log.info(f"Testing chat completions with model '{self.model}'...") log.info(f"Testing chat completions with model '{self.model}'...")
response = self.client.call_chat_completions(config) response = self.client.call_chat_completions(config)
if use_streaming: if use_streaming:
try: try:
self.handle_streaming_response(response, show_reasoning=True) self.handle_streaming_response(response, show_reasoning=True)
except Exception as e: except Exception as e:
log.error(f"\nError during streaming: {e}") log.error(f"\nError during streaming: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return return
else: else:
if isinstance(response, dict): if isinstance(response, dict):
choice = response.get("choices", [{}])[0] choice = response.get("choices", [{}])[0]
message = choice.get("message", {}) message = choice.get("message", {})
content = message.get("content", "") content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "") reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning: if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m") print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}") print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:") print(f"\nFull Response:")
print(json.dumps(response, indent=2)) print(json.dumps(response, indent=2))
else: else:
log.error("Unexpected response format") log.error("Unexpected response format")
def demo_ls_tool(self) -> None: def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees""" """Demo: ask LLM to list files in the current directory and describe what it sees"""
print("=" * 60) print("=" * 60)
print("TOOL USE DEMO: List Directory Contents") print("TOOL USE DEMO: List Directory Contents")
print("=" * 60) print("=" * 60)
# Test if tools are supported first # Test if tools are supported first
if not self.test_tool_support(): if not self.test_tool_support():
return return
# Request with tool available # Request with tool available
messages = [ messages = [{"role": "user", "content": TOOLS_PROMPT}]
{"role": "user", "content": TOOLS_PROMPT}
]
config = ChatCompletionConfig( config = ChatCompletionConfig(
model=self.model, model=self.model,
messages=messages, messages=messages,
tools=self.tool_manager.get_ls_tool_definition(), tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto" tool_choice="auto",
) )
log.info(f"Making initial request with tool using model '{self.model}'...") log.info(f"Making initial request with tool using model '{self.model}'...")
response = self.client.call_chat_completions(config) response = self.client.call_chat_completions(config)
if not isinstance(response, dict): if not isinstance(response, dict):
raise ValueError("Expected dict response for tool use") raise ValueError("Expected dict response for tool use")
choice = response.get("choices", [{}])[0] choice = response.get("choices", [{}])[0]
message = choice.get("message", {}) message = choice.get("message", {})
print(f"Assistant response: {message.get('content', 'No content')}") print(f"Assistant response: {message.get('content', 'No content')}")
# Check for tool calls # Check for tool calls
tool_calls = message.get("tool_calls") tool_calls = message.get("tool_calls")
if not tool_calls: if not tool_calls:
raise ValueError("No tool calls made - model may not support function calling") raise ValueError(
"No tool calls made - model may not support function calling"
)
print(f"Tool calls detected: {len(tool_calls)}") print(f"Tool calls detected: {len(tool_calls)}")
# Execute the tool call # Execute the tool call
for tool_call in tool_calls: for tool_call in tool_calls:
function_name = tool_call["function"]["name"] function_name = tool_call["function"]["name"]
print(f"Executing tool: {function_name}") print(f"Executing tool: {function_name}")
tool_result = self.tool_manager.execute_tool_call(tool_call) tool_result = self.tool_manager.execute_tool_call(tool_call)
print(f"Tool result:\n{tool_result}") print(f"Tool result:\n{tool_result}")
# Add tool result and continue conversation # Add tool result and continue conversation
messages.append(message) # Add assistant's message with tool call messages.append(message) # Add assistant's message with tool call
messages.append({ messages.append(
"role": "tool", {
"tool_call_id": tool_call["id"], "role": "tool",
"content": tool_result "tool_call_id": tool_call["id"],
}) "content": tool_result,
}
)
# Get final response # Get final response
final_config = ChatCompletionConfig( final_config = ChatCompletionConfig(
model=self.model, model=self.model,
messages=messages, messages=messages,
tools=self.tool_manager.get_ls_tool_definition() tools=self.tool_manager.get_ls_tool_definition(),
) )
print("Getting final response...") print("Getting final response...")
final_response = self.client.call_chat_completions(final_config) final_response = self.client.call_chat_completions(final_config)
if isinstance(final_response, dict): if isinstance(final_response, dict):
final_choice = final_response.get("choices", [{}])[0] final_choice = final_response.get("choices", [{}])[0]
final_message = final_choice.get("message", {}) final_message = final_choice.get("message", {})
final_content = final_message.get("content", "") final_content = final_message.get("content", "")
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:") print("FINAL LLM ANALYSIS:")
print("=" * 60) print("=" * 60)
print(final_content) print(final_content)
print("=" * 60) print("=" * 60)
def interactive_chat(self) -> None: def interactive_chat(self) -> None:
"""Interactive chat session with streaming""" """Interactive chat session with streaming"""
print("=" * 60) print("=" * 60)
@@ -440,40 +448,39 @@ class APIDemo:
print(f"Using model: {self.model}") print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history") print("Type 'quit' to exit, 'clear' to clear history")
print() print()
messages = [] messages = []
while True: while True:
try: try:
user_input = input("You: ").strip() user_input = input("You: ").strip()
if user_input.lower() == 'quit': if user_input.lower() == "quit":
print("👋 Goodbye!") print("👋 Goodbye!")
break break
elif user_input.lower() == 'clear': elif user_input.lower() == "clear":
messages = [] messages = []
print("Chat history cleared") print("Chat history cleared")
continue continue
elif not user_input: elif not user_input:
continue continue
messages.append({"role": "user", "content": user_input}) messages.append({"role": "user", "content": user_input})
config = ChatCompletionConfig( config = ChatCompletionConfig(
model=self.model, model=self.model, messages=messages, stream=True, temperature=0.7
messages=messages,
stream=True,
temperature=0.7
) )
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
response = self.client.call_chat_completions(config) response = self.client.call_chat_completions(config)
assistant_content = self.handle_streaming_response(response, show_reasoning=True) assistant_content = self.handle_streaming_response(
response, show_reasoning=True
)
# Add assistant response to conversation history # Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content}) messages.append({"role": "assistant", "content": assistant_content})
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n👋 Chat interrupted. Goodbye!") print("\n👋 Chat interrupted. Goodbye!")
break break
@@ -485,50 +492,49 @@ class APIDemo:
def main(): def main():
"""Main function with CLI switches for different tests""" """Main function with CLI switches for different tests"""
from lib.test_utils import test_args from lib.test_utils import test_args
# Add mandatory model argument # Add mandatory model argument
test_args.add_argument( test_args.add_argument(
"--model", "--model", required=True, help="Model to use for requests (required)"
required=True,
help="Model to use for requests (required)"
) )
# Add test mode arguments # Add test mode arguments
test_args.add_argument( test_args.add_argument(
"--completion", "--completion", action="store_true", help="Test completions endpoint"
action="store_true",
help="Test completions endpoint"
) )
test_args.add_argument( test_args.add_argument(
"--chat", "--chat",
action="store_true", action="store_true",
help="Test chat completions endpoint (non-streaming)" help="Test chat completions endpoint (non-streaming)",
) )
test_args.add_argument( test_args.add_argument(
"--chat-stream", "--chat-stream",
action="store_true", action="store_true",
help="Test chat completions endpoint with streaming" help="Test chat completions endpoint with streaming",
) )
test_args.add_argument( test_args.add_argument(
"--tools", "--tools",
action="store_true", action="store_true",
help="Test function calling with ls tool (non-streaming)" help="Test function calling with ls tool (non-streaming)",
) )
test_args.add_argument( test_args.add_argument(
"--interactive", "--interactive",
action="store_true", action="store_true",
help="Start interactive streaming chat session" help="Start interactive streaming chat session",
) )
args = test_args.parse_args() args = test_args.parse_args()
# Check that only one test mode is selected # Check that only one test mode is selected
test_modes = [ test_modes = [
args.completion, args.chat, args.chat_stream, args.completion,
args.tools, args.interactive args.chat,
args.chat_stream,
args.tools,
args.interactive,
] ]
selected_count = sum(test_modes) selected_count = sum(test_modes)
if selected_count == 0: if selected_count == 0:
print("Please specify exactly one test mode:") print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint") print(" --completion : Test completions endpoint")
@@ -536,27 +542,42 @@ def main():
print(" --chat-stream : Test chat completions endpoint with streaming") print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool (non-streaming)") print(" --tools : Test function calling with ls tool (non-streaming)")
print(" --interactive : Start interactive streaming chat session") print(" --interactive : Start interactive streaming chat session")
print(f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT") print(
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
sys.exit(1) sys.exit(1)
elif selected_count > 1: elif selected_count > 1:
print("Please specify exactly one test mode") print("Please specify exactly one test mode")
sys.exit(1) sys.exit(1)
try: try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client # Create the core API client
client = APIClient( client = APIClient(
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key, api_key=args.api_key,
server_url=args.server_url server_url=Endpoint.get_autoscaler_server_url(args.instance),
endpoint_api_key=endpoint_api_key,
) )
# Create tool manager and demo (passing the model parameter) # Create tool manager and demo (passing the model parameter)
tool_manager = ToolManager() tool_manager = ToolManager()
demo = APIDemo(client, args.model, tool_manager) demo = APIDemo(client, args.model, tool_manager)
print(f"Using model: {args.model}") print(f"Using model: {args.model}")
print("=" * 60) print("=" * 60)
# Run the selected test # Run the selected test
if args.completion: if args.completion:
demo.demo_completions() demo.demo_completions()
@@ -568,11 +589,11 @@ def main():
demo.demo_ls_tool() demo.demo_ls_tool()
elif args.interactive: elif args.interactive:
demo.interactive_chat() demo.interactive_chat()
except Exception as e: except Exception as e:
log.error(f"Error during test: {e}", exc_info=True) log.error(f"Error during test: {e}", exc_info=True)
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
+12 -8
View File
@@ -3,11 +3,13 @@ from dataclasses import dataclass, field, fields, is_dataclass
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
class SerializableDataclass: class SerializableDataclass:
def _serialize_recursive(self, obj: Any) -> Any: def _serialize_recursive(self, obj: Any) -> Any:
if is_dataclass(obj): if is_dataclass(obj):
return {field.name: self._serialize_recursive(getattr(obj, field.name)) return {
for field in fields(obj)} field.name: self._serialize_recursive(getattr(obj, field.name))
for field in fields(obj)
}
elif isinstance(obj, dict): elif isinstance(obj, dict):
return {key: self._serialize_recursive(value) for key, value in obj.items()} return {key: self._serialize_recursive(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)): elif isinstance(obj, (list, tuple)):
@@ -16,10 +18,10 @@ class SerializableDataclass:
return [self._serialize_recursive(item) for item in obj] return [self._serialize_recursive(item) for item in obj]
else: else:
return obj return obj
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return self._serialize_recursive(self) return self._serialize_recursive(self)
def to_json(self, indent: int = 2) -> str: def to_json(self, indent: int = 2) -> str:
return json.dumps(self.to_dict(), indent=indent) return json.dumps(self.to_dict(), indent=indent)
@@ -27,6 +29,7 @@ class SerializableDataclass:
@dataclass @dataclass
class CompletionConfig(SerializableDataclass): class CompletionConfig(SerializableDataclass):
"""Configuration for completion requests""" """Configuration for completion requests"""
model: str model: str
prompt: str = "Hello" prompt: str = "Hello"
max_tokens: int = 256 max_tokens: int = 256
@@ -39,8 +42,9 @@ class CompletionConfig(SerializableDataclass):
@dataclass @dataclass
class ChatCompletionConfig(SerializableDataclass): class ChatCompletionConfig(SerializableDataclass):
"""Configuration for chat completion requests""" """Configuration for chat completion requests"""
model: str model: str
messages: list = None messages: list = field(default_factory=list)
max_tokens: int = 2096 max_tokens: int = 2096
temperature: float = 0.7 temperature: float = 0.7
top_k: int = 20 top_k: int = 20
@@ -48,7 +52,7 @@ class ChatCompletionConfig(SerializableDataclass):
stream: bool = False stream: bool = False
tools: Optional[List[Dict[str, Any]]] = field(default_factory=list) tools: Optional[List[Dict[str, Any]]] = field(default_factory=list)
tool_choice: str = "auto" tool_choice: str = "auto"
def __post_init__(self): def __post_init__(self):
if self.messages is None: if self.messages is None:
self.messages = [{"role": "user", "content": "Hello"}] self.messages = [{"role": "user", "content": "Hello"}]
+72 -42
View File
@@ -2,7 +2,7 @@ import os, json, random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
from typing import Union, Type, Dict, Any from typing import Union, Type, Dict, Any, Optional
from aiohttp import web, ClientResponse from aiohttp import web, ClientResponse
import nltk import nltk
import logging import logging
@@ -10,41 +10,39 @@ import logging
nltk.download("words") nltk.download("words")
WORD_LIST = nltk.corpus.words.words() WORD_LIST = nltk.corpus.words.words()
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
""" """
Generic dataclass accepts any dictionary in input. Generic dataclass accepts any dictionary in input.
""" """
@dataclass @dataclass
class GenericData(ApiPayload, ABC): class GenericData(ApiPayload, ABC):
input: Dict[str, Any] input: Dict[str, Any]
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GenericData": def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
return cls( return cls(input=data["input"])
input=data["input"]
)
@classmethod @classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData": def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
errors = {} errors = {}
# Validate required parameters # Validate required parameters
required_params = ["input"] required_params = ["input"]
for param in required_params: for param in required_params:
if param not in json_msg: if param not in json_msg:
errors[param] = "missing parameter" errors[param] = "missing parameter"
if errors: if errors:
raise JsonDataException(errors) raise JsonDataException(errors)
try: try:
# Create clean data dict and delegate to from_dict # Create clean data dict and delegate to from_dict
clean_data = { clean_data = {"input": json_msg["input"]}
"input": json_msg["input"]
}
return cls.from_dict(clean_data) return cls.from_dict(clean_data)
except (json.JSONDecodeError, JsonDataException) as e: except (json.JSONDecodeError, JsonDataException) as e:
errors["parameters"] = str(e) errors["parameters"] = str(e)
raise JsonDataException(errors) raise JsonDataException(errors)
@@ -59,7 +57,8 @@ class GenericData(ApiPayload, ABC):
def count_workload(self) -> int: def count_workload(self) -> int:
return self.input.get("max_tokens", 0) return self.input.get("max_tokens", 0)
@dataclass @dataclass
class GenericHandler(EndpointHandler[GenericData], ABC): class GenericHandler(EndpointHandler[GenericData], ABC):
@@ -67,10 +66,10 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
@abstractmethod @abstractmethod
def endpoint(self) -> str: def endpoint(self) -> str:
pass pass
@property @property
def healthcheck_endpoint(self) -> str: def healthcheck_endpoint(self) -> Optional[str]:
return os.environ.get('MODEL_HEALTH_ENDPOINT') return os.environ.get("MODEL_HEALTH_ENDPOINT")
@classmethod @classmethod
def payload_cls(cls) -> Type[GenericData]: def payload_cls(cls) -> Type[GenericData]:
@@ -82,17 +81,17 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
async def generate_client_response( async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]: ) -> Union[web.Response, web.StreamResponse]:
match model_response.status: match model_response.status:
case 200: case 200:
# Check if the response is actually streaming based on response headers/content-type # Check if the response is actually streaming based on response headers/content-type
is_streaming_response = ( is_streaming_response = (
model_response.content_type == "text/event-stream" or model_response.content_type == "text/event-stream"
model_response.content_type == "application/x-ndjson" or or model_response.content_type == "application/x-ndjson"
model_response.headers.get("Transfer-Encoding") == "chunked" or or model_response.headers.get("Transfer-Encoding") == "chunked"
"stream" in model_response.content_type.lower() or "stream" in model_response.content_type.lower()
) )
if is_streaming_response: if is_streaming_response:
log.debug("Detected streaming response...") log.debug("Detected streaming response...")
res = web.StreamResponse() res = web.StreamResponse()
@@ -109,69 +108,100 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
return web.Response( return web.Response(
body=content, body=content,
status=200, status=200,
content_type=model_response.content_type content_type=model_response.content_type,
) )
case code: case code:
log.debug("SENDING RESPONSE: ERROR: unknown code") log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code) return web.Response(status=code)
@dataclass @dataclass
class CompletionsData(GenericData): class CompletionsData(GenericData):
@classmethod @classmethod
def for_test(cls) -> "CompletionsData": def for_test(cls) -> "CompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250))) system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
model = os.environ.get("MODEL_NAME") model = os.environ.get("MODEL_NAME")
if not model: if not model:
raise ValueError("MODEL_NAME environment variable not set") raise ValueError("MODEL_NAME environment variable not set")
test_input = { test_input = {
"model": model, "model": model,
"prompt": prompt, "prompt": f"{system_prompt}\n\n{unique_question}",
"temperature": 0.7 "temperature": 0.7,
"max_tokens": 500,
} }
return cls(input=test_input) return cls(input=test_input)
@dataclass @dataclass
class CompletionsHandler(GenericHandler): class CompletionsHandler(GenericHandler):
@property @property
def endpoint(self) -> str: def endpoint(self) -> str:
return "/v1/completions" return "/v1/completions"
@classmethod @classmethod
def payload_cls(cls) -> Type[CompletionsData]: def payload_cls(cls) -> Type[CompletionsData]:
return CompletionsData return CompletionsData
def make_benchmark_payload(self) -> CompletionsData: def make_benchmark_payload(self) -> CompletionsData:
return CompletionsData.for_test() return CompletionsData.for_test()
@dataclass @dataclass
class ChatCompletionsData(GenericData): class ChatCompletionsData(GenericData):
"""Chat completions-specific data implementation""" """Chat completions-specific data implementation"""
@classmethod @classmethod
def for_test(cls) -> "ChatCompletionsData": def for_test(cls) -> "ChatCompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250))) system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
model = os.environ.get("MODEL_NAME") model = os.environ.get("MODEL_NAME")
if not model: if not model:
raise ValueError("MODEL_NAME environment variable not set") raise ValueError("MODEL_NAME environment variable not set")
# Chat completions use messages format instead of prompt # Chat completions use messages format instead of prompt
test_input = { test_input = {
"model": model, "model": model,
"messages": [{"role": "user", "content": prompt}], "messages": [
"temperature": 0.7 {"role": "system", "content": system_prompt}, # Shared prefix
{"role": "user", "content": unique_question} # Unique per request
],
"temperature": 0.7,
"max_tokens": 500,
} }
return cls(input=test_input) return cls(input=test_input)
@dataclass
class ChatCompletionsHandler(GenericHandler): @dataclass
class ChatCompletionsHandler(GenericHandler):
@property @property
def endpoint(self) -> str: def endpoint(self) -> str:
return "/v1/chat/completions" return "/v1/chat/completions"
@classmethod @classmethod
def payload_cls(cls) -> Type[ChatCompletionsData]: def payload_cls(cls) -> Type[ChatCompletionsData]:
return ChatCompletionsData return ChatCompletionsData
def make_benchmark_payload(self) -> ChatCompletionsData: def make_benchmark_payload(self) -> ChatCompletionsData:
return ChatCompletionsData.for_test() return ChatCompletionsData.for_test()
+15 -13
View File
@@ -7,20 +7,20 @@ from lib.server import start_server
# This line indicates that the inference server is listening # This line indicates that the inference server is listening
MODEL_SERVER_START_LOG_MSG = [ MODEL_SERVER_START_LOG_MSG = [
"Application startup complete.", # vLLM "Application startup complete.", # vLLM
"llama runner started", # Ollama "llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI '"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI '"message":"Connected","target":"text_generation_router::server"', # TGI
] ]
MODEL_SERVER_ERROR_LOG_MSGS = [ MODEL_SERVER_ERROR_LOG_MSGS = [
"INFO exited: vllm", # vLLM "INFO exited: vllm", # vLLM
"RuntimeError: Engine", # vLLM "RuntimeError: Engine", # vLLM
"Error: pull model manifest:", # Ollama "Error: pull model manifest:", # Ollama
"stalled; retrying", # Ollama "stalled; retrying", # Ollama
"Error: WebserverFailed", # TGI "Error: WebserverFailed", # TGI
"Error: DownloadError", # TGI "Error: DownloadError", # TGI
"Error: ShardCannotStart", #TGI "Error: ShardCannotStart", # TGI
] ]
logging.basicConfig( logging.basicConfig(
@@ -31,8 +31,8 @@ logging.basicConfig(
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
backend = Backend( backend = Backend(
model_server_url=os.environ.get("MODEL_SERVER_URL"), model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ.get("MODEL_LOG"), model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True, allow_parallel_requests=True,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256), benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[ log_actions=[
@@ -45,9 +45,11 @@ backend = Backend(
], ],
) )
async def handle_ping(_): async def handle_ping(_):
return web.Response(body="pong") return web.Response(body="pong")
routes = [ routes = [
web.post("/v1/completions", backend.create_handler(CompletionsHandler())), web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())), web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
+420 -14
View File
@@ -1,28 +1,434 @@
from lib.test_utils import test_load_cmd, test_args from lib.test_utils import test_args
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from lib.data_types import AuthData
from .data_types.server import CompletionsData from .data_types.server import CompletionsData
import os
WORKER_ENDPOINT = "/v1/completions" import os
import time
import threading
import requests
from dataclasses import dataclass
from collections import Counter
from urllib.parse import urljoin, urlparse
import re
# Headless plotting
import matplotlib
matplotlib.use("Agg")
import logging
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
import matplotlib.pyplot as plt
import numpy as np
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
from requests.adapters import HTTPAdapter
def get_incremented_path(path: str) -> str:
base, ext = os.path.splitext(path)
if not os.path.exists(path):
return path
i = 1
while os.path.exists(f"{base}-{i}{ext}"):
i += 1
return f"{base}-{i}{ext}"
WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT)
@dataclass
class ReqResult:
worker_url: str
route_ms: float
worker_ms: float
total_ms: float
ok: bool
error: str = ""
status_code: int = 0
t_start: float = 0.0
t_end: float = 0.0
workload: float = 0.0
def do_one(endpoint_name: str,
endpoint_id: int,
endpoint_api_key: str,
server_url: str,
worker_endpoint: str,
payload,
results_list,
t0,
status_samples,
route_session,
worker_session):
try:
workload = payload.count_workload()
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
start = time.time()
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
t_after_route = time.time()
if r0.status_code != 200:
results_list.append(ReqResult(worker_url="",
route_ms=(t_after_route - start) * 1000.0,
worker_ms=0.0,
total_ms=(t_after_route - start) * 1000.0,
ok=False,
error=f"route error {r0.reason} {r0.text}",
status_code=r0.status_code,
t_start=start - t0,
t_end=t_after_route - t0,
workload=workload))
return
msg = r0.json()
# 1) Check if we got a worker back from route
worker_url = msg.get("url", "")
if not worker_url:
status = msg.get("status", "")
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
if m:
tot, loading, standby, err = map(int, m.groups())
idle = max(tot - loading - standby - err, 0)
status_samples.append((time.time() - t0, idle))
# 2) If we got a worker, send the request
if worker_url:
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
t_before_worker = time.time()
r1 = worker_session.post(
urljoin(worker_url, worker_endpoint),
json=req,
verify=get_cert_file_path(),
timeout=(4, 120),
)
t_after_worker = time.time()
if r1.status_code != 200:
results_list.append(ReqResult(worker_url=worker_url,
route_ms=(t_after_route - start) * 1000.0,
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
total_ms=(t_after_worker - start) * 1000.0,
ok=False,
error=f"worker inference error {r1.reason} {r1.text}",
status_code=r1.status_code,
t_start=start - t0,
t_end=t_after_worker - t0,
workload=workload))
return
# Success case
results_list.append(ReqResult(worker_url=worker_url,
route_ms=(t_after_route - start) * 1000.0,
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
total_ms=(t_after_worker - start) * 1000.0,
ok=True,
error="",
status_code=200,
t_start=start - t0,
t_end=t_after_worker - t0,
workload=workload))
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
if worker_url:
try:
r_status = route_session.post(
urljoin(server_url, "/get_endpoint_workers/"),
json={"id": endpoint_id},
headers={"Authorization": f"Bearer {endpoint_api_key}"},
timeout=3,
)
if r_status.status_code == 200:
workers = r_status.json()
idle = 0
for w in workers:
st = str(w.get("status", "")).lower()
if (st in ("idle")):
idle += 1
status_samples.append((time.time() - t0, idle))
except Exception:
pass
except Exception as e:
t = time.time()
results_list.append(ReqResult(worker_url="",
route_ms=0.0,
worker_ms=0.0,
total_ms=0.0,
ok=False,
error=f"unknown error {e}",
status_code=0,
t_start=t - t0,
t_end=t - t0,
workload=0.0))
def run_load_with_metrics(num_requests: int,
requests_per_second: float,
endpoint_group_name: str,
account_api_key: str,
server_url: str,
worker_endpoint: str,
instance: str,
out_path: str):
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
account_api_key=account_api_key,
instance=instance)
if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"):
print(f"Endpoint {endpoint_group_name} not found for API key")
return
endpoint_id = int(ep_info["id"])
endpoint_api_key = ep_info["api_key"]
t0 = time.time()
results = []
status_samples = []
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
submit_queue_factor = 2 # cap queued tasks to reduce memory
# Shared HTTP sessions with connection pooling (persistent connections)
def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session:
sess = requests.Session()
adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0)
sess.mount("https://", adapter)
sess.mount("http://", adapter)
return sess
# Router: mostly single host, small connection pool is sufficient
route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
# Fire requests using a thread pool, scheduling at requested RPS
inflight = set()
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
for i in range(num_requests):
# Pace submissions to RPS
target_time = t0 + i / max(requests_per_second, 1e-9)
sleep_s = target_time - time.time()
if sleep_s > 0:
time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive
payload = CompletionsData.for_test()
fut = executor.submit(
do_one,
endpoint_group_name,
endpoint_id,
endpoint_api_key,
server_url,
worker_endpoint,
payload,
results,
t0,
status_samples,
route_session,
worker_session,
)
inflight.add(fut)
# Prevent unbounded queue growth
if len(inflight) >= max_concurrency * submit_queue_factor:
done, not_done = wait(inflight, return_when=FIRST_COMPLETED)
inflight = not_done
# Wait for all outstanding tasks
if inflight:
wait(inflight)
# Close sessions
try:
route_session.close()
finally:
worker_session.close()
# Aggregate results
oks = [r for r in results if r.ok]
errs = [r for r in results if not r.ok]
total_reqs = len(results)
succ = len(oks)
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
avg_total = float(np.mean(total_ms)) if succ else 0.0
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
avg_route = float(np.mean(route_ms)) if succ else 0.0
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
# Distribution over workers (by host:port)
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
dist = Counter(hosts)
# Idle over time (mode per second)
idle_ts, idle_vals = [], []
if status_samples:
buckets = {}
for ts, idle in status_samples:
k = int(ts)
buckets.setdefault(k, []).append(idle)
keys = sorted(buckets.keys())
idle_ts = keys
# Use the most frequent sampled value per second (mode) to keep integer counts
idle_vals = []
for k in keys:
vals_k = [int(v) for v in buckets[k]]
if vals_k:
cnt = Counter(vals_k)
idle_vals.append(cnt.most_common(1)[0][0])
else:
idle_vals.append(0)
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
if errs:
print("Sample errors:")
for e in errs[:5]:
print(f" {e.status_code} {e.error}")
# Plot: 2x3 grid
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}")
# Dist per worker
ax0 = axes[0, 0]
if dist:
items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True)
labels, counts = zip(*items)
ax0.bar(range(len(labels)), counts)
ax0.set_xticks(range(len(labels)))
ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
ax0.set_title("Request distribution over workers")
ax0.set_ylabel("count")
# Latency histogram (total)
ax1 = axes[0, 1]
if succ:
ax1.hist(total_ms, bins=30)
ax1.set_title("Total latency (ms)")
ax1.set_xlabel("ms")
ax1.set_ylabel("freq")
# Eligible workers over time
ax_idle = axes[0, 2]
if idle_ts:
ax_idle.plot(idle_ts, idle_vals, "-o", ms=3)
ax_idle.set_title("Eligible workers over time")
ax_idle.set_xlabel("time (s)")
ax_idle.set_ylabel("eligible count")
# Throughput over time (completions/sec)
ax_idle = axes[1, 0]
ax_idle.clear()
if succ:
per_sec = {}
for r in oks:
s = int(r.t_end)
per_sec[s] = per_sec.get(s, 0) + 1
ts = sorted(per_sec.keys())
vals = [per_sec[t] for t in ts]
ax_idle.plot(ts, vals, "-o", ms=3)
ax_idle.set_title("Completions per second")
ax_idle.set_xlabel("time (s)")
ax_idle.set_ylabel("completions / sec")
# Summary text
ax3 = axes[1, 1]
ax3.axis("off")
text = (
f"Total requests: {total_reqs}\n"
f"Success: {succ} Errors: {len(errs)}\n"
f"Avg total latency: {avg_total:.1f} ms\n"
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
f"Avg route latency: {avg_route:.1f} ms\n"
f"Avg worker latency: {avg_worker:.1f} ms\n"
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
)
ax3.set_title("Summary")
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
# Error count over time
ax_errors = axes[1, 2]
all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
if all_end_times:
min_second = min(all_end_times)
max_second = max(all_end_times)
# Count errors per second
errors_per_second = {}
for result in errs:
second = int(result.t_end)
errors_per_second[second] = errors_per_second.get(second, 0) + 1
# Create complete timeline including zeros
time_seconds = list(range(min_second, max_second + 1))
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
ax_errors.set_title("Errors per second")
ax_errors.set_xlabel("time (s)")
ax_errors.set_ylabel("errors / sec")
# Ensure unique output path and create directory if needed
final_out_path = get_incremented_path(out_path)
out_dir = os.path.dirname(final_out_path)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(final_out_path, dpi=120)
print(f"Saved report to: {final_out_path}")
# Per-worker latency boxplot (top 12 by volume)
groups = {}
for r in oks:
host = urlparse(r.worker_url).netloc
groups.setdefault(host, []).append(r.total_ms)
items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12]
if items:
labels, data = zip(*items)
fig2, axb = plt.subplots(1, 1, figsize=(12, 5))
axb.boxplot(data, showfliers=False)
axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
axb.set_title("Per-worker latency (ms)")
axb.set_ylabel("ms")
plt.tight_layout()
extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png")
plt.savefig(extra_out, dpi=120)
fig2.tight_layout()
fig2.savefig(extra_out, dpi=120)
print(f"Saved worker latency plot to: {extra_out}")
if __name__ == "__main__": if __name__ == "__main__":
# Check if MODEL_NAME environment variable is set # Check if MODEL_NAME environment variable is set
model_name_set = os.environ.get("MODEL_NAME") is not None model_name_set = os.environ.get("MODEL_NAME") is not None
# Add model argument - required only if MODEL_NAME is not set # Add model argument - required only if MODEL_NAME is not set
test_args.add_argument( test_args.add_argument(
"--model", "--model",
dest="model", dest="model",
required=not model_name_set, required=not model_name_set,
help="Model to use for completions request (required if MODEL_NAME env var not set)" help="Model to use for completions request (required if MODEL_NAME env var not set)",
) )
# Parse known args to get model early, before test_load_cmd adds its args # Parse known args to get model early, before adding load args
known_args, _ = test_args.parse_known_args() known_args, _ = test_args.parse_known_args()
if hasattr(known_args, "model") and known_args.model:
# Set environment variable if model was provided
if hasattr(known_args, 'model') and known_args.model:
os.environ["MODEL_NAME"] = known_args.model os.environ["MODEL_NAME"] = known_args.model
print(f"Set MODEL_NAME environment variable to: {known_args.model}") print(f"Set MODEL_NAME environment variable to: {known_args.model}")
# Now call test_load_cmd normally - it will add its own args and re-parse # Load test args
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args) test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
args = test_args.parse_args()
server_url = {
"prod": "https://run.vast.ai",
"alpha": "https://run-alpha.vast.ai",
"candidate": "https://run-candidate.vast.ai",
"local": "http://localhost:8080"
}.get(args.instance, "http://localhost:8080")
run_load_with_metrics(
num_requests=args.num_requests,
requests_per_second=args.requests_per_second,
endpoint_group_name=args.endpoint_group_name,
account_api_key=args.api_key,
server_url=server_url,
worker_endpoint=WORKER_ENDPOINT,
instance=args.instance,
out_path=args.out_path,
)
+6 -1
View File
@@ -4,6 +4,7 @@ import json
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.DEBUG,
@@ -42,7 +43,11 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
req_data = dict(payload=payload, auth_data=auth_data) req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT) url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}") print(f"url: {url}")
response = requests.post(url, json=req_data) response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status() response.raise_for_status()
res = response.json() res = response.json()
print(res) print(res)