Compare commits

..

22 Commits

Author SHA1 Message Date
Colter Downing e756f61b9a graphing errors over time 2025-10-25 12:14:27 -07:00
Colter Downing 8cb98c84f9 non vibe coded test_load 2025-10-24 19:08:36 -07:00
Colter Downing e251afda2b improved test load 2025-10-24 12:53:35 -07:00
Lucas Armand 74bd932327 Suppress matplot debug logs 2025-10-24 12:30:20 -07:00
Lucas Armand 37ad3f8d46 asyncio in metrics 2025-10-23 10:18:31 -07:00
Lucas Armand 0f13506938 Send success param 2025-10-22 10:18:59 -07:00
Lucas Armand 01e752d31f use more asyncio sleep 2025-10-21 18:52:13 -07:00
Lucas Armand 5edfa968ca async sleep 2025-10-21 18:49:48 -07:00
Lucas Armand 5b5ef7227a nvm moved it here 2025-10-21 18:20:11 -07:00
Lucas Armand 16990ff8ff move start request 2025-10-21 18:18:44 -07:00
Lucas Armand 9748176366 fixed semaphore acquire bool 2025-10-21 18:12:23 -07:00
Lucas Armand b39193ae70 check for sem acquire 2025-10-21 18:02:14 -07:00
Lucas Armand 9a6ca5d412 added versioning 2025-10-21 15:42:43 -07:00
Lucas Armand e9ba1b03e4 Use delete_requests and track request_idxs 2025-10-21 11:59:35 -07:00
Rob Ballantyne 4fdc314fd9 Fix healthcheck endpoint URL 2025-10-06 22:16:09 +01:00
Colter-Downing 639d82f5b4 Merge pull request #35 from vast-ai/AUTO-664--Healthcheck-error
Fix healthcheck with separate session
2025-10-02 12:51:19 -07:00
Scott-Laytart 4e2f2311d0 Merge pull request #33 from vast-ai/comfy-blind-fix-override
undo the fix for comfy yesterday.
2025-09-03 11:50:07 -07:00
abiola-vastai 38782d89bc undo the fix for comfy yesterday. 2025-09-03 17:12:35 +00:00
Scott-Laytart 0185216ccb Merge pull request #32 from vast-ai/blindhotfix_comfy_ui_default_port
Blind hotfix to see if comfy UI default is needed. if it does work we…
2025-09-02 18:26:25 -07:00
abiola-vastai b20d9e714c Blind hotfix to see if comfy UI default is needed. if it does work we would revert back. 2025-09-03 01:20:09 +00:00
Rob Ballantyne b1eb65d75d Merge pull request #31 from vast-ai/bugfix/startup-script-20250901
Update uv venv creation command
2025-09-01 18:19:17 +01:00
Rob Ballantyne 1d09d7fe96 Update uv venv creation command 2025-09-01 16:55:20 +01:00
6 changed files with 300 additions and 121 deletions
+50 -26
View File
@@ -12,6 +12,7 @@ from distutils.util import strtobool
from anyio import open_file from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
import asyncio
import requests import requests
from Crypto.Signature import pkcs1_15 from Crypto.Signature import pkcs1_15
@@ -25,8 +26,11 @@ from lib.data_types import (
LogAction, LogAction,
ApiPayload_T, ApiPayload_T,
JsonDataException, JsonDataException,
RequestMetrics
) )
VERSION = "0.1.0"
MSG_HISTORY_LEN = 100 MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
@@ -53,7 +57,9 @@ class Backend:
EndpointHandler # this endpoint handler will be used for benchmarking EndpointHandler # this endpoint handler will be used for benchmarking
) )
log_actions: List[Tuple[LogAction, str]] log_actions: List[Tuple[LogAction, str]]
max_wait_time: float = 10.0
reqnum = -1 reqnum = -1
version = VERSION
msg_history = [] msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore) sem: Semaphore = dataclasses.field(default_factory=Semaphore)
unsecured: bool = dataclasses.field( unsecured: bool = dataclasses.field(
@@ -62,6 +68,7 @@ class Backend:
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
self.metrics._set_version(self.version)
self._total_pubkey_fetch_errors = 0 self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False self.__start_healthcheck: bool = False
@@ -128,55 +135,56 @@ class Backend:
except json.JSONDecodeError: except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
async def cancel_api_call_if_disconnected() -> web.Response: async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection() await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled") log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(workload=workload) self.metrics._request_canceled(request_metrics)
return web.Response(status=500) raise asyncio.CancelledError
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
await self.sem.acquire()
log.debug(
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
try: try:
response = await self.__call_api(handler=handler, payload=payload) response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status status_code = response.status
log.debug( log.debug(
" ".join( " ".join(
[ [
f"request with reqnum:{auth_data.reqnum}", f"request with reqnum:{request_metrics.reqnum}",
f"returned status code: {status_code},", f"returned status code: {status_code},",
] ]
) )
) )
res = await handler.generate_client_response(request, response) res = await handler.generate_client_response(request, response)
self.metrics._request_success(workload=workload) self.metrics._request_success(request_metrics)
return res return res
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}") log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(workload=workload) self.metrics._request_errored(request_metrics)
return web.Response(status=500) return web.Response(status=500)
finally:
self.metrics._request_end(
workload=workload,
reqnum=auth_data.reqnum,
)
self.sem.release()
########### ###########
if self.__check_signature(auth_data) is False: if self.__check_signature(auth_data) is False:
self.metrics._request_reject(request_metrics)
return web.Response(status=401) return web.Response(status=401)
if self.metrics.model_metrics.wait_time > self.max_wait_time:
self.metrics._request_reject(request_metrics)
return web.Response(status=429)
acquired = False
try: try:
self.metrics._request_start(request_metrics)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
acquired = True
log.debug(
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait( done, pending = await wait(
[ [
create_task(make_request()), create_task(make_request()),
@@ -184,11 +192,27 @@ class Backend:
], ],
return_when=FIRST_COMPLETED, return_when=FIRST_COMPLETED,
) )
[task.cancel() for task in pending] for t in pending:
return done.pop().result() t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
done_task = done.pop()
try:
return done_task.result()
except Exception as e:
log.debug(f"Request task raised exception: {e}")
return web.Response(status=500)
except asyncio.CancelledError:
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499)
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
finally:
# Always release the semaphore if it was acquired
if acquired:
self.sem.release()
self.metrics._request_end(request_metrics)
@cached_property @cached_property
def healthcheck_session(self): def healthcheck_session(self):
@@ -229,7 +253,7 @@ class Backend:
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
await gather( await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck() self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
) )
def backend_errored(self, msg: str) -> None: def backend_errored(self, msg: str) -> None:
@@ -387,7 +411,7 @@ class Backend:
if line: if line:
await handle_log_line(line.rstrip()) await handle_log_line(line.rstrip())
else: else:
time.sleep(LOG_POLL_INTERVAL) await asyncio.sleep(LOG_POLL_INTERVAL)
########### ###########
+35 -7
View File
@@ -70,6 +70,7 @@ class AuthData:
endpoint: str endpoint: str
reqnum: int reqnum: int
url: str url: str
request_idx: int
@classmethod @classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]): def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -196,6 +197,15 @@ class SystemMetrics:
self.model_loading_time = None self.model_loading_time = None
@dataclass
class RequestMetrics:
"""Tracks metrics for an active request."""
request_idx: int
reqnum: int
workload: float
status: str
success: bool = False
@dataclass @dataclass
class ModelMetrics: class ModelMetrics:
"""Model specific metrics""" """Model specific metrics"""
@@ -205,12 +215,14 @@ class ModelMetrics:
workload_received: float workload_received: float
workload_cancelled: float workload_cancelled: float
workload_errored: float workload_errored: float
workload_rejected: float
# these are not # these are not
workload_pending: float workload_pending: float
error_msg: Optional[str] error_msg: Optional[str]
max_throughput: float max_throughput: float
requests_recieved: Set[int] = field(default_factory=set) requests_recieved: Set[int] = field(default_factory=set)
requests_working: Set[int] = field(default_factory=set) requests_working: dict[int, RequestMetrics] = field(default_factory=dict)
requests_deleting: list[RequestMetrics] = field(default_factory=list)
last_update: float = field(default_factory=time.time) last_update: float = field(default_factory=time.time)
@classmethod @classmethod
@@ -220,19 +232,30 @@ class ModelMetrics:
workload_served=0.0, workload_served=0.0,
workload_cancelled=0.0, workload_cancelled=0.0,
workload_errored=0.0, workload_errored=0.0,
workload_rejected=0.0,
workload_received=0.0, workload_received=0.0,
error_msg=None, error_msg=None,
max_throughput=0.0, max_throughput=0.0,
) )
@property
def cur_perf(self) -> float:
return max(self.workload_served / (time.time() - self.last_update), 0.0)
@property @property
def workload_processing(self) -> float: def workload_processing(self) -> float:
return max(self.workload_received - self.workload_cancelled, 0.0) return max(self.workload_received - self.workload_cancelled, 0.0)
@property
def wait_time(self) -> float:
if (len(self.requests_working) == 0):
return 0.0
return sum([request.workload for request in self.requests_working.values()]) / self.max_throughput
@property
def cur_load(self) -> float:
return sum([request.workload for request in self.requests_working.values()])
@property
def working_request_idxs(self) -> list[int]:
return [req.request_idx for req in self.requests_working.values()]
def set_errored(self, error_msg): def set_errored(self, error_msg):
self.reset() self.reset()
self.error_msg = error_msg self.error_msg = error_msg
@@ -242,16 +265,20 @@ class ModelMetrics:
self.workload_received = 0 self.workload_received = 0
self.workload_cancelled = 0 self.workload_cancelled = 0
self.workload_errored = 0 self.workload_errored = 0
self.workload_rejected = 0
self.last_update = time.time() self.last_update = time.time()
@dataclass @dataclass
class AutoScalaerData: class AutoScalerData:
"""Data that is reported to autoscaler""" """Data that is reported to autoscaler"""
id: int id: int
version: str
loadtime: float loadtime: float
cur_load: float cur_load: float
rej_load: float
new_load: float
error_msg: str error_msg: str
max_perf: float max_perf: float
cur_perf: float cur_perf: float
@@ -260,6 +287,7 @@ class AutoScalaerData:
num_requests_working: int num_requests_working: int
num_requests_recieved: int num_requests_recieved: int
additional_disk_usage: float additional_disk_usage: float
working_request_idxs: list[int]
url: str url: str
+113 -31
View File
@@ -5,13 +5,14 @@ import json
from asyncio import sleep from asyncio import sleep
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from functools import cache from functools import cache
import asyncio
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError
import requests from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
from typing import Awaitable, NoReturn, List from typing import Awaitable, NoReturn, List
METRICS_UPDATE_INTERVAL = 1 METRICS_UPDATE_INTERVAL = 1
DELETE_REQUESTS_INTERVAL = 1
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
@@ -26,7 +27,9 @@ def get_url() -> str:
@dataclass @dataclass
class Metrics: class Metrics:
version: str = "0"
last_metric_update: float = 0.0 last_metric_update: float = 0.0
last_request_served: float = 0.0
update_pending: bool = False update_pending: bool = False
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"])) id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
report_addr: List[str] = field( report_addr: List[str] = field(
@@ -35,42 +38,84 @@ class Metrics:
url: str = field(default_factory=get_url) url: str = field(default_factory=get_url)
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty) system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty) model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
_session: ClientSession | None = field(default=None, init=False, repr=False)
def _request_start(self, workload: float, reqnum: int) -> None: async def http(self) -> ClientSession:
if self._session is None:
self._session = ClientSession(
timeout=ClientTimeout(total=10),
connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True)
)
return self._session
async def aclose(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
def _request_start(self, request: RequestMetrics) -> None:
""" """
this function is called prior to forwarding a request to a model API. this function is called prior to forwarding a request to a model API.
""" """
log.debug("request start") log.debug("request start")
self.model_metrics.workload_pending += workload request.status = "Started"
self.model_metrics.workload_received += workload self.model_metrics.workload_pending += request.workload
self.model_metrics.requests_recieved.add(reqnum) self.model_metrics.workload_received += request.workload
self.model_metrics.requests_working.add(reqnum) self.model_metrics.requests_recieved.add(request.reqnum)
self.model_metrics.requests_working[request.reqnum] = request
self.update_pending = True
def _request_end(self, workload: float, reqnum: int) -> None: def _request_end(self, request: RequestMetrics) -> None:
""" """
this function is called after handling of a request ends, regardless of the outcome this function is called after handling of a request ends, regardless of the outcome
""" """
self.model_metrics.workload_pending -= workload self.model_metrics.workload_pending -= request.workload
self.model_metrics.requests_working.discard(reqnum) self.model_metrics.requests_working.pop(request.reqnum, None)
self.model_metrics.requests_deleting.append(request)
self.last_request_served = time.time()
def _request_success(self, workload: float) -> None: def _request_success(self, request: RequestMetrics) -> None:
""" """
this function is called after a response from model API is received and forwarded. this function is called after a response from model API is received and forwarded.
""" """
self.model_metrics.workload_served += workload self.model_metrics.workload_served += request.workload
request.status = "Success"
request.success = True
self.update_pending = True self.update_pending = True
def _request_errored(self, workload: float) -> None: def _request_errored(self, request: RequestMetrics) -> None:
""" """
this function is called if model API returns an error this function is called if model API returns an error
""" """
self.model_metrics.workload_errored += workload self.model_metrics.workload_errored += request.workload
request.status = "Error"
request.success = False
self.update_pending = True
def _request_canceled(self, workload: float) -> None: def _request_canceled(self, request: RequestMetrics) -> None:
""" """
this function is called if client drops connection before model API has responded this function is called if client drops connection before model API has responded
""" """
self.model_metrics.workload_cancelled += workload self.model_metrics.workload_cancelled += request.workload
request.success = True
request.status = "Cancelled"
def _request_reject(self, request: RequestMetrics):
"""
this function is called if the current wait time for the model is above max_wait_time
"""
self.model_metrics.requests_recieved.add(request.reqnum)
self.model_metrics.requests_deleting.append(request)
self.model_metrics.workload_rejected += request.workload
request.success = False
request.status = "Rejected"
self.update_pending = True
async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]:
while True:
await sleep(DELETE_REQUESTS_INTERVAL)
if len(self.model_metrics.requests_deleting) > 0:
await self.__send_delete_requests_and_reset()
async def _send_metrics_loop(self) -> Awaitable[NoReturn]: async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
while True: while True:
@@ -78,10 +123,10 @@ class Metrics:
elapsed = time.time() - self.last_metric_update elapsed = time.time() - self.last_metric_update
if self.system_metrics.model_is_loaded is False and elapsed >= 10: if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait") log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed) await self.__send_metrics_and_reset()
elif self.update_pending or elapsed > 10: elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait") log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed) await self.__send_metrics_and_reset()
def _model_loaded(self, max_throughput: float) -> None: def _model_loaded(self, max_throughput: float) -> None:
self.system_metrics.model_loading_time = ( self.system_metrics.model_loading_time = (
@@ -94,27 +139,63 @@ class Metrics:
self.model_metrics.set_errored(error_msg) self.model_metrics.set_errored(error_msg)
self.system_metrics.model_is_loaded = True self.system_metrics.model_is_loaded = True
def _set_version(self, version: str) -> None:
self.version = version
#######################################Private####################################### #######################################Private#######################################
def __send_metrics_and_reset(self, elapsed): async def __send_delete_requests_and_reset(self):
def compute_autoscaler_data() -> AutoScalaerData: async def send_data(report_addr: str, success: bool) -> bool:
return AutoScalaerData( data = {
"worker_id": self.id,
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success],
"success": success
}
full_path = report_addr.rstrip("/") + "/delete_requests/"
for attempt in range(1, 4):
try:
session = await self.http()
async with session.post(full_path, json=data) as res:
res.raise_for_status()
return True
except asyncio.TimeoutError:
log.debug(f"delete_requests timed out")
except (ClientResponseError, Exception) as e:
log.debug(f"delete_requests failed with error: {e}")
await asyncio.sleep(2)
log.debug(f"retrying delete_request, attempt: {attempt}")
for report_addr in self.report_addr:
success = await send_data(report_addr, success=True) and await send_data(report_addr, success=False)
if success is True:
self.model_metrics.requests_deleting.clear()
break
async def __send_metrics_and_reset(self):
def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData(
id=self.id, id=self.id,
version=self.version,
loadtime=(self.system_metrics.model_loading_time or 0.0), loadtime=(self.system_metrics.model_loading_time or 0.0),
cur_load=(self.model_metrics.workload_processing / elapsed), new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected,
max_perf=self.model_metrics.max_throughput, max_perf=self.model_metrics.max_throughput,
cur_perf=self.model_metrics.cur_perf, cur_perf=self.model_metrics.workload_served,
error_msg=self.model_metrics.error_msg or "", error_msg=self.model_metrics.error_msg or "",
num_requests_working=len(self.model_metrics.requests_working), num_requests_working=len(self.model_metrics.requests_working),
num_requests_recieved=len(self.model_metrics.requests_recieved), num_requests_recieved=len(self.model_metrics.requests_recieved),
additional_disk_usage=self.system_metrics.additional_disk_usage, additional_disk_usage=self.system_metrics.additional_disk_usage,
working_request_idxs=self.model_metrics.working_request_idxs,
cur_capacity=0, cur_capacity=0,
max_capacity=0, max_capacity=0,
url=self.url, url=self.url,
) )
def send_data(report_addr: str) -> bool: async def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data() data = compute_autoscaler_data()
full_path = report_addr.rstrip("/") + "/worker_status/" full_path = report_addr.rstrip("/") + "/worker_status/"
log.debug( log.debug(
@@ -129,14 +210,15 @@ class Metrics:
) )
for attempt in range(1, 4): for attempt in range(1, 4):
try: try:
res = requests.post(full_path, json=asdict(data), timeout=1) session = await self.http()
res.raise_for_status() async with session.post(full_path, json=asdict(data)) as res:
res.raise_for_status()
return True return True
except requests.Timeout: except asyncio.TimeoutError:
log.debug(f"autoscaler status update timed out") log.debug(f"autoscaler status update timed out")
except Exception as e: except (ClientResponseError, Exception) as e:
log.debug(f"autoscaler status update failed with error: {e}") log.debug(f"autoscaler status update failed with error: {e}")
time.sleep(2) await asyncio.sleep(2)
log.debug(f"retrying autoscaler status update, attempt: {attempt}") log.debug(f"retrying autoscaler status update, attempt: {attempt}")
log.debug(f"failed to send update through {report_addr}") log.debug(f"failed to send update through {report_addr}")
return False return False
@@ -146,7 +228,7 @@ class Metrics:
self.system_metrics.update_disk_usage() self.system_metrics.update_disk_usage()
for report_addr in self.report_addr: for report_addr in self.report_addr:
success = send_data(report_addr) success = await send_data(report_addr)
if success is True: if success is True:
break break
self.update_pending = False self.update_pending = False
+2 -2
View File
@@ -59,12 +59,12 @@ then
fi fi
# Fork testing # Fork testing
git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR" [[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
if [[ -n ${PYWORKER_REF:-} ]]; then if [[ -n ${PYWORKER_REF:-} ]]; then
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF") (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
fi fi
uv venv --managed-python "$ENV_PATH" -p 3.10 uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
source "$ENV_PATH/bin/activate" source "$ENV_PATH/bin/activate"
uv pip install -r "${SERVER_DIR}/requirements.txt" uv pip install -r "${SERVER_DIR}/requirements.txt"
+1 -1
View File
@@ -70,7 +70,7 @@ class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
@property @property
def healthcheck_endpoint(self) -> Optional[str]: def healthcheck_endpoint(self) -> Optional[str]:
return "/health" return f"{MODEL_SERVER_URL}/health"
@classmethod @classmethod
def payload_cls(cls) -> Type[ComfyWorkflowData]: def payload_cls(cls) -> Type[ComfyWorkflowData]:
+99 -54
View File
@@ -42,6 +42,7 @@ class ReqResult:
total_ms: float total_ms: float
ok: bool ok: bool
error: str = "" error: str = ""
status_code: int = 0
t_start: float = 0.0 t_start: float = 0.0
t_end: float = 0.0 t_end: float = 0.0
workload: float = 0.0 workload: float = 0.0
@@ -58,31 +59,72 @@ def do_one(endpoint_name: str,
route_session, route_session,
worker_session): worker_session):
try: try:
u = payload.count_workload() workload = payload.count_workload()
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": u} route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
headers = {"Authorization": f"Bearer {endpoint_api_key}"} headers = {"Authorization": f"Bearer {endpoint_api_key}"}
start = time.time() start = time.time()
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4) r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
t_after_route = time.time() t_after_route = time.time()
if r0.status_code != 200: if r0.status_code != 200:
results_list.append(ReqResult("", (t_after_route - start) * 1000.0, 0.0, (t_after_route - start) * 1000.0, False, results_list.append(ReqResult(worker_url="",
f"route {r0.status_code} {r0.text}")) route_ms=(t_after_route - start) * 1000.0,
worker_ms=0.0,
total_ms=(t_after_route - start) * 1000.0,
ok=False,
error=f"route error {r0.reason} {r0.text}",
status_code=r0.status_code,
t_start=start - t0,
t_end=t_after_route - t0,
workload=workload))
return return
msg = r0.json() msg = r0.json()
# 1) "Status" is in the response when no worker is ready # 1) Check if we got a worker back from route
worker_sampled = True worker_url = msg.get("url", "")
status = msg.get("status", "") if not worker_url:
if status:
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S) m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
if m: if m:
tot, loading, standby, err = map(int, m.groups()) tot, loading, standby, err = map(int, m.groups())
idle = max(tot - loading - standby - err, 0) idle = max(tot - loading - standby - err, 0)
status_samples.append((time.time() - t0, idle)) status_samples.append((time.time() - t0, idle))
worker_sampled = False
# 2) Otherwise (successful request), sample via /get_endpoint_workers/ for eligible (idle) worker tracking # 2) If we got a worker, send the request
if worker_sampled: if worker_url:
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
t_before_worker = time.time()
r1 = worker_session.post(
urljoin(worker_url, worker_endpoint),
json=req,
verify=get_cert_file_path(),
timeout=(4, 120),
)
t_after_worker = time.time()
if r1.status_code != 200:
results_list.append(ReqResult(worker_url=worker_url,
route_ms=(t_after_route - start) * 1000.0,
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
total_ms=(t_after_worker - start) * 1000.0,
ok=False,
error=f"worker inference error {r1.reason} {r1.text}",
status_code=r1.status_code,
t_start=start - t0,
t_end=t_after_worker - t0,
workload=workload))
return
# Success case
results_list.append(ReqResult(worker_url=worker_url,
route_ms=(t_after_route - start) * 1000.0,
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
total_ms=(t_after_worker - start) * 1000.0,
ok=True,
error="",
status_code=200,
t_start=start - t0,
t_end=t_after_worker - t0,
workload=workload))
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
if worker_url:
try: try:
r_status = route_session.post( r_status = route_session.post(
urljoin(server_url, "/get_endpoint_workers/"), urljoin(server_url, "/get_endpoint_workers/"),
@@ -100,29 +142,18 @@ def do_one(endpoint_name: str,
status_samples.append((time.time() - t0, idle)) status_samples.append((time.time() - t0, idle))
except Exception: except Exception:
pass pass
# 3) Send the request
worker_address = msg["url"]
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
t1 = time.time()
# Use explicit connect/read timeouts to avoid long hangs
r1 = worker_session.post(
urljoin(worker_address, worker_endpoint),
json=req,
verify=get_cert_file_path(),
timeout=(4, 120),
)
t2 = time.time()
if r1.status_code != 200:
results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0,
(t2 - start) * 1000.0, False,
f"infer {r1.status_code} {r1.text}"))
return
results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0, (t2 - start) * 1000.0,
True, "", t_start=start - t0, t_end=t2 - t0, workload=u))
except Exception as e: except Exception as e:
t = time.time() t = time.time()
results_list.append(ReqResult("", (t - start) * 1000.0, 0.0, (t - start) * 1000.0, False, str(e))) results_list.append(ReqResult(worker_url="",
route_ms=0.0,
worker_ms=0.0,
total_ms=0.0,
ok=False,
error=f"unknown error {e}",
status_code=0,
t_start=t - t0,
t_end=t - t0,
workload=0.0))
def run_load_with_metrics(num_requests: int, def run_load_with_metrics(num_requests: int,
requests_per_second: float, requests_per_second: float,
@@ -132,7 +163,7 @@ def run_load_with_metrics(num_requests: int,
worker_endpoint: str, worker_endpoint: str,
instance: str, instance: str,
out_path: str): out_path: str):
# Resolve endpoint id + endpoint-scoped API key
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name, ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
account_api_key=account_api_key, account_api_key=account_api_key,
instance=instance) instance=instance)
@@ -145,8 +176,7 @@ def run_load_with_metrics(num_requests: int,
t0 = time.time() t0 = time.time()
results = [] results = []
status_samples = [] status_samples = []
# Concurrency control max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "1024"))
submit_queue_factor = 2 # cap queued tasks to reduce memory submit_queue_factor = 2 # cap queued tasks to reduce memory
# Shared HTTP sessions with connection pooling (persistent connections) # Shared HTTP sessions with connection pooling (persistent connections)
@@ -158,9 +188,9 @@ def run_load_with_metrics(num_requests: int,
return sess return sess
# Router: mostly single host, small connection pool is sufficient # Router: mostly single host, small connection pool is sufficient
route_session = make_session(pool_connections=8, pool_maxsize=max_concurrency) route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency # Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
worker_session = make_session(pool_connections=max(256, max_concurrency), pool_maxsize=max_concurrency) worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
# Fire requests using a thread pool, scheduling at requested RPS # Fire requests using a thread pool, scheduling at requested RPS
inflight = set() inflight = set()
@@ -209,11 +239,12 @@ def run_load_with_metrics(num_requests: int,
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([]) total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([]) worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
#route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([]) route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
avg_total = float(np.mean(total_ms)) if succ else 0.0 avg_total = float(np.mean(total_ms)) if succ else 0.0
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
avg_route = float(np.mean(route_ms)) if succ else 0.0
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0) p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
total_compute_time_ms = float(np.sum(worker_ms)) if succ else 0.0
# Distribution over workers (by host:port) # Distribution over workers (by host:port)
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url] hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
@@ -240,11 +271,11 @@ def run_load_with_metrics(num_requests: int,
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}") print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}") print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
print(f"Total compute time (sum worker latency, s): {total_compute_time_ms/1000.0:.2f}") print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
if errs: if errs:
print("Sample errors:") print("Sample errors:")
for e in errs[:5]: for e in errs[:5]:
print(f" {e.error}") print(f" {e.status_code} {e.error}")
# Plot: 2x3 grid # Plot: 2x3 grid
fig, axes = plt.subplots(2, 3, figsize=(15, 8)) fig, axes = plt.subplots(2, 3, figsize=(15, 8))
@@ -264,7 +295,7 @@ def run_load_with_metrics(num_requests: int,
# Latency histogram (total) # Latency histogram (total)
ax1 = axes[0, 1] ax1 = axes[0, 1]
if succ: if succ:
ax1.hist(total_ms, bins=30, color="#4e79a7") ax1.hist(total_ms, bins=30)
ax1.set_title("Total latency (ms)") ax1.set_title("Total latency (ms)")
ax1.set_xlabel("ms") ax1.set_xlabel("ms")
ax1.set_ylabel("freq") ax1.set_ylabel("freq")
@@ -290,7 +321,7 @@ def run_load_with_metrics(num_requests: int,
ax_idle.plot(ts, vals, "-o", ms=3) ax_idle.plot(ts, vals, "-o", ms=3)
ax_idle.set_title("Completions per second") ax_idle.set_title("Completions per second")
ax_idle.set_xlabel("time (s)") ax_idle.set_xlabel("time (s)")
ax_idle.set_ylabel("req/s") ax_idle.set_ylabel("completions / sec")
# Summary text # Summary text
ax3 = axes[1, 1] ax3 = axes[1, 1]
@@ -298,22 +329,36 @@ def run_load_with_metrics(num_requests: int,
text = ( text = (
f"Total requests: {total_reqs}\n" f"Total requests: {total_reqs}\n"
f"Success: {succ} Errors: {len(errs)}\n" f"Success: {succ} Errors: {len(errs)}\n"
f"Avg latency: {avg_total:.1f} ms\n" f"Avg total latency: {avg_total:.1f} ms\n"
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n" f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
f"Total compute time: {total_compute_time_ms/1000.0:.2f} s" f"Avg route latency: {avg_route:.1f} ms\n"
f"Avg worker latency: {avg_worker:.1f} ms\n"
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
) )
ax3.set_title("Summary") ax3.set_title("Summary")
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes) ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
# Latency CDF (total_ms) # Error count over time
ax_cdf = axes[1, 2] ax_errors = axes[1, 2]
if succ: all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
x = np.sort(total_ms) if all_end_times:
y = np.linspace(0, 1, len(x), endpoint=True) min_second = min(all_end_times)
ax_cdf.plot(x, y) max_second = max(all_end_times)
ax_cdf.set_title("Latency CDF") # Count errors per second
ax_cdf.set_xlabel("ms") errors_per_second = {}
ax_cdf.set_ylabel("fraction ≤ x") for result in errs:
second = int(result.t_end)
errors_per_second[second] = errors_per_second.get(second, 0) + 1
# Create complete timeline including zeros
time_seconds = list(range(min_second, max_second + 1))
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
ax_errors.set_title("Errors per second")
ax_errors.set_xlabel("time (s)")
ax_errors.set_ylabel("errors / sec")
# Ensure unique output path and create directory if needed # Ensure unique output path and create directory if needed
final_out_path = get_incremented_path(out_path) final_out_path = get_incremented_path(out_path)