Compare commits

..

1 Commits

Author SHA1 Message Date
Lucas Armand 0471f6b219 trying queue 2025-10-27 17:34:37 -07:00
4 changed files with 81 additions and 30 deletions
+68 -6
View File
@@ -5,7 +5,7 @@ import base64
import subprocess
import dataclasses
import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task, get_running_loop
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property
from distutils.util import strtobool
@@ -74,6 +74,11 @@ class Backend:
self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False
# NEW: FIFO queue + worker count
self.request_queue: "asyncio.Queue[tuple[EndpointHandler[ApiPayload_T], web.Request, asyncio.Future]]" = asyncio.Queue()
# If parallel allowed, let multiple workers drain the queue (order preserved by FIFO per worker; overall start order is FIFO).
self._num_workers: int = 1 if not self.allow_parallel_requests else int(os.environ.get("WORKERS", "4"))
@property
def pubkey(self) -> Optional[RSA.RsaKey]:
if self._pubkey is None:
@@ -91,6 +96,22 @@ class Backend:
timeout = ClientTimeout(total=None)
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
async def _worker(self):
while True:
handler, request, fut = await self.request_queue.get()
try:
# Skip if already cancelled while waiting in the queue
if fut.cancelled():
continue
res = await self.__process_enqueued_request(handler, request)
if not fut.cancelled():
fut.set_result(res)
except Exception as e:
if not fut.cancelled():
fut.set_exception(e)
finally:
self.request_queue.task_done()
def create_handler(
self,
handler: EndpointHandler[ApiPayload_T],
@@ -127,7 +148,36 @@ class Backend:
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""use this function to forward requests to the model endpoint"""
"""use this function to enqueue requests for FIFO processing"""
loop = get_running_loop()
fut: asyncio.Future = loop.create_future()
# If the client disconnects while waiting in the FIFO, cancel the future so the worker skips it
cancel_watch = create_task(request.wait_for_disconnection())
def _cancel_if_disconnected(_):
if not fut.done():
fut.cancel()
cancel_watch.add_done_callback(_cancel_if_disconnected)
try:
await self.request_queue.put((handler, request, fut))
return await fut
except asyncio.CancelledError:
# Propagate cancellation to ensure aiohttp doesn't expect a response body
raise
finally:
# Best-effort cleanup of the watcher
cancel_watch.cancel()
async def __process_enqueued_request(
self,
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""
This contains the original __handle_request logic and is invoked by workers,
ensuring FIFO execution via asyncio.Queue.
"""
try:
data = await request.json()
auth_data, payload = handler.get_data_from_request(data)
@@ -135,8 +185,11 @@ class Backend:
return web.json_response(data=e.message, status=422)
except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
request_metrics: RequestMetrics = RequestMetrics(
request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created"
)
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
@@ -177,6 +230,8 @@ class Backend:
acquired = False
try:
self.metrics._request_start(request_metrics)
# Preserve existing semaphore behavior for serializing requests when requested
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
@@ -186,6 +241,7 @@ class Backend:
)
else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait(
[
create_task(make_request()),
@@ -253,8 +309,14 @@ class Backend:
self.backend_errored(str(e))
async def _start_tracking(self) -> None:
# Start the FIFO workers alongside existing loops
worker_tasks = tuple(self._worker() for _ in range(self._num_workers))
await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
self.__read_logs(),
self.metrics._send_metrics_loop(),
self.__healthcheck(),
self.metrics._send_delete_requests_loop(),
*worker_tasks,
)
def backend_errored(self, msg: str) -> None:
@@ -286,7 +348,7 @@ class Backend:
message = {
key: value
for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature" and key != "__request_id"
if key != "signature"
}
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug(
@@ -296,7 +358,7 @@ class Backend:
elif message in self.msg_history:
log.debug(f"message: {message} already in message history")
return False
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
+3 -4
View File
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
class AuthData:
"""data used to authenticate requester"""
signature: str
cost: str
endpoint: str
reqnum: int
request_idx: int
signature: str
url: str
request_idx: int
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -190,11 +190,10 @@ class SystemMetrics:
self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage
def reset(self, expected: float | None) -> None:
def reset(self):
# autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading
if self.model_loading_time == expected:
self.model_loading_time = None
+4 -14
View File
@@ -180,10 +180,6 @@ class Metrics:
return # nothing to do
for report_addr in self.report_addr:
# TODO: Add a Redis subscriber queue for delete_requests
if report_addr == "https://cloud.vast.ai/api/v0":
# Patch: ignore the Redis API report_addr
continue
sent_success = True
sent_failed = True
@@ -204,13 +200,11 @@ class Metrics:
async def __send_metrics_and_reset(self):
loadtime_snapshot = self.system_metrics.model_loading_time
def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData(
id=self.id,
version=self.version,
loadtime=(loadtime_snapshot or 0.0),
loadtime=(self.system_metrics.model_loading_time or 0.0),
new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected,
@@ -258,15 +252,11 @@ class Metrics:
self.system_metrics.update_disk_usage()
sent = False
for report_addr in self.report_addr:
if await send_data(report_addr):
sent = True
success = await send_data(report_addr)
if success is True:
break
if sent:
# clear the one-shot loadtime only if we actually sent *this* value
self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False
self.model_metrics.reset()
self.system_metrics.reset()
self.last_metric_update = time.time()
+1 -1
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR"