From e9ba1b03e4bdbc12f6a4cca8038a7236bce89659 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 8 Oct 2025 16:54:18 -0700 Subject: [PATCH] Use delete_requests and track request_idxs --- lib/backend.py | 54 ++++++++++++++++--------- lib/data_types.py | 40 +++++++++++++++---- lib/metrics.py | 100 ++++++++++++++++++++++++++++++++++++---------- 3 files changed, 146 insertions(+), 48 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 76b1ec5..a62ab6c 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -12,6 +12,7 @@ from distutils.util import strtobool from anyio import open_file from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector +import asyncio import requests from Crypto.Signature import pkcs1_15 @@ -25,6 +26,7 @@ from lib.data_types import ( LogAction, ApiPayload_T, JsonDataException, + RequestMetrics ) MSG_HISTORY_LEN = 100 @@ -53,6 +55,7 @@ class Backend: EndpointHandler # this endpoint handler will be used for benchmarking ) log_actions: List[Tuple[LogAction, str]] + max_wait_time: float = 10.0 reqnum = -1 msg_history = [] sem: Semaphore = dataclasses.field(default_factory=Semaphore) @@ -128,53 +131,53 @@ class Backend: except json.JSONDecodeError: return web.json_response(dict(error="invalid JSON"), status=422) workload = payload.count_workload() + request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") async def cancel_api_call_if_disconnected() -> web.Response: await request.wait_for_disconnection() - log.debug(f"request with reqnum: {auth_data.reqnum} was canceled") - self.metrics._request_canceled(workload=workload) - return web.Response(status=500) + log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled") + self.metrics._request_canceled(request_metrics) + raise asyncio.CancelledError async def make_request() -> Union[web.Response, web.StreamResponse]: - log.debug(f"got request, {auth_data.reqnum}") - self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum) + log.debug(f"got request, {request_metrics.reqnum}") + self.metrics._request_start(request_metrics) if self.allow_parallel_requests is False: - log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}") + log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") await self.sem.acquire() log.debug( - f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..." + f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..." ) else: - log.debug(f"Starting request for reqnum:{auth_data.reqnum}") + log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") try: response = await self.__call_api(handler=handler, payload=payload) status_code = response.status log.debug( " ".join( [ - f"request with reqnum:{auth_data.reqnum}", + f"request with reqnum:{request_metrics.reqnum}", f"returned status code: {status_code},", ] ) ) res = await handler.generate_client_response(request, response) - self.metrics._request_success(workload=workload) + self.metrics._request_success(request_metrics) return res except requests.exceptions.RequestException as e: log.debug(f"[backend] Request error: {e}") - self.metrics._request_errored(workload=workload) + self.metrics._request_errored(request_metrics) return web.Response(status=500) - finally: - self.metrics._request_end( - workload=workload, - reqnum=auth_data.reqnum, - ) - self.sem.release() ########### if self.__check_signature(auth_data) is False: + self.metrics._request_reject(request_metrics) return web.Response(status=401) + + if self.metrics.model_metrics.wait_time > self.max_wait_time: + self.metrics._request_reject(request_metrics) + return web.Response(status=500) try: done, pending = await wait( @@ -185,10 +188,23 @@ class Backend: return_when=FIRST_COMPLETED, ) [task.cancel() for task in pending] - return done.pop().result() + done_task = done.pop() + try: + return done_task.result() + except Exception as e: + log.debug(f"Request task raised exception: {e}") + return web.Response(status=500) + except asyncio.CancelledError: + # Client is gone. Do not write a response; just unwind. + return web.Response(status=499) except Exception as e: log.debug(f"Exception in main handler loop {e}") return web.Response(status=500) + finally: + # Always release the semaphore if it was acquired + if not self.allow_parallel_requests: + self.sem.release() + self.metrics._request_end(request_metrics) @cached_property def healthcheck_session(self): @@ -229,7 +245,7 @@ class Backend: async def _start_tracking(self) -> None: await gather( - self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck() + self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop() ) def backend_errored(self, msg: str) -> None: diff --git a/lib/data_types.py b/lib/data_types.py index ed8b9f4..ec72204 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -70,6 +70,7 @@ class AuthData: endpoint: str reqnum: int url: str + request_idx: int @classmethod def from_json_msg(cls, json_msg: Dict[str, Any]): @@ -196,6 +197,14 @@ class SystemMetrics: self.model_loading_time = None +@dataclass +class RequestMetrics: + """Tracks metrics for an active request.""" + request_idx: int + reqnum: int + workload: float + status: str + @dataclass class ModelMetrics: """Model specific metrics""" @@ -205,12 +214,14 @@ class ModelMetrics: workload_received: float workload_cancelled: float workload_errored: float + workload_rejected: float # these are not workload_pending: float error_msg: Optional[str] max_throughput: float requests_recieved: Set[int] = field(default_factory=set) - requests_working: Set[int] = field(default_factory=set) + requests_working: dict[int, RequestMetrics] = field(default_factory=dict) + requests_deleting: list[RequestMetrics] = field(default_factory=list) last_update: float = field(default_factory=time.time) @classmethod @@ -220,19 +231,30 @@ class ModelMetrics: workload_served=0.0, workload_cancelled=0.0, workload_errored=0.0, + workload_rejected=0.0, workload_received=0.0, error_msg=None, max_throughput=0.0, ) - - @property - def cur_perf(self) -> float: - return max(self.workload_served / (time.time() - self.last_update), 0.0) - + @property def workload_processing(self) -> float: return max(self.workload_received - self.workload_cancelled, 0.0) + @property + def wait_time(self) -> float: + if (len(self.requests_working) == 0): + return 0.0 + return sum([request.workload for request in self.requests_working.values()]) / self.max_throughput + + @property + def cur_load(self) -> float: + return sum([request.workload for request in self.requests_working.values()]) + + @property + def working_request_idxs(self) -> list[int]: + return [req.request_idx for req in self.requests_working.values()] + def set_errored(self, error_msg): self.reset() self.error_msg = error_msg @@ -242,16 +264,19 @@ class ModelMetrics: self.workload_received = 0 self.workload_cancelled = 0 self.workload_errored = 0 + self.workload_rejected = 0 self.last_update = time.time() @dataclass -class AutoScalaerData: +class AutoScalerData: """Data that is reported to autoscaler""" id: int loadtime: float cur_load: float + rej_load: float + new_load: float error_msg: str max_perf: float cur_perf: float @@ -260,6 +285,7 @@ class AutoScalaerData: num_requests_working: int num_requests_recieved: int additional_disk_usage: float + working_request_idxs: list[int] url: str diff --git a/lib/metrics.py b/lib/metrics.py index 166706b..053a465 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -8,10 +8,11 @@ from functools import cache import requests -from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics +from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics from typing import Awaitable, NoReturn, List METRICS_UPDATE_INTERVAL = 1 +DELETE_REQUESTS_INTERVAL = 1 log = logging.getLogger(__file__) @@ -27,6 +28,7 @@ def get_url() -> str: @dataclass class Metrics: last_metric_update: float = 0.0 + last_request_served: float = 0.0 update_pending: bool = False id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"])) report_addr: List[str] = field( @@ -36,41 +38,65 @@ class Metrics: system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty) model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty) - def _request_start(self, workload: float, reqnum: int) -> None: + def _request_start(self, request: RequestMetrics) -> None: """ this function is called prior to forwarding a request to a model API. """ log.debug("request start") - self.model_metrics.workload_pending += workload - self.model_metrics.workload_received += workload - self.model_metrics.requests_recieved.add(reqnum) - self.model_metrics.requests_working.add(reqnum) + request.status = "Started" + self.model_metrics.workload_pending += request.workload + self.model_metrics.workload_received += request.workload + self.model_metrics.requests_recieved.add(request.reqnum) + self.model_metrics.requests_working[request.reqnum] = request + self.update_pending = True - def _request_end(self, workload: float, reqnum: int) -> None: + def _request_end(self, request: RequestMetrics) -> None: """ this function is called after handling of a request ends, regardless of the outcome """ - self.model_metrics.workload_pending -= workload - self.model_metrics.requests_working.discard(reqnum) + self.model_metrics.workload_pending -= request.workload + self.model_metrics.requests_working.pop(request.reqnum, None) + self.model_metrics.requests_deleting.append(request) + self.last_request_served = time.time() - def _request_success(self, workload: float) -> None: + def _request_success(self, request: RequestMetrics) -> None: """ this function is called after a response from model API is received and forwarded. """ - self.model_metrics.workload_served += workload + self.model_metrics.workload_served += request.workload + request.status = "Success" self.update_pending = True - def _request_errored(self, workload: float) -> None: + def _request_errored(self, request: RequestMetrics) -> None: """ this function is called if model API returns an error """ - self.model_metrics.workload_errored += workload + self.model_metrics.workload_errored += request.workload + request.status = "Error" + self.update_pending = True - def _request_canceled(self, workload: float) -> None: + def _request_canceled(self, request: RequestMetrics) -> None: """ this function is called if client drops connection before model API has responded """ - self.model_metrics.workload_cancelled += workload + self.model_metrics.workload_cancelled += request.workload + request.status = "Cancelled" + + def _request_reject(self, request: RequestMetrics): + """ + this function is called if the current wait time for the model is above max_wait_time + """ + self.model_metrics.requests_recieved.add(request.reqnum) + self.model_metrics.requests_deleting.append(request) + self.model_metrics.workload_rejected += request.workload + request.status = "Rejected" + self.update_pending = True + + async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]: + while True: + await sleep(DELETE_REQUESTS_INTERVAL) + if len(self.model_metrics.requests_deleting) > 0: + self.__send_delete_requests_and_reset() async def _send_metrics_loop(self) -> Awaitable[NoReturn]: while True: @@ -78,10 +104,10 @@ class Metrics: elapsed = time.time() - self.last_metric_update if self.system_metrics.model_is_loaded is False and elapsed >= 10: log.debug(f"sending loading model metrics after {int(elapsed)}s wait") - self.__send_metrics_and_reset(elapsed) + self.__send_metrics_and_reset() elif self.update_pending or elapsed > 10: log.debug(f"sending loaded model metrics after {int(elapsed)}s wait") - self.__send_metrics_and_reset(elapsed) + self.__send_metrics_and_reset() def _model_loaded(self, max_throughput: float) -> None: self.system_metrics.model_loading_time = ( @@ -96,19 +122,49 @@ class Metrics: #######################################Private####################################### - def __send_metrics_and_reset(self, elapsed): + def __send_delete_requests_and_reset(self): - def compute_autoscaler_data() -> AutoScalaerData: - return AutoScalaerData( + def send_data(report_addr: str) -> bool: + data = { + "worker_id": self.id, + "request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting] + } + full_path = report_addr.rstrip("/") + "/delete_requests/" + for attempt in range(1, 4): + try: + res = requests.post(full_path, json=data, timeout=1) + res.raise_for_status() + return True + except requests.Timeout: + log.debug(f"delete_requests timed out") + except Exception as e: + log.debug(f"delete_requests failed with error: {e}") + time.sleep(2) + log.debug(f"retrying delete_request, attempt: {attempt}") + + for report_addr in self.report_addr: + success = send_data(report_addr) + if success is True: + self.model_metrics.requests_deleting.clear() + break + + + def __send_metrics_and_reset(self): + + def compute_autoscaler_data() -> AutoScalerData: + return AutoScalerData( id=self.id, loadtime=(self.system_metrics.model_loading_time or 0.0), - cur_load=(self.model_metrics.workload_processing / elapsed), + new_load=self.model_metrics.workload_processing, + cur_load=self.model_metrics.cur_load, + rej_load=self.model_metrics.workload_rejected, max_perf=self.model_metrics.max_throughput, - cur_perf=self.model_metrics.cur_perf, + cur_perf=self.model_metrics.workload_served, error_msg=self.model_metrics.error_msg or "", num_requests_working=len(self.model_metrics.requests_working), num_requests_recieved=len(self.model_metrics.requests_recieved), additional_disk_usage=self.system_metrics.additional_disk_usage, + working_request_idxs=self.model_metrics.working_request_idxs, cur_capacity=0, max_capacity=0, url=self.url,