From e9ba1b03e4bdbc12f6a4cca8038a7236bce89659 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 8 Oct 2025 16:54:18 -0700 Subject: [PATCH 01/18] Use delete_requests and track request_idxs --- lib/backend.py | 54 ++++++++++++++++--------- lib/data_types.py | 40 +++++++++++++++---- lib/metrics.py | 100 ++++++++++++++++++++++++++++++++++++---------- 3 files changed, 146 insertions(+), 48 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 76b1ec5..a62ab6c 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -12,6 +12,7 @@ from distutils.util import strtobool from anyio import open_file from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector +import asyncio import requests from Crypto.Signature import pkcs1_15 @@ -25,6 +26,7 @@ from lib.data_types import ( LogAction, ApiPayload_T, JsonDataException, + RequestMetrics ) MSG_HISTORY_LEN = 100 @@ -53,6 +55,7 @@ class Backend: EndpointHandler # this endpoint handler will be used for benchmarking ) log_actions: List[Tuple[LogAction, str]] + max_wait_time: float = 10.0 reqnum = -1 msg_history = [] sem: Semaphore = dataclasses.field(default_factory=Semaphore) @@ -128,53 +131,53 @@ class Backend: except json.JSONDecodeError: return web.json_response(dict(error="invalid JSON"), status=422) workload = payload.count_workload() + request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") async def cancel_api_call_if_disconnected() -> web.Response: await request.wait_for_disconnection() - log.debug(f"request with reqnum: {auth_data.reqnum} was canceled") - self.metrics._request_canceled(workload=workload) - return web.Response(status=500) + log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled") + self.metrics._request_canceled(request_metrics) + raise asyncio.CancelledError async def make_request() -> Union[web.Response, web.StreamResponse]: - log.debug(f"got request, {auth_data.reqnum}") - self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum) + log.debug(f"got request, {request_metrics.reqnum}") + self.metrics._request_start(request_metrics) if self.allow_parallel_requests is False: - log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}") + log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") await self.sem.acquire() log.debug( - f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..." + f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..." ) else: - log.debug(f"Starting request for reqnum:{auth_data.reqnum}") + log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") try: response = await self.__call_api(handler=handler, payload=payload) status_code = response.status log.debug( " ".join( [ - f"request with reqnum:{auth_data.reqnum}", + f"request with reqnum:{request_metrics.reqnum}", f"returned status code: {status_code},", ] ) ) res = await handler.generate_client_response(request, response) - self.metrics._request_success(workload=workload) + self.metrics._request_success(request_metrics) return res except requests.exceptions.RequestException as e: log.debug(f"[backend] Request error: {e}") - self.metrics._request_errored(workload=workload) + self.metrics._request_errored(request_metrics) return web.Response(status=500) - finally: - self.metrics._request_end( - workload=workload, - reqnum=auth_data.reqnum, - ) - self.sem.release() ########### if self.__check_signature(auth_data) is False: + self.metrics._request_reject(request_metrics) return web.Response(status=401) + + if self.metrics.model_metrics.wait_time > self.max_wait_time: + self.metrics._request_reject(request_metrics) + return web.Response(status=500) try: done, pending = await wait( @@ -185,10 +188,23 @@ class Backend: return_when=FIRST_COMPLETED, ) [task.cancel() for task in pending] - return done.pop().result() + done_task = done.pop() + try: + return done_task.result() + except Exception as e: + log.debug(f"Request task raised exception: {e}") + return web.Response(status=500) + except asyncio.CancelledError: + # Client is gone. Do not write a response; just unwind. + return web.Response(status=499) except Exception as e: log.debug(f"Exception in main handler loop {e}") return web.Response(status=500) + finally: + # Always release the semaphore if it was acquired + if not self.allow_parallel_requests: + self.sem.release() + self.metrics._request_end(request_metrics) @cached_property def healthcheck_session(self): @@ -229,7 +245,7 @@ class Backend: async def _start_tracking(self) -> None: await gather( - self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck() + self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop() ) def backend_errored(self, msg: str) -> None: diff --git a/lib/data_types.py b/lib/data_types.py index ed8b9f4..ec72204 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -70,6 +70,7 @@ class AuthData: endpoint: str reqnum: int url: str + request_idx: int @classmethod def from_json_msg(cls, json_msg: Dict[str, Any]): @@ -196,6 +197,14 @@ class SystemMetrics: self.model_loading_time = None +@dataclass +class RequestMetrics: + """Tracks metrics for an active request.""" + request_idx: int + reqnum: int + workload: float + status: str + @dataclass class ModelMetrics: """Model specific metrics""" @@ -205,12 +214,14 @@ class ModelMetrics: workload_received: float workload_cancelled: float workload_errored: float + workload_rejected: float # these are not workload_pending: float error_msg: Optional[str] max_throughput: float requests_recieved: Set[int] = field(default_factory=set) - requests_working: Set[int] = field(default_factory=set) + requests_working: dict[int, RequestMetrics] = field(default_factory=dict) + requests_deleting: list[RequestMetrics] = field(default_factory=list) last_update: float = field(default_factory=time.time) @classmethod @@ -220,19 +231,30 @@ class ModelMetrics: workload_served=0.0, workload_cancelled=0.0, workload_errored=0.0, + workload_rejected=0.0, workload_received=0.0, error_msg=None, max_throughput=0.0, ) - - @property - def cur_perf(self) -> float: - return max(self.workload_served / (time.time() - self.last_update), 0.0) - + @property def workload_processing(self) -> float: return max(self.workload_received - self.workload_cancelled, 0.0) + @property + def wait_time(self) -> float: + if (len(self.requests_working) == 0): + return 0.0 + return sum([request.workload for request in self.requests_working.values()]) / self.max_throughput + + @property + def cur_load(self) -> float: + return sum([request.workload for request in self.requests_working.values()]) + + @property + def working_request_idxs(self) -> list[int]: + return [req.request_idx for req in self.requests_working.values()] + def set_errored(self, error_msg): self.reset() self.error_msg = error_msg @@ -242,16 +264,19 @@ class ModelMetrics: self.workload_received = 0 self.workload_cancelled = 0 self.workload_errored = 0 + self.workload_rejected = 0 self.last_update = time.time() @dataclass -class AutoScalaerData: +class AutoScalerData: """Data that is reported to autoscaler""" id: int loadtime: float cur_load: float + rej_load: float + new_load: float error_msg: str max_perf: float cur_perf: float @@ -260,6 +285,7 @@ class AutoScalaerData: num_requests_working: int num_requests_recieved: int additional_disk_usage: float + working_request_idxs: list[int] url: str diff --git a/lib/metrics.py b/lib/metrics.py index 166706b..053a465 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -8,10 +8,11 @@ from functools import cache import requests -from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics +from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics from typing import Awaitable, NoReturn, List METRICS_UPDATE_INTERVAL = 1 +DELETE_REQUESTS_INTERVAL = 1 log = logging.getLogger(__file__) @@ -27,6 +28,7 @@ def get_url() -> str: @dataclass class Metrics: last_metric_update: float = 0.0 + last_request_served: float = 0.0 update_pending: bool = False id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"])) report_addr: List[str] = field( @@ -36,41 +38,65 @@ class Metrics: system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty) model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty) - def _request_start(self, workload: float, reqnum: int) -> None: + def _request_start(self, request: RequestMetrics) -> None: """ this function is called prior to forwarding a request to a model API. """ log.debug("request start") - self.model_metrics.workload_pending += workload - self.model_metrics.workload_received += workload - self.model_metrics.requests_recieved.add(reqnum) - self.model_metrics.requests_working.add(reqnum) + request.status = "Started" + self.model_metrics.workload_pending += request.workload + self.model_metrics.workload_received += request.workload + self.model_metrics.requests_recieved.add(request.reqnum) + self.model_metrics.requests_working[request.reqnum] = request + self.update_pending = True - def _request_end(self, workload: float, reqnum: int) -> None: + def _request_end(self, request: RequestMetrics) -> None: """ this function is called after handling of a request ends, regardless of the outcome """ - self.model_metrics.workload_pending -= workload - self.model_metrics.requests_working.discard(reqnum) + self.model_metrics.workload_pending -= request.workload + self.model_metrics.requests_working.pop(request.reqnum, None) + self.model_metrics.requests_deleting.append(request) + self.last_request_served = time.time() - def _request_success(self, workload: float) -> None: + def _request_success(self, request: RequestMetrics) -> None: """ this function is called after a response from model API is received and forwarded. """ - self.model_metrics.workload_served += workload + self.model_metrics.workload_served += request.workload + request.status = "Success" self.update_pending = True - def _request_errored(self, workload: float) -> None: + def _request_errored(self, request: RequestMetrics) -> None: """ this function is called if model API returns an error """ - self.model_metrics.workload_errored += workload + self.model_metrics.workload_errored += request.workload + request.status = "Error" + self.update_pending = True - def _request_canceled(self, workload: float) -> None: + def _request_canceled(self, request: RequestMetrics) -> None: """ this function is called if client drops connection before model API has responded """ - self.model_metrics.workload_cancelled += workload + self.model_metrics.workload_cancelled += request.workload + request.status = "Cancelled" + + def _request_reject(self, request: RequestMetrics): + """ + this function is called if the current wait time for the model is above max_wait_time + """ + self.model_metrics.requests_recieved.add(request.reqnum) + self.model_metrics.requests_deleting.append(request) + self.model_metrics.workload_rejected += request.workload + request.status = "Rejected" + self.update_pending = True + + async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]: + while True: + await sleep(DELETE_REQUESTS_INTERVAL) + if len(self.model_metrics.requests_deleting) > 0: + self.__send_delete_requests_and_reset() async def _send_metrics_loop(self) -> Awaitable[NoReturn]: while True: @@ -78,10 +104,10 @@ class Metrics: elapsed = time.time() - self.last_metric_update if self.system_metrics.model_is_loaded is False and elapsed >= 10: log.debug(f"sending loading model metrics after {int(elapsed)}s wait") - self.__send_metrics_and_reset(elapsed) + self.__send_metrics_and_reset() elif self.update_pending or elapsed > 10: log.debug(f"sending loaded model metrics after {int(elapsed)}s wait") - self.__send_metrics_and_reset(elapsed) + self.__send_metrics_and_reset() def _model_loaded(self, max_throughput: float) -> None: self.system_metrics.model_loading_time = ( @@ -96,19 +122,49 @@ class Metrics: #######################################Private####################################### - def __send_metrics_and_reset(self, elapsed): + def __send_delete_requests_and_reset(self): - def compute_autoscaler_data() -> AutoScalaerData: - return AutoScalaerData( + def send_data(report_addr: str) -> bool: + data = { + "worker_id": self.id, + "request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting] + } + full_path = report_addr.rstrip("/") + "/delete_requests/" + for attempt in range(1, 4): + try: + res = requests.post(full_path, json=data, timeout=1) + res.raise_for_status() + return True + except requests.Timeout: + log.debug(f"delete_requests timed out") + except Exception as e: + log.debug(f"delete_requests failed with error: {e}") + time.sleep(2) + log.debug(f"retrying delete_request, attempt: {attempt}") + + for report_addr in self.report_addr: + success = send_data(report_addr) + if success is True: + self.model_metrics.requests_deleting.clear() + break + + + def __send_metrics_and_reset(self): + + def compute_autoscaler_data() -> AutoScalerData: + return AutoScalerData( id=self.id, loadtime=(self.system_metrics.model_loading_time or 0.0), - cur_load=(self.model_metrics.workload_processing / elapsed), + new_load=self.model_metrics.workload_processing, + cur_load=self.model_metrics.cur_load, + rej_load=self.model_metrics.workload_rejected, max_perf=self.model_metrics.max_throughput, - cur_perf=self.model_metrics.cur_perf, + cur_perf=self.model_metrics.workload_served, error_msg=self.model_metrics.error_msg or "", num_requests_working=len(self.model_metrics.requests_working), num_requests_recieved=len(self.model_metrics.requests_recieved), additional_disk_usage=self.system_metrics.additional_disk_usage, + working_request_idxs=self.model_metrics.working_request_idxs, cur_capacity=0, max_capacity=0, url=self.url, From 9a6ca5d412d56fe48d8fae0da0e980f8a71c1dcc Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 21 Oct 2025 15:42:43 -0700 Subject: [PATCH 02/18] added versioning --- lib/backend.py | 4 ++++ lib/data_types.py | 1 + lib/metrics.py | 5 +++++ 3 files changed, 10 insertions(+) diff --git a/lib/backend.py b/lib/backend.py index a62ab6c..070f2c5 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -29,6 +29,8 @@ from lib.data_types import ( RequestMetrics ) +VERSION = "0.1.0" + MSG_HISTORY_LEN = 100 log = logging.getLogger(__file__) @@ -57,6 +59,7 @@ class Backend: log_actions: List[Tuple[LogAction, str]] max_wait_time: float = 10.0 reqnum = -1 + version = VERSION msg_history = [] sem: Semaphore = dataclasses.field(default_factory=Semaphore) unsecured: bool = dataclasses.field( @@ -65,6 +68,7 @@ class Backend: def __post_init__(self): self.metrics = Metrics() + self.metrics._set_version(self.version) self._total_pubkey_fetch_errors = 0 self._pubkey = self._fetch_pubkey() self.__start_healthcheck: bool = False diff --git a/lib/data_types.py b/lib/data_types.py index ec72204..43213de 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -273,6 +273,7 @@ class AutoScalerData: """Data that is reported to autoscaler""" id: int + version: str loadtime: float cur_load: float rej_load: float diff --git a/lib/metrics.py b/lib/metrics.py index 053a465..f7dfaef 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -27,6 +27,7 @@ def get_url() -> str: @dataclass class Metrics: + version: str = "0" last_metric_update: float = 0.0 last_request_served: float = 0.0 update_pending: bool = False @@ -120,6 +121,9 @@ class Metrics: self.model_metrics.set_errored(error_msg) self.system_metrics.model_is_loaded = True + def _set_version(self, version: str) -> None: + self.version = version + #######################################Private####################################### def __send_delete_requests_and_reset(self): @@ -154,6 +158,7 @@ class Metrics: def compute_autoscaler_data() -> AutoScalerData: return AutoScalerData( id=self.id, + version=self.version, loadtime=(self.system_metrics.model_loading_time or 0.0), new_load=self.model_metrics.workload_processing, cur_load=self.model_metrics.cur_load, From b39193ae708d2f240ecd35e0a91c5324d0ce6118 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 21 Oct 2025 18:02:14 -0700 Subject: [PATCH 03/18] check for sem acquire --- lib/backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/backend.py b/lib/backend.py index 070f2c5..34bb631 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -136,6 +136,7 @@ class Backend: return web.json_response(dict(error="invalid JSON"), status=422) workload = payload.count_workload() request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") + acquired = False async def cancel_api_call_if_disconnected() -> web.Response: await request.wait_for_disconnection() @@ -149,6 +150,7 @@ class Backend: if self.allow_parallel_requests is False: log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") await self.sem.acquire() + acquired = True log.debug( f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..." ) @@ -206,7 +208,7 @@ class Backend: return web.Response(status=500) finally: # Always release the semaphore if it was acquired - if not self.allow_parallel_requests: + if acquired: self.sem.release() self.metrics._request_end(request_metrics) From 9748176366040de00c24a4e08f710d432408a8a5 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 21 Oct 2025 18:12:23 -0700 Subject: [PATCH 04/18] fixed semaphore acquire bool --- lib/backend.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 34bb631..30bc282 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -136,7 +136,6 @@ class Backend: return web.json_response(dict(error="invalid JSON"), status=422) workload = payload.count_workload() request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") - acquired = False async def cancel_api_call_if_disconnected() -> web.Response: await request.wait_for_disconnection() @@ -147,15 +146,6 @@ class Backend: async def make_request() -> Union[web.Response, web.StreamResponse]: log.debug(f"got request, {request_metrics.reqnum}") self.metrics._request_start(request_metrics) - if self.allow_parallel_requests is False: - log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") - await self.sem.acquire() - acquired = True - log.debug( - f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..." - ) - else: - log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") try: response = await self.__call_api(handler=handler, payload=payload) status_code = response.status @@ -185,7 +175,17 @@ class Backend: self.metrics._request_reject(request_metrics) return web.Response(status=500) + acquired = False try: + if self.allow_parallel_requests is False: + log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") + await self.sem.acquire() + acquired = True + log.debug( + f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..." + ) + else: + log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") done, pending = await wait( [ create_task(make_request()), @@ -193,7 +193,10 @@ class Backend: ], return_when=FIRST_COMPLETED, ) - [task.cancel() for task in pending] + for t in pending: + t.cancel() + await asyncio.gather(*pending, return_exceptions=True) + done_task = done.pop() try: return done_task.result() From 16990ff8ffc4b70882c67f949d66edf2bbce6b9b Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 21 Oct 2025 18:18:44 -0700 Subject: [PATCH 05/18] move start request --- lib/backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 30bc282..1d4072d 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -144,8 +144,6 @@ class Backend: raise asyncio.CancelledError async def make_request() -> Union[web.Response, web.StreamResponse]: - log.debug(f"got request, {request_metrics.reqnum}") - self.metrics._request_start(request_metrics) try: response = await self.__call_api(handler=handler, payload=payload) status_code = response.status @@ -186,6 +184,7 @@ class Backend: ) else: log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") + self.metrics._request_start(request_metrics) done, pending = await wait( [ create_task(make_request()), From 5b5ef7227a7c07e3a4dd58ed6620ce428bc350af Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 21 Oct 2025 18:20:11 -0700 Subject: [PATCH 06/18] nvm moved it here --- lib/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/backend.py b/lib/backend.py index 1d4072d..29b6b2b 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -175,6 +175,7 @@ class Backend: acquired = False try: + self.metrics._request_start(request_metrics) if self.allow_parallel_requests is False: log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") await self.sem.acquire() @@ -184,7 +185,6 @@ class Backend: ) else: log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") - self.metrics._request_start(request_metrics) done, pending = await wait( [ create_task(make_request()), From 5edfa968ca2be0c5732eefcf1fc394f8020073c8 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 21 Oct 2025 18:49:48 -0700 Subject: [PATCH 07/18] async sleep --- lib/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/backend.py b/lib/backend.py index 29b6b2b..3d4ef92 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -411,7 +411,7 @@ class Backend: if line: await handle_log_line(line.rstrip()) else: - time.sleep(LOG_POLL_INTERVAL) + await asyncio.sleep(LOG_POLL_INTERVAL) ########### From 01e752d31ff2e80bf4ae54f0ba172daa67da7638 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 21 Oct 2025 18:52:13 -0700 Subject: [PATCH 08/18] use more asyncio sleep --- lib/metrics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/metrics.py b/lib/metrics.py index f7dfaef..76f2d16 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -5,7 +5,7 @@ import json from asyncio import sleep from dataclasses import dataclass, asdict, field from functools import cache - +import asyncio import requests from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics @@ -143,7 +143,7 @@ class Metrics: log.debug(f"delete_requests timed out") except Exception as e: log.debug(f"delete_requests failed with error: {e}") - time.sleep(2) + asyncio.sleep(2) log.debug(f"retrying delete_request, attempt: {attempt}") for report_addr in self.report_addr: @@ -197,7 +197,7 @@ class Metrics: log.debug(f"autoscaler status update timed out") except Exception as e: log.debug(f"autoscaler status update failed with error: {e}") - time.sleep(2) + asyncio.sleep(2) log.debug(f"retrying autoscaler status update, attempt: {attempt}") log.debug(f"failed to send update through {report_addr}") return False From 0f135069384e593a99d5cbf637a6dbd372c3109a Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 22 Oct 2025 10:18:59 -0700 Subject: [PATCH 09/18] Send success param --- lib/data_types.py | 1 + lib/metrics.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lib/data_types.py b/lib/data_types.py index 43213de..d2cf0c2 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -204,6 +204,7 @@ class RequestMetrics: reqnum: int workload: float status: str + success: bool = False @dataclass class ModelMetrics: diff --git a/lib/metrics.py b/lib/metrics.py index 76f2d16..40dcb9a 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -66,6 +66,7 @@ class Metrics: """ self.model_metrics.workload_served += request.workload request.status = "Success" + request.success = True self.update_pending = True def _request_errored(self, request: RequestMetrics) -> None: @@ -74,6 +75,7 @@ class Metrics: """ self.model_metrics.workload_errored += request.workload request.status = "Error" + request.success = False self.update_pending = True def _request_canceled(self, request: RequestMetrics) -> None: @@ -81,6 +83,7 @@ class Metrics: this function is called if client drops connection before model API has responded """ self.model_metrics.workload_cancelled += request.workload + request.success = True request.status = "Cancelled" def _request_reject(self, request: RequestMetrics): @@ -90,6 +93,7 @@ class Metrics: self.model_metrics.requests_recieved.add(request.reqnum) self.model_metrics.requests_deleting.append(request) self.model_metrics.workload_rejected += request.workload + request.success = False request.status = "Rejected" self.update_pending = True @@ -128,10 +132,11 @@ class Metrics: def __send_delete_requests_and_reset(self): - def send_data(report_addr: str) -> bool: + def send_data(report_addr: str, success: bool) -> bool: data = { "worker_id": self.id, - "request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting] + "request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success], + "success": success } full_path = report_addr.rstrip("/") + "/delete_requests/" for attempt in range(1, 4): @@ -147,7 +152,7 @@ class Metrics: log.debug(f"retrying delete_request, attempt: {attempt}") for report_addr in self.report_addr: - success = send_data(report_addr) + success = send_data(report_addr, success=True) and send_data(report_addr, success=False) if success is True: self.model_metrics.requests_deleting.clear() break From 37ad3f8d4628415c911eb98a59757d2b24835a96 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Thu, 23 Oct 2025 10:18:31 -0700 Subject: [PATCH 10/18] asyncio in metrics --- lib/backend.py | 2 +- lib/metrics.py | 56 ++++++++++++++++++++++++++++++++------------------ 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 3d4ef92..dc1f52c 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -171,7 +171,7 @@ class Backend: if self.metrics.model_metrics.wait_time > self.max_wait_time: self.metrics._request_reject(request_metrics) - return web.Response(status=500) + return web.Response(status=429) acquired = False try: diff --git a/lib/metrics.py b/lib/metrics.py index 40dcb9a..45f44f4 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -6,7 +6,7 @@ from asyncio import sleep from dataclasses import dataclass, asdict, field from functools import cache import asyncio -import requests +from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics from typing import Awaitable, NoReturn, List @@ -38,6 +38,20 @@ class Metrics: url: str = field(default_factory=get_url) system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty) model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty) + _session: ClientSession | None = field(default=None, init=False, repr=False) + + async def http(self) -> ClientSession: + if self._session is None: + self._session = ClientSession( + timeout=ClientTimeout(total=10), + connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True) + ) + return self._session + + async def aclose(self) -> None: + if self._session is not None: + await self._session.close() + self._session = None def _request_start(self, request: RequestMetrics) -> None: """ @@ -101,7 +115,7 @@ class Metrics: while True: await sleep(DELETE_REQUESTS_INTERVAL) if len(self.model_metrics.requests_deleting) > 0: - self.__send_delete_requests_and_reset() + await self.__send_delete_requests_and_reset() async def _send_metrics_loop(self) -> Awaitable[NoReturn]: while True: @@ -109,10 +123,10 @@ class Metrics: elapsed = time.time() - self.last_metric_update if self.system_metrics.model_is_loaded is False and elapsed >= 10: log.debug(f"sending loading model metrics after {int(elapsed)}s wait") - self.__send_metrics_and_reset() + await self.__send_metrics_and_reset() elif self.update_pending or elapsed > 10: log.debug(f"sending loaded model metrics after {int(elapsed)}s wait") - self.__send_metrics_and_reset() + await self.__send_metrics_and_reset() def _model_loaded(self, max_throughput: float) -> None: self.system_metrics.model_loading_time = ( @@ -130,9 +144,9 @@ class Metrics: #######################################Private####################################### - def __send_delete_requests_and_reset(self): + async def __send_delete_requests_and_reset(self): - def send_data(report_addr: str, success: bool) -> bool: + async def send_data(report_addr: str, success: bool) -> bool: data = { "worker_id": self.id, "request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success], @@ -141,24 +155,25 @@ class Metrics: full_path = report_addr.rstrip("/") + "/delete_requests/" for attempt in range(1, 4): try: - res = requests.post(full_path, json=data, timeout=1) - res.raise_for_status() + session = await self.http() + async with session.post(full_path, json=data) as res: + res.raise_for_status() return True - except requests.Timeout: + except asyncio.TimeoutError: log.debug(f"delete_requests timed out") - except Exception as e: + except (ClientResponseError, Exception) as e: log.debug(f"delete_requests failed with error: {e}") - asyncio.sleep(2) + await asyncio.sleep(2) log.debug(f"retrying delete_request, attempt: {attempt}") for report_addr in self.report_addr: - success = send_data(report_addr, success=True) and send_data(report_addr, success=False) + success = await send_data(report_addr, success=True) and await send_data(report_addr, success=False) if success is True: self.model_metrics.requests_deleting.clear() break - def __send_metrics_and_reset(self): + async def __send_metrics_and_reset(self): def compute_autoscaler_data() -> AutoScalerData: return AutoScalerData( @@ -180,7 +195,7 @@ class Metrics: url=self.url, ) - def send_data(report_addr: str) -> bool: + async def send_data(report_addr: str) -> bool: data = compute_autoscaler_data() full_path = report_addr.rstrip("/") + "/worker_status/" log.debug( @@ -195,14 +210,15 @@ class Metrics: ) for attempt in range(1, 4): try: - res = requests.post(full_path, json=asdict(data), timeout=1) - res.raise_for_status() + session = await self.http() + async with session.post(full_path, json=asdict(data)) as res: + res.raise_for_status() return True - except requests.Timeout: + except asyncio.TimeoutError: log.debug(f"autoscaler status update timed out") - except Exception as e: + except (ClientResponseError, Exception) as e: log.debug(f"autoscaler status update failed with error: {e}") - asyncio.sleep(2) + await asyncio.sleep(2) log.debug(f"retrying autoscaler status update, attempt: {attempt}") log.debug(f"failed to send update through {report_addr}") return False @@ -212,7 +228,7 @@ class Metrics: self.system_metrics.update_disk_usage() for report_addr in self.report_addr: - success = send_data(report_addr) + success = await send_data(report_addr) if success is True: break self.update_pending = False From 7788bc4a62f6687f1d3664021c9b090ec4bb4967 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Fri, 24 Oct 2025 15:41:00 -0700 Subject: [PATCH 11/18] Added some debug logs --- lib/metrics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/metrics.py b/lib/metrics.py index 45f44f4..6936a3b 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -152,11 +152,13 @@ class Metrics: "request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success], "success": success } + log.debug(f"Deleting requests that { "succeeded" if success else "failed"}: {data["request_idxs"]}") full_path = report_addr.rstrip("/") + "/delete_requests/" for attempt in range(1, 4): try: session = await self.http() async with session.post(full_path, json=data) as res: + log.debug(f"delete_requests response: {res.status}") res.raise_for_status() return True except asyncio.TimeoutError: From 4d9bf2048cf05ca2b219bed72a8afd903fc24fac Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Fri, 24 Oct 2025 15:44:38 -0700 Subject: [PATCH 12/18] Fix --- lib/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/metrics.py b/lib/metrics.py index 6936a3b..3022a8e 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -152,7 +152,7 @@ class Metrics: "request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success], "success": success } - log.debug(f"Deleting requests that { "succeeded" if success else "failed"}: {data["request_idxs"]}") + log.debug(f"Deleting requests that {'succeeded' if success else 'failed'}: {data['request_idxs']}") full_path = report_addr.rstrip("/") + "/delete_requests/" for attempt in range(1, 4): try: From bcecd6df4049c1ee74390521e4a4d5a827655862 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Sat, 25 Oct 2025 16:18:02 -0700 Subject: [PATCH 13/18] Suppress matplot debug logs --- lib/backend.py | 24 +- lib/data_types.py | 13 +- lib/test_utils.py | 12 +- utils/endpoint_util.py | 48 +++- workers/openai/data_types/server.py | 33 ++- workers/openai/test_load.py | 422 +++++++++++++++++++++++++++- 6 files changed, 519 insertions(+), 33 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index dc1f52c..4a99ac6 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -26,7 +26,8 @@ from lib.data_types import ( LogAction, ApiPayload_T, JsonDataException, - RequestMetrics + RequestMetrics, + BenchmarkResult ) VERSION = "0.1.0" @@ -332,18 +333,23 @@ class Backend: for run in range(1, self.benchmark_handler.benchmark_runs + 1): start = time.time() - tasks = [] - total_workload = 0 + benchmark_requests = [] - for _ in range(concurrent_requests): + for i in range(concurrent_requests): payload = self.benchmark_handler.make_benchmark_payload() - total_workload += payload.count_workload() - tasks.append( - self.__call_api(handler=self.benchmark_handler, payload=payload) + workload = payload.count_workload() + task = self.__call_api(handler=self.benchmark_handler, payload=payload) + benchmark_requests.append( + BenchmarkResult(request_idx=i, workload=workload, task=task) ) - responses = await gather(*tasks) + responses = await gather(*[br.task for br in benchmark_requests]) + for br, response in zip(benchmark_requests, responses): + br.response = response + + total_workload = sum(br.workload for br in benchmark_requests if br.is_successful) time_elapsed = time.time() - start + successful_responses = sum([1 for br in benchmark_requests if br.is_successful]) throughput = total_workload / time_elapsed sum_throughput += throughput @@ -357,7 +363,7 @@ class Backend: f"Run: {run}, concurrent_requests: {concurrent_requests}", f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s", f"Throughput: {throughput} workload/s", - f"Successful responses: {len([r for r in responses if r.status == 200])}", + f"Successful responses: {successful_responses}/{concurrent_requests}", "#" * 60, ] ) diff --git a/lib/data_types.py b/lib/data_types.py index d2cf0c2..389ed18 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -3,7 +3,7 @@ import logging from dataclasses import dataclass, field from enum import Enum from abc import ABC, abstractmethod -from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type +from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable from aiohttp import web, ClientResponse import inspect @@ -206,6 +206,17 @@ class RequestMetrics: status: str success: bool = False +@dataclass +class BenchmarkResult: + request_idx: int + workload: float + task: Awaitable[ClientResponse] + response: Optional[ClientResponse] = None + + @property + def is_successful(self) -> bool: + return self.response is not None and self.response.status == 200 + @dataclass class ModelMetrics: """Model specific metrics""" diff --git a/lib/test_utils.py b/lib/test_utils.py index 8635027..d64a4b6 100644 --- a/lib/test_utils.py +++ b/lib/test_utils.py @@ -292,12 +292,12 @@ def test_load_cmd( args = arg_parser.parse_args() if hasattr(args, "comfy_model"): os.environ["COMFY_MODEL"] = args.comfy_model - server_url = dict( - prod="https://run.vast.ai", - alpha="https://run-alpha.vast.ai", - candidate="https://run-candidate.vast.ai", - local="http://localhost:8080", - )[args.instance] + server_url = { + "prod": "https://run.vast.ai", + "alpha": "https://run-alpha.vast.ai", + "candidate": "https://run-candidate.vast.ai", + "local": "http://localhost:8080", + }.get(args.instance, "http://localhost:8080") run_test( num_requests=args.num_requests, requests_per_second=args.requests_per_second, diff --git a/utils/endpoint_util.py b/utils/endpoint_util.py index 37930af..927262e 100644 --- a/utils/endpoint_util.py +++ b/utils/endpoint_util.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Dict, Optional +import time +from typing import Any, Dict, Optional, Tuple import requests @@ -16,6 +17,38 @@ class Endpoint: Utility class for handling endpoint operations. """ + @staticmethod + def get_endpoint_info( + endpoint_name: str, account_api_key: str, instance: str + ) -> Optional[Dict[str, Any]]: + headers = {"Authorization": f"Bearer {account_api_key}"} + url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}" + # Retry a few times to smooth over transient propagation/network delays + for attempt in range(4): + try: + response = requests.get(url, headers=headers, timeout=8) + if response.status_code != 200: + # brief backoff and retry + time.sleep(0.3 * (attempt + 1)) + continue + try: + data = response.json() + except Exception: + # JSON parse failed; backoff and retry + time.sleep(0.3 * (attempt + 1)) + continue + result = data.get("results", []) if isinstance(data, dict) else [] + endpoint = next( + (item for item in result if item.get("endpoint_name") == endpoint_name), + None, + ) + if endpoint and endpoint.get("id") and endpoint.get("api_key"): + return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")} + except Exception: + # network or other transient error; retry + time.sleep(0.3 * (attempt + 1)) + return None + @staticmethod def get_autoscaler_server_url(instance: str) -> str: endpoints = { @@ -23,7 +56,10 @@ class Endpoint: "candidate": "run-candidate", "prod": "run", } - return f"https://{endpoints[instance]}.vast.ai/" + host = endpoints.get(instance) + if host: + return f"https://{host}.vast.ai/" + return "http://localhost:8080" @staticmethod def get_server_url(instance: str) -> str: @@ -32,7 +68,8 @@ class Endpoint: "candidate": "candidate", "prod": "console", } - return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/" + host = endpoints.get(instance, "alpha") + return f"https://{host}.vast.ai/api/v0/endptjobs/" @staticmethod def get_endpoint_api_key( @@ -55,6 +92,7 @@ class Endpoint: response = requests.get( f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}", headers=headers, + timeout=8, ) if response.status_code != 200: @@ -64,14 +102,14 @@ class Endpoint: try: data = response.json() - except requests.exceptions.JSONDecodeError as e: + except Exception as e: log.debug(f"Failed to parse JSON response: {e}") return None result = data.get("results", []) endpoint: Optional[Dict[str, Any]] = next( - (item for item in result if item["endpoint_name"] == endpoint_name), + (item for item in result if item.get("endpoint_name") == endpoint_name), None, ) if not endpoint: diff --git a/workers/openai/data_types/server.py b/workers/openai/data_types/server.py index 92f204b..e549864 100644 --- a/workers/openai/data_types/server.py +++ b/workers/openai/data_types/server.py @@ -119,14 +119,25 @@ class GenericHandler(EndpointHandler[GenericData], ABC): class CompletionsData(GenericData): @classmethod def for_test(cls) -> "CompletionsData": - prompt = " ".join(random.choices(WORD_LIST, k=int(250))) + system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base: + + Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines + with distinctive black-and-white striped coats. There are three living species: Grévy's zebra + (Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the + genus Equus with horses and asses, the three groups being the only living members of the family + Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern + and southern Africa and can be found in a variety of habitats such as savannahs, grasslands, + woodlands, shrublands, and mountainous areas. + + Please answer the following question based on the above context.""" + unique_question = " ".join(random.choices(WORD_LIST, k=int(100))) model = os.environ.get("MODEL_NAME") if not model: raise ValueError("MODEL_NAME environment variable not set") test_input = { "model": model, - "prompt": prompt, + "prompt": f"{system_prompt}\n\n{unique_question}", "temperature": 0.7, "max_tokens": 500, } @@ -153,7 +164,18 @@ class ChatCompletionsData(GenericData): @classmethod def for_test(cls) -> "ChatCompletionsData": - prompt = " ".join(random.choices(WORD_LIST, k=int(250))) + system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base: + + Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines + with distinctive black-and-white striped coats. There are three living species: Grévy's zebra + (Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the + genus Equus with horses and asses, the three groups being the only living members of the family + Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern + and southern Africa and can be found in a variety of habitats such as savannahs, grasslands, + woodlands, shrublands, and mountainous areas. + + Please answer the following question based on the above context.""" + unique_question = " ".join(random.choices(WORD_LIST, k=int(100))) model = os.environ.get("MODEL_NAME") if not model: raise ValueError("MODEL_NAME environment variable not set") @@ -161,7 +183,10 @@ class ChatCompletionsData(GenericData): # Chat completions use messages format instead of prompt test_input = { "model": model, - "messages": [{"role": "user", "content": prompt}], + "messages": [ + {"role": "system", "content": system_prompt}, # Shared prefix + {"role": "user", "content": unique_question} # Unique per request + ], "temperature": 0.7, "max_tokens": 500, } diff --git a/workers/openai/test_load.py b/workers/openai/test_load.py index 0c45524..9cb5f37 100644 --- a/workers/openai/test_load.py +++ b/workers/openai/test_load.py @@ -1,8 +1,395 @@ -from lib.test_utils import test_load_cmd, test_args +from lib.test_utils import test_args +from utils.endpoint_util import Endpoint +from utils.ssl import get_cert_file_path +from lib.data_types import AuthData from .data_types.server import CompletionsData -import os -WORKER_ENDPOINT = "/v1/completions" +import os +import time +import threading +import requests +from dataclasses import dataclass +from collections import Counter +from urllib.parse import urljoin, urlparse +import re + +# Headless plotting +import matplotlib +matplotlib.use("Agg") +import logging +logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING) +import matplotlib.pyplot as plt +import numpy as np +from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED +from requests.adapters import HTTPAdapter + +def get_incremented_path(path: str) -> str: + base, ext = os.path.splitext(path) + if not os.path.exists(path): + return path + i = 1 + while os.path.exists(f"{base}-{i}{ext}"): + i += 1 + return f"{base}-{i}{ext}" + +WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT) + +@dataclass +class ReqResult: + worker_url: str + route_ms: float + worker_ms: float + total_ms: float + ok: bool + error: str = "" + status_code: int = 0 + t_start: float = 0.0 + t_end: float = 0.0 + workload: float = 0.0 + +def do_one(endpoint_name: str, + endpoint_id: int, + endpoint_api_key: str, + server_url: str, + worker_endpoint: str, + payload, + results_list, + t0, + status_samples, + route_session, + worker_session): + try: + workload = payload.count_workload() + route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload} + headers = {"Authorization": f"Bearer {endpoint_api_key}"} + start = time.time() + r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4) + t_after_route = time.time() + if r0.status_code != 200: + results_list.append(ReqResult(worker_url="", + route_ms=(t_after_route - start) * 1000.0, + worker_ms=0.0, + total_ms=(t_after_route - start) * 1000.0, + ok=False, + error=f"route error {r0.reason} {r0.text}", + status_code=r0.status_code, + t_start=start - t0, + t_end=t_after_route - t0, + workload=workload)) + return + msg = r0.json() + + # 1) Check if we got a worker back from route + worker_url = msg.get("url", "") + if not worker_url: + status = msg.get("status", "") + m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S) + if m: + tot, loading, standby, err = map(int, m.groups()) + idle = max(tot - loading - standby - err, 0) + status_samples.append((time.time() - t0, idle)) + + # 2) If we got a worker, send the request + if worker_url: + req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__) + t_before_worker = time.time() + r1 = worker_session.post( + urljoin(worker_url, worker_endpoint), + json=req, + verify=get_cert_file_path(), + timeout=(4, 120), + ) + t_after_worker = time.time() + if r1.status_code != 200: + results_list.append(ReqResult(worker_url=worker_url, + route_ms=(t_after_route - start) * 1000.0, + worker_ms=(t_after_worker - t_before_worker) * 1000.0, + total_ms=(t_after_worker - start) * 1000.0, + ok=False, + error=f"worker inference error {r1.reason} {r1.text}", + status_code=r1.status_code, + t_start=start - t0, + t_end=t_after_worker - t0, + workload=workload)) + return + # Success case + results_list.append(ReqResult(worker_url=worker_url, + route_ms=(t_after_route - start) * 1000.0, + worker_ms=(t_after_worker - t_before_worker) * 1000.0, + total_ms=(t_after_worker - start) * 1000.0, + ok=True, + error="", + status_code=200, + t_start=start - t0, + t_end=t_after_worker - t0, + workload=workload)) + + # 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking + if worker_url: + try: + r_status = route_session.post( + urljoin(server_url, "/get_endpoint_workers/"), + json={"id": endpoint_id}, + headers={"Authorization": f"Bearer {endpoint_api_key}"}, + timeout=3, + ) + if r_status.status_code == 200: + workers = r_status.json() + idle = 0 + for w in workers: + st = str(w.get("status", "")).lower() + if (st in ("idle")): + idle += 1 + status_samples.append((time.time() - t0, idle)) + except Exception: + pass + except Exception as e: + t = time.time() + results_list.append(ReqResult(worker_url="", + route_ms=0.0, + worker_ms=0.0, + total_ms=0.0, + ok=False, + error=f"unknown error {e}", + status_code=0, + t_start=t - t0, + t_end=t - t0, + workload=0.0)) + +def run_load_with_metrics(num_requests: int, + requests_per_second: float, + endpoint_group_name: str, + account_api_key: str, + server_url: str, + worker_endpoint: str, + instance: str, + out_path: str): + + ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name, + account_api_key=account_api_key, + instance=instance) + if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"): + print(f"Endpoint {endpoint_group_name} not found for API key") + return + endpoint_id = int(ep_info["id"]) + endpoint_api_key = ep_info["api_key"] + + t0 = time.time() + results = [] + status_samples = [] + max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192")) + submit_queue_factor = 2 # cap queued tasks to reduce memory + + # Shared HTTP sessions with connection pooling (persistent connections) + def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session: + sess = requests.Session() + adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0) + sess.mount("https://", adapter) + sess.mount("http://", adapter) + return sess + + # Router: mostly single host, small connection pool is sufficient + route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency) + # Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency + worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8) + + # Fire requests using a thread pool, scheduling at requested RPS + inflight = set() + with ThreadPoolExecutor(max_workers=max_concurrency) as executor: + for i in range(num_requests): + # Pace submissions to RPS + target_time = t0 + i / max(requests_per_second, 1e-9) + sleep_s = target_time - time.time() + if sleep_s > 0: + time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive + + payload = CompletionsData.for_test() + fut = executor.submit( + do_one, + endpoint_group_name, + endpoint_id, + endpoint_api_key, + server_url, + worker_endpoint, + payload, + results, + t0, + status_samples, + route_session, + worker_session, + ) + inflight.add(fut) + # Prevent unbounded queue growth + if len(inflight) >= max_concurrency * submit_queue_factor: + done, not_done = wait(inflight, return_when=FIRST_COMPLETED) + inflight = not_done + # Wait for all outstanding tasks + if inflight: + wait(inflight) + # Close sessions + try: + route_session.close() + finally: + worker_session.close() + + # Aggregate results + oks = [r for r in results if r.ok] + errs = [r for r in results if not r.ok] + total_reqs = len(results) + succ = len(oks) + + total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([]) + worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([]) + route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([]) + + avg_total = float(np.mean(total_ms)) if succ else 0.0 + avg_worker = float(np.mean(worker_ms)) if succ else 0.0 + avg_route = float(np.mean(route_ms)) if succ else 0.0 + p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0) + + # Distribution over workers (by host:port) + hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url] + dist = Counter(hosts) + + # Idle over time (mode per second) + idle_ts, idle_vals = [], [] + if status_samples: + buckets = {} + for ts, idle in status_samples: + k = int(ts) + buckets.setdefault(k, []).append(idle) + keys = sorted(buckets.keys()) + idle_ts = keys + # Use the most frequent sampled value per second (mode) to keep integer counts + idle_vals = [] + for k in keys: + vals_k = [int(v) for v in buckets[k]] + if vals_k: + cnt = Counter(vals_k) + idle_vals.append(cnt.most_common(1)[0][0]) + else: + idle_vals.append(0) + + print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}") + print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}") + print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}") + if errs: + print("Sample errors:") + for e in errs[:5]: + print(f" {e.status_code} {e.error}") + + # Plot: 2x3 grid + fig, axes = plt.subplots(2, 3, figsize=(15, 8)) + fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}") + + # Dist per worker + ax0 = axes[0, 0] + if dist: + items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True) + labels, counts = zip(*items) + ax0.bar(range(len(labels)), counts) + ax0.set_xticks(range(len(labels))) + ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) + ax0.set_title("Request distribution over workers") + ax0.set_ylabel("count") + + # Latency histogram (total) + ax1 = axes[0, 1] + if succ: + ax1.hist(total_ms, bins=30) + ax1.set_title("Total latency (ms)") + ax1.set_xlabel("ms") + ax1.set_ylabel("freq") + + # Eligible workers over time + ax_idle = axes[0, 2] + if idle_ts: + ax_idle.plot(idle_ts, idle_vals, "-o", ms=3) + ax_idle.set_title("Eligible workers over time") + ax_idle.set_xlabel("time (s)") + ax_idle.set_ylabel("eligible count") + + # Throughput over time (completions/sec) + ax_idle = axes[1, 0] + ax_idle.clear() + if succ: + per_sec = {} + for r in oks: + s = int(r.t_end) + per_sec[s] = per_sec.get(s, 0) + 1 + ts = sorted(per_sec.keys()) + vals = [per_sec[t] for t in ts] + ax_idle.plot(ts, vals, "-o", ms=3) + ax_idle.set_title("Completions per second") + ax_idle.set_xlabel("time (s)") + ax_idle.set_ylabel("completions / sec") + + # Summary text + ax3 = axes[1, 1] + ax3.axis("off") + text = ( + f"Total requests: {total_reqs}\n" + f"Success: {succ} Errors: {len(errs)}\n" + f"Avg total latency: {avg_total:.1f} ms\n" + f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n" + f"Avg route latency: {avg_route:.1f} ms\n" + f"Avg worker latency: {avg_worker:.1f} ms\n" + f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n" + f"429 errors: {len([r for r in errs if r.status_code == 429])}\n" + f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n" + f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n" + ) + ax3.set_title("Summary") + ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes) + + # Error count over time + ax_errors = axes[1, 2] + all_end_times = [int(r.t_end) for r in results if r.t_end > 0] + if all_end_times: + min_second = min(all_end_times) + max_second = max(all_end_times) + # Count errors per second + errors_per_second = {} + for result in errs: + second = int(result.t_end) + errors_per_second[second] = errors_per_second.get(second, 0) + 1 + # Create complete timeline including zeros + time_seconds = list(range(min_second, max_second + 1)) + error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds] + ax_errors.plot(time_seconds, error_counts, "-o", ms=3) + ax_errors.set_title("Errors per second") + ax_errors.set_xlabel("time (s)") + ax_errors.set_ylabel("errors / sec") + + # Ensure unique output path and create directory if needed + final_out_path = get_incremented_path(out_path) + out_dir = os.path.dirname(final_out_path) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + + plt.tight_layout(rect=[0, 0, 1, 0.96]) + plt.savefig(final_out_path, dpi=120) + print(f"Saved report to: {final_out_path}") + + # Per-worker latency boxplot (top 12 by volume) + groups = {} + for r in oks: + host = urlparse(r.worker_url).netloc + groups.setdefault(host, []).append(r.total_ms) + items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12] + if items: + labels, data = zip(*items) + fig2, axb = plt.subplots(1, 1, figsize=(12, 5)) + axb.boxplot(data, showfliers=False) + axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) + axb.set_title("Per-worker latency (ms)") + axb.set_ylabel("ms") + plt.tight_layout() + extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png") + plt.savefig(extra_out, dpi=120) + fig2.tight_layout() + fig2.savefig(extra_out, dpi=120) + print(f"Saved worker latency plot to: {extra_out}") if __name__ == "__main__": # Check if MODEL_NAME environment variable is set @@ -16,13 +403,32 @@ if __name__ == "__main__": help="Model to use for completions request (required if MODEL_NAME env var not set)", ) - # Parse known args to get model early, before test_load_cmd adds its args + # Parse known args to get model early, before adding load args known_args, _ = test_args.parse_known_args() - - # Set environment variable if model was provided if hasattr(known_args, "model") and known_args.model: os.environ["MODEL_NAME"] = known_args.model print(f"Set MODEL_NAME environment variable to: {known_args.model}") - # Now call test_load_cmd normally - it will add its own args and re-parse - test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args) + # Load test args + test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests") + test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second") + test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image") + args = test_args.parse_args() + + server_url = { + "prod": "https://run.vast.ai", + "alpha": "https://run-alpha.vast.ai", + "candidate": "https://run-candidate.vast.ai", + "local": "http://localhost:8080" + }.get(args.instance, "http://localhost:8080") + + run_load_with_metrics( + num_requests=args.num_requests, + requests_per_second=args.requests_per_second, + endpoint_group_name=args.endpoint_group_name, + account_api_key=args.api_key, + server_url=server_url, + worker_endpoint=WORKER_ENDPOINT, + instance=args.instance, + out_path=args.out_path, + ) \ No newline at end of file From d6eb498ee4dd04549ce6070cc631eeff1f842764 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Mon, 27 Oct 2025 12:01:55 -0700 Subject: [PATCH 14/18] catch the case where all benchmarks fail (sets error) --- lib/backend.py | 3 +++ lib/data_types.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/backend.py b/lib/backend.py index 4a99ac6..e55ce59 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -350,6 +350,9 @@ class Backend: total_workload = sum(br.workload for br in benchmark_requests if br.is_successful) time_elapsed = time.time() - start successful_responses = sum([1 for br in benchmark_requests if br.is_successful]) + if successful_responses == 0: + self.backend_errored("No successful responses from benchmark") + log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses") throughput = total_workload / time_elapsed sum_throughput += throughput diff --git a/lib/data_types.py b/lib/data_types.py index 389ed18..af1bbd5 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -257,7 +257,7 @@ class ModelMetrics: def wait_time(self) -> float: if (len(self.requests_working) == 0): return 0.0 - return sum([request.workload for request in self.requests_working.values()]) / self.max_throughput + return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001) @property def cur_load(self) -> float: From 830b532781b14cbebe227600e988e116d9d249cd Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Mon, 27 Oct 2025 16:57:52 -0700 Subject: [PATCH 15/18] Trying unified delete --- lib/backend.py | 13 +++++++++++++ lib/metrics.py | 41 ++++++++++++++++++++++++++++++++--------- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index e55ce59..d2ac11c 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -59,6 +59,8 @@ class Backend: ) log_actions: List[Tuple[LogAction, str]] max_wait_time: float = 10.0 + request_queue = asyncio.Queue() + worker_task = asyncio.create_task(_worker()) reqnum = -1 version = VERSION msg_history = [] @@ -91,6 +93,17 @@ class Backend: timeout = ClientTimeout(total=None) return ClientSession(self.model_server_url, timeout=timeout, connector=connector) + async def _worker(self): + while True: + handler, request, fut = await self.request_queue.get() + try: + res = await self.__process_request(handler, request) + fut.set_result(res) + except Exception as e: + fut.set_exception(e) + finally: + self.request_queue.task_done() + def create_handler( self, handler: EndpointHandler[ApiPayload_T], diff --git a/lib/metrics.py b/lib/metrics.py index 3022a8e..de5490b 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -145,14 +145,15 @@ class Metrics: #######################################Private####################################### async def __send_delete_requests_and_reset(self): - - async def send_data(report_addr: str, success: bool) -> bool: + async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool: data = { "worker_id": self.id, - "request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success], - "success": success + "request_idxs": idxs, + "success": success_flag, } - log.debug(f"Deleting requests that {'succeeded' if success else 'failed'}: {data['request_idxs']}") + log.debug( + f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}" + ) full_path = report_addr.rstrip("/") + "/delete_requests/" for attempt in range(1, 4): try: @@ -162,16 +163,38 @@ class Metrics: res.raise_for_status() return True except asyncio.TimeoutError: - log.debug(f"delete_requests timed out") + log.debug("delete_requests timed out") except (ClientResponseError, Exception) as e: log.debug(f"delete_requests failed with error: {e}") await asyncio.sleep(2) log.debug(f"retrying delete_request, attempt: {attempt}") + return False + + # Take a snapshot of what we plan to send this tick. + # New arrivals after this snapshot will remain in the queue for the next tick. + snapshot = list(self.model_metrics.requests_deleting) + success_idxs = [r.request_idx for r in snapshot if r.success is True] + failed_idxs = [r.request_idx for r in snapshot if r.success is False] + + if not success_idxs and not failed_idxs: + return # nothing to do for report_addr in self.report_addr: - success = await send_data(report_addr, success=True) and await send_data(report_addr, success=False) - if success is True: - self.model_metrics.requests_deleting.clear() + sent_success = True + sent_failed = True + + if success_idxs: + sent_success = await post(report_addr, success_idxs, True) + if failed_idxs: + sent_failed = await post(report_addr, failed_idxs, False) + + if sent_success and sent_failed: + # Remove only the items we actually sent from the live queue. + sent_set = set(success_idxs) | set(failed_idxs) + self.model_metrics.requests_deleting[:] = [ + r for r in self.model_metrics.requests_deleting + if r.request_idx not in sent_set + ] break From 9c795e2a014938e34627f07fd72546ed17db47af Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Mon, 27 Oct 2025 17:03:13 -0700 Subject: [PATCH 16/18] removed bad code --- lib/backend.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index d2ac11c..8f0cae1 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -59,8 +59,6 @@ class Backend: ) log_actions: List[Tuple[LogAction, str]] max_wait_time: float = 10.0 - request_queue = asyncio.Queue() - worker_task = asyncio.create_task(_worker()) reqnum = -1 version = VERSION msg_history = [] From 22bca74087df00161c622c046619851361fa93df Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Mon, 27 Oct 2025 18:25:21 -0700 Subject: [PATCH 17/18] Prevent load time race --- lib/data_types.py | 5 +++-- lib/metrics.py | 20 +++++++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/lib/data_types.py b/lib/data_types.py index af1bbd5..ceadfed 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -190,11 +190,12 @@ class SystemMetrics: self.additional_disk_usage = disk_usage - self.last_disk_usage self.last_disk_usage = disk_usage - def reset(self): + def reset(self, expected: float | None) -> None: # autoscaler excepts model_loading_time to be populated only once, when the instance has # finished benchmarking and is ready to receive requests. This applies to restarted instances # as well: they should send model_loading_time once when they are done loading - self.model_loading_time = None + if self.model_loading_time == expected: + self.model_loading_time = None @dataclass diff --git a/lib/metrics.py b/lib/metrics.py index de5490b..6cb9c4f 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -200,11 +200,13 @@ class Metrics: async def __send_metrics_and_reset(self): + loadtime_snapshot = self.system_metrics.model_loading_time + def compute_autoscaler_data() -> AutoScalerData: return AutoScalerData( id=self.id, version=self.version, - loadtime=(self.system_metrics.model_loading_time or 0.0), + loadtime=(loadtime_snapshot or 0.0), new_load=self.model_metrics.workload_processing, cur_load=self.model_metrics.cur_load, rej_load=self.model_metrics.workload_rejected, @@ -252,11 +254,15 @@ class Metrics: self.system_metrics.update_disk_usage() + sent = False for report_addr in self.report_addr: - success = await send_data(report_addr) - if success is True: + if await send_data(report_addr): + sent = True break - self.update_pending = False - self.model_metrics.reset() - self.system_metrics.reset() - self.last_metric_update = time.time() + + if sent: + # clear the one-shot loadtime only if we actually sent *this* value + self.system_metrics.reset(expected=loadtime_snapshot) + self.update_pending = False + self.model_metrics.reset() + self.last_metric_update = time.time() From 814c3acd4ca87d7beee433eab8ebb3073c685077 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 28 Oct 2025 13:43:57 -0700 Subject: [PATCH 18/18] remove unused code --- lib/backend.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 8f0cae1..e55ce59 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -91,17 +91,6 @@ class Backend: timeout = ClientTimeout(total=None) return ClientSession(self.model_server_url, timeout=timeout, connector=connector) - async def _worker(self): - while True: - handler, request, fut = await self.request_queue.get() - try: - res = await self.__process_request(handler, request) - fut.set_result(res) - except Exception as e: - fut.set_exception(e) - finally: - self.request_queue.task_done() - def create_handler( self, handler: EndpointHandler[ApiPayload_T],