Compare commits

..

5 Commits

Author SHA1 Message Date
Lucas Armand e09f1fa953 patch for redis queue 2025-10-28 16:03:50 -07:00
edgaratvast ba6f1c2e4b Fix signature (#50)
* change order of fields in auth_data to match autoscaler for signature verification

* also ignore __request_id

* Revert "change order of fields in auth_data to match autoscaler for signature verification" so that it's alphabetical again

This reverts commit b8223879c9.

* enforce alphabetical json dumping of message for signature verification

---------

Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-28 16:01:32 -07:00
edgaratvast 298590fb88 Merge pull request #45 from vast-ai/new-pyworker
New PyWorker
2025-10-28 14:02:53 -07:00
Lucas Armand 814c3acd4c remove unused code 2025-10-28 13:43:57 -07:00
Lucas Armand 22bca74087 Prevent load time race 2025-10-27 18:25:21 -07:00
3 changed files with 29 additions and 80 deletions
+7 -69
View File
@@ -5,7 +5,7 @@ import base64
import subprocess import subprocess
import dataclasses import dataclasses
import logging import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task, get_running_loop from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool from distutils.util import strtobool
@@ -47,7 +47,7 @@ class Backend:
This class is responsible for: This class is responsible for:
1. Tailing logs and updating load time metrics 1. Tailing logs and updating load time metrics
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and 2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
sending the request. It also updates metrics as it makes those requests. sending the request. It also updates metrics as it makes those requests.
3. Running a benchmark from an EndpointHandler 3. Running a benchmark from an EndpointHandler
""" """
@@ -74,11 +74,6 @@ class Backend:
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False self.__start_healthcheck: bool = False
# NEW: FIFO queue + worker count
self.request_queue: "asyncio.Queue[tuple[EndpointHandler[ApiPayload_T], web.Request, asyncio.Future]]" = asyncio.Queue()
# If parallel allowed, let multiple workers drain the queue (order preserved by FIFO per worker; overall start order is FIFO).
self._num_workers: int = 1 if not self.allow_parallel_requests else int(os.environ.get("WORKERS", "4"))
@property @property
def pubkey(self) -> Optional[RSA.RsaKey]: def pubkey(self) -> Optional[RSA.RsaKey]:
if self._pubkey is None: if self._pubkey is None:
@@ -96,22 +91,6 @@ class Backend:
timeout = ClientTimeout(total=None) timeout = ClientTimeout(total=None)
return ClientSession(self.model_server_url, timeout=timeout, connector=connector) return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
async def _worker(self):
while True:
handler, request, fut = await self.request_queue.get()
try:
# Skip if already cancelled while waiting in the queue
if fut.cancelled():
continue
res = await self.__process_enqueued_request(handler, request)
if not fut.cancelled():
fut.set_result(res)
except Exception as e:
if not fut.cancelled():
fut.set_exception(e)
finally:
self.request_queue.task_done()
def create_handler( def create_handler(
self, self,
handler: EndpointHandler[ApiPayload_T], handler: EndpointHandler[ApiPayload_T],
@@ -148,36 +127,7 @@ class Backend:
handler: EndpointHandler[ApiPayload_T], handler: EndpointHandler[ApiPayload_T],
request: web.Request, request: web.Request,
) -> Union[web.Response, web.StreamResponse]: ) -> Union[web.Response, web.StreamResponse]:
"""use this function to enqueue requests for FIFO processing""" """use this function to forward requests to the model endpoint"""
loop = get_running_loop()
fut: asyncio.Future = loop.create_future()
# If the client disconnects while waiting in the FIFO, cancel the future so the worker skips it
cancel_watch = create_task(request.wait_for_disconnection())
def _cancel_if_disconnected(_):
if not fut.done():
fut.cancel()
cancel_watch.add_done_callback(_cancel_if_disconnected)
try:
await self.request_queue.put((handler, request, fut))
return await fut
except asyncio.CancelledError:
# Propagate cancellation to ensure aiohttp doesn't expect a response body
raise
finally:
# Best-effort cleanup of the watcher
cancel_watch.cancel()
async def __process_enqueued_request(
self,
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""
This contains the original __handle_request logic and is invoked by workers,
ensuring FIFO execution via asyncio.Queue.
"""
try: try:
data = await request.json() data = await request.json()
auth_data, payload = handler.get_data_from_request(data) auth_data, payload = handler.get_data_from_request(data)
@@ -185,11 +135,8 @@ class Backend:
return web.json_response(data=e.message, status=422) return web.json_response(data=e.message, status=422)
except json.JSONDecodeError: except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics( request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created"
)
async def cancel_api_call_if_disconnected() -> web.Response: async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection() await request.wait_for_disconnection()
@@ -230,8 +177,6 @@ class Backend:
acquired = False acquired = False
try: try:
self.metrics._request_start(request_metrics) self.metrics._request_start(request_metrics)
# Preserve existing semaphore behavior for serializing requests when requested
if self.allow_parallel_requests is False: if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire() await self.sem.acquire()
@@ -241,7 +186,6 @@ class Backend:
) )
else: else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait( done, pending = await wait(
[ [
create_task(make_request()), create_task(make_request()),
@@ -309,14 +253,8 @@ class Backend:
self.backend_errored(str(e)) self.backend_errored(str(e))
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
# Start the FIFO workers alongside existing loops
worker_tasks = tuple(self._worker() for _ in range(self._num_workers))
await gather( await gather(
self.__read_logs(), self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
self.metrics._send_metrics_loop(),
self.__healthcheck(),
self.metrics._send_delete_requests_loop(),
*worker_tasks,
) )
def backend_errored(self, msg: str) -> None: def backend_errored(self, msg: str) -> None:
@@ -348,7 +286,7 @@ class Backend:
message = { message = {
key: value key: value
for (key, value) in (dataclasses.asdict(auth_data).items()) for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature" if key != "signature" and key != "__request_id"
} }
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN): if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug( log.debug(
@@ -358,7 +296,7 @@ class Backend:
elif message in self.msg_history: elif message in self.msg_history:
log.debug(f"message: {message} already in message history") log.debug(f"message: {message} already in message history")
return False return False
elif verify_signature(json.dumps(message, indent=4), auth_data.signature): elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum) self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message) self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:] self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
+5 -4
View File
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
class AuthData: class AuthData:
"""data used to authenticate requester""" """data used to authenticate requester"""
signature: str
cost: str cost: str
endpoint: str endpoint: str
reqnum: int reqnum: int
url: str
request_idx: int request_idx: int
signature: str
url: str
@classmethod @classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]): def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -190,11 +190,12 @@ class SystemMetrics:
self.additional_disk_usage = disk_usage - self.last_disk_usage self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage self.last_disk_usage = disk_usage
def reset(self): def reset(self, expected: float | None) -> None:
# autoscaler excepts model_loading_time to be populated only once, when the instance has # autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances # finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading # as well: they should send model_loading_time once when they are done loading
self.model_loading_time = None if self.model_loading_time == expected:
self.model_loading_time = None
@dataclass @dataclass
+17 -7
View File
@@ -180,6 +180,10 @@ class Metrics:
return # nothing to do return # nothing to do
for report_addr in self.report_addr: for report_addr in self.report_addr:
# TODO: Add a Redis subscriber queue for delete_requests
if report_addr == "https://cloud.vast.ai/api/v0":
# Patch: ignore the Redis API report_addr
continue
sent_success = True sent_success = True
sent_failed = True sent_failed = True
@@ -200,11 +204,13 @@ class Metrics:
async def __send_metrics_and_reset(self): async def __send_metrics_and_reset(self):
loadtime_snapshot = self.system_metrics.model_loading_time
def compute_autoscaler_data() -> AutoScalerData: def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData( return AutoScalerData(
id=self.id, id=self.id,
version=self.version, version=self.version,
loadtime=(self.system_metrics.model_loading_time or 0.0), loadtime=(loadtime_snapshot or 0.0),
new_load=self.model_metrics.workload_processing, new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load, cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected, rej_load=self.model_metrics.workload_rejected,
@@ -252,11 +258,15 @@ class Metrics:
self.system_metrics.update_disk_usage() self.system_metrics.update_disk_usage()
sent = False
for report_addr in self.report_addr: for report_addr in self.report_addr:
success = await send_data(report_addr) if await send_data(report_addr):
if success is True: sent = True
break break
self.update_pending = False
self.model_metrics.reset() if sent:
self.system_metrics.reset() # clear the one-shot loadtime only if we actually sent *this* value
self.last_metric_update = time.time() self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False
self.model_metrics.reset()
self.last_metric_update = time.time()