Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0471f6b219 |
+67
-5
@@ -5,7 +5,7 @@ import base64
|
|||||||
import subprocess
|
import subprocess
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task, get_running_loop
|
||||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
@@ -47,7 +47,7 @@ class Backend:
|
|||||||
This class is responsible for:
|
This class is responsible for:
|
||||||
1. Tailing logs and updating load time metrics
|
1. Tailing logs and updating load time metrics
|
||||||
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
|
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
|
||||||
sending the request. It also updates metrics as it makes those requests.
|
sending the request. It also updates metrics as it makes those requests.
|
||||||
3. Running a benchmark from an EndpointHandler
|
3. Running a benchmark from an EndpointHandler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -74,6 +74,11 @@ class Backend:
|
|||||||
self._pubkey = self._fetch_pubkey()
|
self._pubkey = self._fetch_pubkey()
|
||||||
self.__start_healthcheck: bool = False
|
self.__start_healthcheck: bool = False
|
||||||
|
|
||||||
|
# NEW: FIFO queue + worker count
|
||||||
|
self.request_queue: "asyncio.Queue[tuple[EndpointHandler[ApiPayload_T], web.Request, asyncio.Future]]" = asyncio.Queue()
|
||||||
|
# If parallel allowed, let multiple workers drain the queue (order preserved by FIFO per worker; overall start order is FIFO).
|
||||||
|
self._num_workers: int = 1 if not self.allow_parallel_requests else int(os.environ.get("WORKERS", "4"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pubkey(self) -> Optional[RSA.RsaKey]:
|
def pubkey(self) -> Optional[RSA.RsaKey]:
|
||||||
if self._pubkey is None:
|
if self._pubkey is None:
|
||||||
@@ -91,6 +96,22 @@ class Backend:
|
|||||||
timeout = ClientTimeout(total=None)
|
timeout = ClientTimeout(total=None)
|
||||||
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
||||||
|
|
||||||
|
async def _worker(self):
|
||||||
|
while True:
|
||||||
|
handler, request, fut = await self.request_queue.get()
|
||||||
|
try:
|
||||||
|
# Skip if already cancelled while waiting in the queue
|
||||||
|
if fut.cancelled():
|
||||||
|
continue
|
||||||
|
res = await self.__process_enqueued_request(handler, request)
|
||||||
|
if not fut.cancelled():
|
||||||
|
fut.set_result(res)
|
||||||
|
except Exception as e:
|
||||||
|
if not fut.cancelled():
|
||||||
|
fut.set_exception(e)
|
||||||
|
finally:
|
||||||
|
self.request_queue.task_done()
|
||||||
|
|
||||||
def create_handler(
|
def create_handler(
|
||||||
self,
|
self,
|
||||||
handler: EndpointHandler[ApiPayload_T],
|
handler: EndpointHandler[ApiPayload_T],
|
||||||
@@ -127,7 +148,36 @@ class Backend:
|
|||||||
handler: EndpointHandler[ApiPayload_T],
|
handler: EndpointHandler[ApiPayload_T],
|
||||||
request: web.Request,
|
request: web.Request,
|
||||||
) -> Union[web.Response, web.StreamResponse]:
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
"""use this function to forward requests to the model endpoint"""
|
"""use this function to enqueue requests for FIFO processing"""
|
||||||
|
loop = get_running_loop()
|
||||||
|
fut: asyncio.Future = loop.create_future()
|
||||||
|
|
||||||
|
# If the client disconnects while waiting in the FIFO, cancel the future so the worker skips it
|
||||||
|
cancel_watch = create_task(request.wait_for_disconnection())
|
||||||
|
def _cancel_if_disconnected(_):
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
cancel_watch.add_done_callback(_cancel_if_disconnected)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.request_queue.put((handler, request, fut))
|
||||||
|
return await fut
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Propagate cancellation to ensure aiohttp doesn't expect a response body
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Best-effort cleanup of the watcher
|
||||||
|
cancel_watch.cancel()
|
||||||
|
|
||||||
|
async def __process_enqueued_request(
|
||||||
|
self,
|
||||||
|
handler: EndpointHandler[ApiPayload_T],
|
||||||
|
request: web.Request,
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
"""
|
||||||
|
This contains the original __handle_request logic and is invoked by workers,
|
||||||
|
ensuring FIFO execution via asyncio.Queue.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
auth_data, payload = handler.get_data_from_request(data)
|
auth_data, payload = handler.get_data_from_request(data)
|
||||||
@@ -135,8 +185,11 @@ class Backend:
|
|||||||
return web.json_response(data=e.message, status=422)
|
return web.json_response(data=e.message, status=422)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||||
|
|
||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
request_metrics: RequestMetrics = RequestMetrics(
|
||||||
|
request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created"
|
||||||
|
)
|
||||||
|
|
||||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||||
await request.wait_for_disconnection()
|
await request.wait_for_disconnection()
|
||||||
@@ -177,6 +230,8 @@ class Backend:
|
|||||||
acquired = False
|
acquired = False
|
||||||
try:
|
try:
|
||||||
self.metrics._request_start(request_metrics)
|
self.metrics._request_start(request_metrics)
|
||||||
|
|
||||||
|
# Preserve existing semaphore behavior for serializing requests when requested
|
||||||
if self.allow_parallel_requests is False:
|
if self.allow_parallel_requests is False:
|
||||||
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
||||||
await self.sem.acquire()
|
await self.sem.acquire()
|
||||||
@@ -186,6 +241,7 @@ class Backend:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||||
|
|
||||||
done, pending = await wait(
|
done, pending = await wait(
|
||||||
[
|
[
|
||||||
create_task(make_request()),
|
create_task(make_request()),
|
||||||
@@ -253,8 +309,14 @@ class Backend:
|
|||||||
self.backend_errored(str(e))
|
self.backend_errored(str(e))
|
||||||
|
|
||||||
async def _start_tracking(self) -> None:
|
async def _start_tracking(self) -> None:
|
||||||
|
# Start the FIFO workers alongside existing loops
|
||||||
|
worker_tasks = tuple(self._worker() for _ in range(self._num_workers))
|
||||||
await gather(
|
await gather(
|
||||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
self.__read_logs(),
|
||||||
|
self.metrics._send_metrics_loop(),
|
||||||
|
self.__healthcheck(),
|
||||||
|
self.metrics._send_delete_requests_loop(),
|
||||||
|
*worker_tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
def backend_errored(self, msg: str) -> None:
|
def backend_errored(self, msg: str) -> None:
|
||||||
|
|||||||
+2
-3
@@ -190,12 +190,11 @@ class SystemMetrics:
|
|||||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||||
self.last_disk_usage = disk_usage
|
self.last_disk_usage = disk_usage
|
||||||
|
|
||||||
def reset(self, expected: float | None) -> None:
|
def reset(self):
|
||||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||||
# as well: they should send model_loading_time once when they are done loading
|
# as well: they should send model_loading_time once when they are done loading
|
||||||
if self.model_loading_time == expected:
|
self.model_loading_time = None
|
||||||
self.model_loading_time = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
+7
-13
@@ -200,13 +200,11 @@ class Metrics:
|
|||||||
|
|
||||||
async def __send_metrics_and_reset(self):
|
async def __send_metrics_and_reset(self):
|
||||||
|
|
||||||
loadtime_snapshot = self.system_metrics.model_loading_time
|
|
||||||
|
|
||||||
def compute_autoscaler_data() -> AutoScalerData:
|
def compute_autoscaler_data() -> AutoScalerData:
|
||||||
return AutoScalerData(
|
return AutoScalerData(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
version=self.version,
|
version=self.version,
|
||||||
loadtime=(loadtime_snapshot or 0.0),
|
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
||||||
new_load=self.model_metrics.workload_processing,
|
new_load=self.model_metrics.workload_processing,
|
||||||
cur_load=self.model_metrics.cur_load,
|
cur_load=self.model_metrics.cur_load,
|
||||||
rej_load=self.model_metrics.workload_rejected,
|
rej_load=self.model_metrics.workload_rejected,
|
||||||
@@ -254,15 +252,11 @@ class Metrics:
|
|||||||
|
|
||||||
self.system_metrics.update_disk_usage()
|
self.system_metrics.update_disk_usage()
|
||||||
|
|
||||||
sent = False
|
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
if await send_data(report_addr):
|
success = await send_data(report_addr)
|
||||||
sent = True
|
if success is True:
|
||||||
break
|
break
|
||||||
|
self.update_pending = False
|
||||||
if sent:
|
self.model_metrics.reset()
|
||||||
# clear the one-shot loadtime only if we actually sent *this* value
|
self.system_metrics.reset()
|
||||||
self.system_metrics.reset(expected=loadtime_snapshot)
|
self.last_metric_update = time.time()
|
||||||
self.update_pending = False
|
|
||||||
self.model_metrics.reset()
|
|
||||||
self.last_metric_update = time.time()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user