Use delete_requests and track request_idxs

This commit is contained in:
Lucas Armand
2025-10-08 16:54:18 -07:00
parent 4fdc314fd9
commit e9ba1b03e4
3 changed files with 146 additions and 48 deletions
+35 -19
View File
@@ -12,6 +12,7 @@ from distutils.util import strtobool
from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
import asyncio
import requests
from Crypto.Signature import pkcs1_15
@@ -25,6 +26,7 @@ from lib.data_types import (
LogAction,
ApiPayload_T,
JsonDataException,
RequestMetrics
)
MSG_HISTORY_LEN = 100
@@ -53,6 +55,7 @@ class Backend:
EndpointHandler # this endpoint handler will be used for benchmarking
)
log_actions: List[Tuple[LogAction, str]]
max_wait_time: float = 10.0
reqnum = -1
msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
@@ -128,53 +131,53 @@ class Backend:
except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload)
return web.Response(status=500)
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics)
raise asyncio.CancelledError
async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
log.debug(f"got request, {request_metrics.reqnum}")
self.metrics._request_start(request_metrics)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
log.debug(
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
try:
response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status
log.debug(
" ".join(
[
f"request with reqnum:{auth_data.reqnum}",
f"request with reqnum:{request_metrics.reqnum}",
f"returned status code: {status_code},",
]
)
)
res = await handler.generate_client_response(request, response)
self.metrics._request_success(workload=workload)
self.metrics._request_success(request_metrics)
return res
except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(workload=workload)
self.metrics._request_errored(request_metrics)
return web.Response(status=500)
finally:
self.metrics._request_end(
workload=workload,
reqnum=auth_data.reqnum,
)
self.sem.release()
###########
if self.__check_signature(auth_data) is False:
self.metrics._request_reject(request_metrics)
return web.Response(status=401)
if self.metrics.model_metrics.wait_time > self.max_wait_time:
self.metrics._request_reject(request_metrics)
return web.Response(status=500)
try:
done, pending = await wait(
@@ -185,10 +188,23 @@ class Backend:
return_when=FIRST_COMPLETED,
)
[task.cancel() for task in pending]
return done.pop().result()
done_task = done.pop()
try:
return done_task.result()
except Exception as e:
log.debug(f"Request task raised exception: {e}")
return web.Response(status=500)
except asyncio.CancelledError:
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499)
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
finally:
# Always release the semaphore if it was acquired
if not self.allow_parallel_requests:
self.sem.release()
self.metrics._request_end(request_metrics)
@cached_property
def healthcheck_session(self):
@@ -229,7 +245,7 @@ class Backend:
async def _start_tracking(self) -> None:
await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
)
def backend_errored(self, msg: str) -> None: