trying queue

This commit is contained in:
Lucas Armand
2025-10-27 17:34:37 -07:00
parent 9c795e2a01
commit 0471f6b219
+59 -8
View File
@@ -5,7 +5,7 @@ import base64
import subprocess import subprocess
import dataclasses import dataclasses
import logging import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task, get_running_loop
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool from distutils.util import strtobool
@@ -47,7 +47,7 @@ class Backend:
This class is responsible for: This class is responsible for:
1. Tailing logs and updating load time metrics 1. Tailing logs and updating load time metrics
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and 2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
sending the request. It also updates metrics as it makes those requests. sending the request. It also updates metrics as it makes those requests.
3. Running a benchmark from an EndpointHandler 3. Running a benchmark from an EndpointHandler
""" """
@@ -74,6 +74,11 @@ class Backend:
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False self.__start_healthcheck: bool = False
# NEW: FIFO queue + worker count
self.request_queue: "asyncio.Queue[tuple[EndpointHandler[ApiPayload_T], web.Request, asyncio.Future]]" = asyncio.Queue()
# If parallel allowed, let multiple workers drain the queue (order preserved by FIFO per worker; overall start order is FIFO).
self._num_workers: int = 1 if not self.allow_parallel_requests else int(os.environ.get("WORKERS", "4"))
@property @property
def pubkey(self) -> Optional[RSA.RsaKey]: def pubkey(self) -> Optional[RSA.RsaKey]:
if self._pubkey is None: if self._pubkey is None:
@@ -95,10 +100,15 @@ class Backend:
while True: while True:
handler, request, fut = await self.request_queue.get() handler, request, fut = await self.request_queue.get()
try: try:
res = await self.__process_request(handler, request) # Skip if already cancelled while waiting in the queue
fut.set_result(res) if fut.cancelled():
continue
res = await self.__process_enqueued_request(handler, request)
if not fut.cancelled():
fut.set_result(res)
except Exception as e: except Exception as e:
fut.set_exception(e) if not fut.cancelled():
fut.set_exception(e)
finally: finally:
self.request_queue.task_done() self.request_queue.task_done()
@@ -138,7 +148,36 @@ class Backend:
handler: EndpointHandler[ApiPayload_T], handler: EndpointHandler[ApiPayload_T],
request: web.Request, request: web.Request,
) -> Union[web.Response, web.StreamResponse]: ) -> Union[web.Response, web.StreamResponse]:
"""use this function to forward requests to the model endpoint""" """use this function to enqueue requests for FIFO processing"""
loop = get_running_loop()
fut: asyncio.Future = loop.create_future()
# If the client disconnects while waiting in the FIFO, cancel the future so the worker skips it
cancel_watch = create_task(request.wait_for_disconnection())
def _cancel_if_disconnected(_):
if not fut.done():
fut.cancel()
cancel_watch.add_done_callback(_cancel_if_disconnected)
try:
await self.request_queue.put((handler, request, fut))
return await fut
except asyncio.CancelledError:
# Propagate cancellation to ensure aiohttp doesn't expect a response body
raise
finally:
# Best-effort cleanup of the watcher
cancel_watch.cancel()
async def __process_enqueued_request(
self,
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""
This contains the original __handle_request logic and is invoked by workers,
ensuring FIFO execution via asyncio.Queue.
"""
try: try:
data = await request.json() data = await request.json()
auth_data, payload = handler.get_data_from_request(data) auth_data, payload = handler.get_data_from_request(data)
@@ -146,8 +185,11 @@ class Backend:
return web.json_response(data=e.message, status=422) return web.json_response(data=e.message, status=422)
except json.JSONDecodeError: except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") request_metrics: RequestMetrics = RequestMetrics(
request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created"
)
async def cancel_api_call_if_disconnected() -> web.Response: async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection() await request.wait_for_disconnection()
@@ -188,6 +230,8 @@ class Backend:
acquired = False acquired = False
try: try:
self.metrics._request_start(request_metrics) self.metrics._request_start(request_metrics)
# Preserve existing semaphore behavior for serializing requests when requested
if self.allow_parallel_requests is False: if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire() await self.sem.acquire()
@@ -197,6 +241,7 @@ class Backend:
) )
else: else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait( done, pending = await wait(
[ [
create_task(make_request()), create_task(make_request()),
@@ -264,8 +309,14 @@ class Backend:
self.backend_errored(str(e)) self.backend_errored(str(e))
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
# Start the FIFO workers alongside existing loops
worker_tasks = tuple(self._worker() for _ in range(self._num_workers))
await gather( await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop() self.__read_logs(),
self.metrics._send_metrics_loop(),
self.__healthcheck(),
self.metrics._send_delete_requests_loop(),
*worker_tasks,
) )
def backend_errored(self, msg: str) -> None: def backend_errored(self, msg: str) -> None: