Compare commits

..

1 Commits

Author SHA1 Message Date
Lucas Armand 0471f6b219 trying queue 2025-10-27 17:34:37 -07:00
8 changed files with 103 additions and 55 deletions
+91 -28
View File
@@ -5,7 +5,7 @@ import base64
import subprocess import subprocess
import dataclasses import dataclasses
import logging import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task, get_running_loop
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool from distutils.util import strtobool
@@ -47,7 +47,7 @@ class Backend:
This class is responsible for: This class is responsible for:
1. Tailing logs and updating load time metrics 1. Tailing logs and updating load time metrics
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and 2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
sending the request. It also updates metrics as it makes those requests. sending the request. It also updates metrics as it makes those requests.
3. Running a benchmark from an EndpointHandler 3. Running a benchmark from an EndpointHandler
""" """
@@ -66,9 +66,6 @@ class Backend:
unsecured: bool = dataclasses.field( unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
) )
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
@@ -77,6 +74,11 @@ class Backend:
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False self.__start_healthcheck: bool = False
# NEW: FIFO queue + worker count
self.request_queue: "asyncio.Queue[tuple[EndpointHandler[ApiPayload_T], web.Request, asyncio.Future]]" = asyncio.Queue()
# If parallel allowed, let multiple workers drain the queue (order preserved by FIFO per worker; overall start order is FIFO).
self._num_workers: int = 1 if not self.allow_parallel_requests else int(os.environ.get("WORKERS", "4"))
@property @property
def pubkey(self) -> Optional[RSA.RsaKey]: def pubkey(self) -> Optional[RSA.RsaKey]:
if self._pubkey is None: if self._pubkey is None:
@@ -94,6 +96,22 @@ class Backend:
timeout = ClientTimeout(total=None) timeout = ClientTimeout(total=None)
return ClientSession(self.model_server_url, timeout=timeout, connector=connector) return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
async def _worker(self):
while True:
handler, request, fut = await self.request_queue.get()
try:
# Skip if already cancelled while waiting in the queue
if fut.cancelled():
continue
res = await self.__process_enqueued_request(handler, request)
if not fut.cancelled():
fut.set_result(res)
except Exception as e:
if not fut.cancelled():
fut.set_exception(e)
finally:
self.request_queue.task_done()
def create_handler( def create_handler(
self, self,
handler: EndpointHandler[ApiPayload_T], handler: EndpointHandler[ApiPayload_T],
@@ -107,26 +125,59 @@ class Backend:
#######################################Private####################################### #######################################Private#######################################
def _fetch_pubkey(self): def _fetch_pubkey(self):
report_addr = self.report_addr.rstrip("/") command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"] result = subprocess.check_output(command, universal_newlines=True)
try: log.debug("public key:")
result = subprocess.check_output(command, universal_newlines=True) log.debug(result)
log.debug("public key:") key = None
log.debug(result) for _ in range(5):
key = RSA.import_key(result) try:
if key is not None: key = RSA.import_key(result)
return key break
except (ValueError , subprocess.CalledProcessError) as e: except ValueError as e:
log.debug(f"Error downloading key: {e}") log.debug(f"Error downloading key: {e}")
self.backend_errored("Failed to get autoscaler pubkey") time.sleep(15)
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
async def __handle_request( async def __handle_request(
self, self,
handler: EndpointHandler[ApiPayload_T], handler: EndpointHandler[ApiPayload_T],
request: web.Request, request: web.Request,
) -> Union[web.Response, web.StreamResponse]: ) -> Union[web.Response, web.StreamResponse]:
"""use this function to forward requests to the model endpoint""" """use this function to enqueue requests for FIFO processing"""
loop = get_running_loop()
fut: asyncio.Future = loop.create_future()
# If the client disconnects while waiting in the FIFO, cancel the future so the worker skips it
cancel_watch = create_task(request.wait_for_disconnection())
def _cancel_if_disconnected(_):
if not fut.done():
fut.cancel()
cancel_watch.add_done_callback(_cancel_if_disconnected)
try:
await self.request_queue.put((handler, request, fut))
return await fut
except asyncio.CancelledError:
# Propagate cancellation to ensure aiohttp doesn't expect a response body
raise
finally:
# Best-effort cleanup of the watcher
cancel_watch.cancel()
async def __process_enqueued_request(
self,
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""
This contains the original __handle_request logic and is invoked by workers,
ensuring FIFO execution via asyncio.Queue.
"""
try: try:
data = await request.json() data = await request.json()
auth_data, payload = handler.get_data_from_request(data) auth_data, payload = handler.get_data_from_request(data)
@@ -134,8 +185,11 @@ class Backend:
return web.json_response(data=e.message, status=422) return web.json_response(data=e.message, status=422)
except json.JSONDecodeError: except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") request_metrics: RequestMetrics = RequestMetrics(
request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created"
)
async def cancel_api_call_if_disconnected() -> web.Response: async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection() await request.wait_for_disconnection()
@@ -176,6 +230,8 @@ class Backend:
acquired = False acquired = False
try: try:
self.metrics._request_start(request_metrics) self.metrics._request_start(request_metrics)
# Preserve existing semaphore behavior for serializing requests when requested
if self.allow_parallel_requests is False: if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire() await self.sem.acquire()
@@ -185,6 +241,7 @@ class Backend:
) )
else: else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait( done, pending = await wait(
[ [
create_task(make_request()), create_task(make_request()),
@@ -252,8 +309,14 @@ class Backend:
self.backend_errored(str(e)) self.backend_errored(str(e))
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
# Start the FIFO workers alongside existing loops
worker_tasks = tuple(self._worker() for _ in range(self._num_workers))
await gather( await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop() self.__read_logs(),
self.metrics._send_metrics_loop(),
self.__healthcheck(),
self.metrics._send_delete_requests_loop(),
*worker_tasks,
) )
def backend_errored(self, msg: str) -> None: def backend_errored(self, msg: str) -> None:
@@ -285,7 +348,7 @@ class Backend:
message = { message = {
key: value key: value
for (key, value) in (dataclasses.asdict(auth_data).items()) for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature" and key != "__request_id" if key != "signature"
} }
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN): if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug( log.debug(
@@ -295,7 +358,7 @@ class Backend:
elif message in self.msg_history: elif message in self.msg_history:
log.debug(f"message: {message} already in message history") log.debug(f"message: {message} already in message history")
return False return False
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature): elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum) self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message) self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:] self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
@@ -314,10 +377,10 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f: with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark") log.debug("already ran benchmark")
# trigger model load # trigger model load
# payload = self.benchmark_handler.make_benchmark_payload() payload = self.benchmark_handler.make_benchmark_payload()
# _ = await self.__call_api( _ = await self.__call_api(
# handler=self.benchmark_handler, payload=payload handler=self.benchmark_handler, payload=payload
# ) )
return float(f.readline()) return float(f.readline())
except FileNotFoundError: except FileNotFoundError:
pass pass
@@ -392,7 +455,7 @@ class Backend:
) )
# some backends need a few seconds after logging successful startup before # some backends need a few seconds after logging successful startup before
# they can begin accepting requests # they can begin accepting requests
# await sleep(5) await sleep(5)
try: try:
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
self.__start_healthcheck = True self.__start_healthcheck = True
+4 -5
View File
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
class AuthData: class AuthData:
"""data used to authenticate requester""" """data used to authenticate requester"""
signature: str
cost: str cost: str
endpoint: str endpoint: str
reqnum: int reqnum: int
request_idx: int
signature: str
url: str url: str
request_idx: int
@classmethod @classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]): def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -190,12 +190,11 @@ class SystemMetrics:
self.additional_disk_usage = disk_usage - self.last_disk_usage self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage self.last_disk_usage = disk_usage
def reset(self, expected: float | None) -> None: def reset(self):
# autoscaler excepts model_loading_time to be populated only once, when the instance has # autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances # finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading # as well: they should send model_loading_time once when they are done loading
if self.model_loading_time == expected: self.model_loading_time = None
self.model_loading_time = None
@dataclass @dataclass
+7 -17
View File
@@ -180,10 +180,6 @@ class Metrics:
return # nothing to do return # nothing to do
for report_addr in self.report_addr: for report_addr in self.report_addr:
# TODO: Add a Redis subscriber queue for delete_requests
if report_addr == "https://cloud.vast.ai/api/v0":
# Patch: ignore the Redis API report_addr
continue
sent_success = True sent_success = True
sent_failed = True sent_failed = True
@@ -204,13 +200,11 @@ class Metrics:
async def __send_metrics_and_reset(self): async def __send_metrics_and_reset(self):
loadtime_snapshot = self.system_metrics.model_loading_time
def compute_autoscaler_data() -> AutoScalerData: def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData( return AutoScalerData(
id=self.id, id=self.id,
version=self.version, version=self.version,
loadtime=(loadtime_snapshot or 0.0), loadtime=(self.system_metrics.model_loading_time or 0.0),
new_load=self.model_metrics.workload_processing, new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load, cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected, rej_load=self.model_metrics.workload_rejected,
@@ -258,15 +252,11 @@ class Metrics:
self.system_metrics.update_disk_usage() self.system_metrics.update_disk_usage()
sent = False
for report_addr in self.report_addr: for report_addr in self.report_addr:
if await send_data(report_addr): success = await send_data(report_addr)
sent = True if success is True:
break break
self.update_pending = False
if sent: self.model_metrics.reset()
# clear the one-shot loadtime only if we actually sent *this* value self.system_metrics.reset()
self.system_metrics.reset(expected=loadtime_snapshot) self.last_metric_update = time.time()
self.update_pending = False
self.model_metrics.reset()
self.last_metric_update = time.time()
+1 -1
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log" DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log" PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}" REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}" USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}" WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR" mkdir -p "$WORKSPACE_DIR"
-1
View File
@@ -98,7 +98,6 @@ def call_text2image_workflow(
endpoint=route_response["endpoint"], endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"], reqnum=route_response["reqnum"],
url=route_response["url"], url=route_response["url"],
request_idx=route_response["request_idx"],
) )
# Build the payload for the worker request # Build the payload for the worker request
-1
View File
@@ -82,7 +82,6 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"], endpoint=message["endpoint"],
reqnum=message["reqnum"], reqnum=message["reqnum"],
url=message["url"], url=message["url"],
request_idx=message["request_idx"],
) )
workflow = { workflow = {
"3": { "3": {
-1
View File
@@ -43,7 +43,6 @@ backend = Backend(
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
], ],
], ],
max_wait_time=600
) )
-1
View File
@@ -113,7 +113,6 @@ backend = Backend(
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
], ],
], ],
max_wait_time=600
) )