From 0471f6b219cf621cd197b85f6729fc607c6e9866 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Mon, 27 Oct 2025 17:34:37 -0700 Subject: [PATCH] trying queue --- lib/backend.py | 67 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 8 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 8f0cae1..d0a14ed 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -5,7 +5,7 @@ import base64 import subprocess import dataclasses import logging -from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task +from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task, get_running_loop from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from functools import cached_property from distutils.util import strtobool @@ -47,7 +47,7 @@ class Backend: This class is responsible for: 1. Tailing logs and updating load time metrics 2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and - sending the request. It also updates metrics as it makes those requests. + sending the request. It also updates metrics as it makes those requests. 3. Running a benchmark from an EndpointHandler """ @@ -74,6 +74,11 @@ class Backend: self._pubkey = self._fetch_pubkey() self.__start_healthcheck: bool = False + # NEW: FIFO queue + worker count + self.request_queue: "asyncio.Queue[tuple[EndpointHandler[ApiPayload_T], web.Request, asyncio.Future]]" = asyncio.Queue() + # If parallel allowed, let multiple workers drain the queue (order preserved by FIFO per worker; overall start order is FIFO). + self._num_workers: int = 1 if not self.allow_parallel_requests else int(os.environ.get("WORKERS", "4")) + @property def pubkey(self) -> Optional[RSA.RsaKey]: if self._pubkey is None: @@ -95,10 +100,15 @@ class Backend: while True: handler, request, fut = await self.request_queue.get() try: - res = await self.__process_request(handler, request) - fut.set_result(res) + # Skip if already cancelled while waiting in the queue + if fut.cancelled(): + continue + res = await self.__process_enqueued_request(handler, request) + if not fut.cancelled(): + fut.set_result(res) except Exception as e: - fut.set_exception(e) + if not fut.cancelled(): + fut.set_exception(e) finally: self.request_queue.task_done() @@ -138,7 +148,36 @@ class Backend: handler: EndpointHandler[ApiPayload_T], request: web.Request, ) -> Union[web.Response, web.StreamResponse]: - """use this function to forward requests to the model endpoint""" + """use this function to enqueue requests for FIFO processing""" + loop = get_running_loop() + fut: asyncio.Future = loop.create_future() + + # If the client disconnects while waiting in the FIFO, cancel the future so the worker skips it + cancel_watch = create_task(request.wait_for_disconnection()) + def _cancel_if_disconnected(_): + if not fut.done(): + fut.cancel() + cancel_watch.add_done_callback(_cancel_if_disconnected) + + try: + await self.request_queue.put((handler, request, fut)) + return await fut + except asyncio.CancelledError: + # Propagate cancellation to ensure aiohttp doesn't expect a response body + raise + finally: + # Best-effort cleanup of the watcher + cancel_watch.cancel() + + async def __process_enqueued_request( + self, + handler: EndpointHandler[ApiPayload_T], + request: web.Request, + ) -> Union[web.Response, web.StreamResponse]: + """ + This contains the original __handle_request logic and is invoked by workers, + ensuring FIFO execution via asyncio.Queue. + """ try: data = await request.json() auth_data, payload = handler.get_data_from_request(data) @@ -146,8 +185,11 @@ class Backend: return web.json_response(data=e.message, status=422) except json.JSONDecodeError: return web.json_response(dict(error="invalid JSON"), status=422) + workload = payload.count_workload() - request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") + request_metrics: RequestMetrics = RequestMetrics( + request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created" + ) async def cancel_api_call_if_disconnected() -> web.Response: await request.wait_for_disconnection() @@ -188,6 +230,8 @@ class Backend: acquired = False try: self.metrics._request_start(request_metrics) + + # Preserve existing semaphore behavior for serializing requests when requested if self.allow_parallel_requests is False: log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") await self.sem.acquire() @@ -197,6 +241,7 @@ class Backend: ) else: log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") + done, pending = await wait( [ create_task(make_request()), @@ -264,8 +309,14 @@ class Backend: self.backend_errored(str(e)) async def _start_tracking(self) -> None: + # Start the FIFO workers alongside existing loops + worker_tasks = tuple(self._worker() for _ in range(self._num_workers)) await gather( - self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop() + self.__read_logs(), + self.metrics._send_metrics_loop(), + self.__healthcheck(), + self.metrics._send_delete_requests_loop(), + *worker_tasks, ) def backend_errored(self, msg: str) -> None: