Use delete_requests and track request_idxs

This commit is contained in:
Lucas Armand
2025-10-08 16:54:18 -07:00
parent 4fdc314fd9
commit e9ba1b03e4
3 changed files with 146 additions and 48 deletions
+35 -19
View File
@@ -12,6 +12,7 @@ from distutils.util import strtobool
from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
import asyncio
import requests
from Crypto.Signature import pkcs1_15
@@ -25,6 +26,7 @@ from lib.data_types import (
LogAction,
ApiPayload_T,
JsonDataException,
RequestMetrics
)
MSG_HISTORY_LEN = 100
@@ -53,6 +55,7 @@ class Backend:
EndpointHandler # this endpoint handler will be used for benchmarking
)
log_actions: List[Tuple[LogAction, str]]
max_wait_time: float = 10.0
reqnum = -1
msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
@@ -128,54 +131,54 @@ class Backend:
except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload)
return web.Response(status=500)
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics)
raise asyncio.CancelledError
async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
log.debug(f"got request, {request_metrics.reqnum}")
self.metrics._request_start(request_metrics)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
log.debug(
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
try:
response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status
log.debug(
" ".join(
[
f"request with reqnum:{auth_data.reqnum}",
f"request with reqnum:{request_metrics.reqnum}",
f"returned status code: {status_code},",
]
)
)
res = await handler.generate_client_response(request, response)
self.metrics._request_success(workload=workload)
self.metrics._request_success(request_metrics)
return res
except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(workload=workload)
self.metrics._request_errored(request_metrics)
return web.Response(status=500)
finally:
self.metrics._request_end(
workload=workload,
reqnum=auth_data.reqnum,
)
self.sem.release()
###########
if self.__check_signature(auth_data) is False:
self.metrics._request_reject(request_metrics)
return web.Response(status=401)
if self.metrics.model_metrics.wait_time > self.max_wait_time:
self.metrics._request_reject(request_metrics)
return web.Response(status=500)
try:
done, pending = await wait(
[
@@ -185,10 +188,23 @@ class Backend:
return_when=FIRST_COMPLETED,
)
[task.cancel() for task in pending]
return done.pop().result()
done_task = done.pop()
try:
return done_task.result()
except Exception as e:
log.debug(f"Request task raised exception: {e}")
return web.Response(status=500)
except asyncio.CancelledError:
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499)
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
finally:
# Always release the semaphore if it was acquired
if not self.allow_parallel_requests:
self.sem.release()
self.metrics._request_end(request_metrics)
@cached_property
def healthcheck_session(self):
@@ -229,7 +245,7 @@ class Backend:
async def _start_tracking(self) -> None:
await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
)
def backend_errored(self, msg: str) -> None:
+32 -6
View File
@@ -70,6 +70,7 @@ class AuthData:
endpoint: str
reqnum: int
url: str
request_idx: int
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -196,6 +197,14 @@ class SystemMetrics:
self.model_loading_time = None
@dataclass
class RequestMetrics:
"""Tracks metrics for an active request."""
request_idx: int
reqnum: int
workload: float
status: str
@dataclass
class ModelMetrics:
"""Model specific metrics"""
@@ -205,12 +214,14 @@ class ModelMetrics:
workload_received: float
workload_cancelled: float
workload_errored: float
workload_rejected: float
# these are not
workload_pending: float
error_msg: Optional[str]
max_throughput: float
requests_recieved: Set[int] = field(default_factory=set)
requests_working: Set[int] = field(default_factory=set)
requests_working: dict[int, RequestMetrics] = field(default_factory=dict)
requests_deleting: list[RequestMetrics] = field(default_factory=list)
last_update: float = field(default_factory=time.time)
@classmethod
@@ -220,19 +231,30 @@ class ModelMetrics:
workload_served=0.0,
workload_cancelled=0.0,
workload_errored=0.0,
workload_rejected=0.0,
workload_received=0.0,
error_msg=None,
max_throughput=0.0,
)
@property
def cur_perf(self) -> float:
return max(self.workload_served / (time.time() - self.last_update), 0.0)
@property
def workload_processing(self) -> float:
return max(self.workload_received - self.workload_cancelled, 0.0)
@property
def wait_time(self) -> float:
if (len(self.requests_working) == 0):
return 0.0
return sum([request.workload for request in self.requests_working.values()]) / self.max_throughput
@property
def cur_load(self) -> float:
return sum([request.workload for request in self.requests_working.values()])
@property
def working_request_idxs(self) -> list[int]:
return [req.request_idx for req in self.requests_working.values()]
def set_errored(self, error_msg):
self.reset()
self.error_msg = error_msg
@@ -242,16 +264,19 @@ class ModelMetrics:
self.workload_received = 0
self.workload_cancelled = 0
self.workload_errored = 0
self.workload_rejected = 0
self.last_update = time.time()
@dataclass
class AutoScalaerData:
class AutoScalerData:
"""Data that is reported to autoscaler"""
id: int
loadtime: float
cur_load: float
rej_load: float
new_load: float
error_msg: str
max_perf: float
cur_perf: float
@@ -260,6 +285,7 @@ class AutoScalaerData:
num_requests_working: int
num_requests_recieved: int
additional_disk_usage: float
working_request_idxs: list[int]
url: str
+78 -22
View File
@@ -8,10 +8,11 @@ from functools import cache
import requests
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics
from typing import Awaitable, NoReturn, List
METRICS_UPDATE_INTERVAL = 1
DELETE_REQUESTS_INTERVAL = 1
log = logging.getLogger(__file__)
@@ -27,6 +28,7 @@ def get_url() -> str:
@dataclass
class Metrics:
last_metric_update: float = 0.0
last_request_served: float = 0.0
update_pending: bool = False
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
report_addr: List[str] = field(
@@ -36,41 +38,65 @@ class Metrics:
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
def _request_start(self, workload: float, reqnum: int) -> None:
def _request_start(self, request: RequestMetrics) -> None:
"""
this function is called prior to forwarding a request to a model API.
"""
log.debug("request start")
self.model_metrics.workload_pending += workload
self.model_metrics.workload_received += workload
self.model_metrics.requests_recieved.add(reqnum)
self.model_metrics.requests_working.add(reqnum)
request.status = "Started"
self.model_metrics.workload_pending += request.workload
self.model_metrics.workload_received += request.workload
self.model_metrics.requests_recieved.add(request.reqnum)
self.model_metrics.requests_working[request.reqnum] = request
self.update_pending = True
def _request_end(self, workload: float, reqnum: int) -> None:
def _request_end(self, request: RequestMetrics) -> None:
"""
this function is called after handling of a request ends, regardless of the outcome
"""
self.model_metrics.workload_pending -= workload
self.model_metrics.requests_working.discard(reqnum)
self.model_metrics.workload_pending -= request.workload
self.model_metrics.requests_working.pop(request.reqnum, None)
self.model_metrics.requests_deleting.append(request)
self.last_request_served = time.time()
def _request_success(self, workload: float) -> None:
def _request_success(self, request: RequestMetrics) -> None:
"""
this function is called after a response from model API is received and forwarded.
"""
self.model_metrics.workload_served += workload
self.model_metrics.workload_served += request.workload
request.status = "Success"
self.update_pending = True
def _request_errored(self, workload: float) -> None:
def _request_errored(self, request: RequestMetrics) -> None:
"""
this function is called if model API returns an error
"""
self.model_metrics.workload_errored += workload
self.model_metrics.workload_errored += request.workload
request.status = "Error"
self.update_pending = True
def _request_canceled(self, workload: float) -> None:
def _request_canceled(self, request: RequestMetrics) -> None:
"""
this function is called if client drops connection before model API has responded
"""
self.model_metrics.workload_cancelled += workload
self.model_metrics.workload_cancelled += request.workload
request.status = "Cancelled"
def _request_reject(self, request: RequestMetrics):
"""
this function is called if the current wait time for the model is above max_wait_time
"""
self.model_metrics.requests_recieved.add(request.reqnum)
self.model_metrics.requests_deleting.append(request)
self.model_metrics.workload_rejected += request.workload
request.status = "Rejected"
self.update_pending = True
async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]:
while True:
await sleep(DELETE_REQUESTS_INTERVAL)
if len(self.model_metrics.requests_deleting) > 0:
self.__send_delete_requests_and_reset()
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
while True:
@@ -78,10 +104,10 @@ class Metrics:
elapsed = time.time() - self.last_metric_update
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed)
self.__send_metrics_and_reset()
elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed)
self.__send_metrics_and_reset()
def _model_loaded(self, max_throughput: float) -> None:
self.system_metrics.model_loading_time = (
@@ -96,19 +122,49 @@ class Metrics:
#######################################Private#######################################
def __send_metrics_and_reset(self, elapsed):
def __send_delete_requests_and_reset(self):
def compute_autoscaler_data() -> AutoScalaerData:
return AutoScalaerData(
def send_data(report_addr: str) -> bool:
data = {
"worker_id": self.id,
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting]
}
full_path = report_addr.rstrip("/") + "/delete_requests/"
for attempt in range(1, 4):
try:
res = requests.post(full_path, json=data, timeout=1)
res.raise_for_status()
return True
except requests.Timeout:
log.debug(f"delete_requests timed out")
except Exception as e:
log.debug(f"delete_requests failed with error: {e}")
time.sleep(2)
log.debug(f"retrying delete_request, attempt: {attempt}")
for report_addr in self.report_addr:
success = send_data(report_addr)
if success is True:
self.model_metrics.requests_deleting.clear()
break
def __send_metrics_and_reset(self):
def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData(
id=self.id,
loadtime=(self.system_metrics.model_loading_time or 0.0),
cur_load=(self.model_metrics.workload_processing / elapsed),
new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected,
max_perf=self.model_metrics.max_throughput,
cur_perf=self.model_metrics.cur_perf,
cur_perf=self.model_metrics.workload_served,
error_msg=self.model_metrics.error_msg or "",
num_requests_working=len(self.model_metrics.requests_working),
num_requests_recieved=len(self.model_metrics.requests_recieved),
additional_disk_usage=self.system_metrics.additional_disk_usage,
working_request_idxs=self.model_metrics.working_request_idxs,
cur_capacity=0,
max_capacity=0,
url=self.url,