Compare commits

..

17 Commits

Author SHA1 Message Date
Lucas Armand c9d701e8d3 increase wait time for llm backends 2025-11-03 16:21:56 -08:00
Colter-Downing ec2ac0a21a Merge pull request #52 from vast-ai/remove-sleeps-and-delays
Remove sleeps and delays
2025-10-30 11:53:39 -07:00
Abiola Akinnubi 2cde573c56 Merge pull request #48 from vast-ai/comfy-request-idx
Added request_idx to comfy auth_data
2025-10-30 11:27:35 -07:00
Abiola Akinnubi b2e4a5db0c Merge pull request #49 from vast-ai/unsecure_report_addr
Added caller for REPORT_ADDR to backend.py to use the report add
2025-10-30 10:39:46 -07:00
Abiola Akinnubi 7437028cb2 Added caller for REPORT_ADDR to backend.py 2025-10-29 18:02:17 -07:00
edgaratvast 02c8307af7 remove redis pubsub from pyworker (#53)
Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-29 17:07:56 -07:00
Colter Downing 7c0f316eeb leave the env vars alone! 2025-10-29 11:36:46 -07:00
Colter Downing b4025a744f remove env var writing 2025-10-29 09:58:09 -07:00
Colter Downing d190308329 removed 5 sec sleep and warmup request on load 2025-10-29 09:57:46 -07:00
LucasArmandVast 9f5a432513 Merge pull request #51 from vast-ai/delete-reqs-hotfix
Redis subscriber queue patch
2025-10-28 16:07:28 -07:00
Lucas Armand e09f1fa953 patch for redis queue 2025-10-28 16:03:50 -07:00
edgaratvast ba6f1c2e4b Fix signature (#50)
* change order of fields in auth_data to match autoscaler for signature verification

* also ignore __request_id

* Revert "change order of fields in auth_data to match autoscaler for signature verification" so that it's alphabetical again

This reverts commit b8223879c9.

* enforce alphabetical json dumping of message for signature verification

---------

Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-28 16:01:32 -07:00
Abiola Akinnubi 944f83fc03 Removed extra spaces from operator assignment 2025-10-28 21:03:52 +00:00
edgaratvast 298590fb88 Merge pull request #45 from vast-ai/new-pyworker
New PyWorker
2025-10-28 14:02:53 -07:00
Lucas Armand 814c3acd4c remove unused code 2025-10-28 13:43:57 -07:00
Lucas Armand 22bca74087 Prevent load time race 2025-10-27 18:25:21 -07:00
Abiola Akinnubi f56bbc0ebe Added request_idx to comfy auth_data 2025-10-27 03:17:06 +00:00
8 changed files with 55 additions and 103 deletions
+28 -91
View File
@@ -5,7 +5,7 @@ import base64
import subprocess import subprocess
import dataclasses import dataclasses
import logging import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task, get_running_loop from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool from distutils.util import strtobool
@@ -47,7 +47,7 @@ class Backend:
This class is responsible for: This class is responsible for:
1. Tailing logs and updating load time metrics 1. Tailing logs and updating load time metrics
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and 2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
sending the request. It also updates metrics as it makes those requests. sending the request. It also updates metrics as it makes those requests.
3. Running a benchmark from an EndpointHandler 3. Running a benchmark from an EndpointHandler
""" """
@@ -66,6 +66,9 @@ class Backend:
unsecured: bool = dataclasses.field( unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
) )
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
@@ -74,11 +77,6 @@ class Backend:
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False self.__start_healthcheck: bool = False
# NEW: FIFO queue + worker count
self.request_queue: "asyncio.Queue[tuple[EndpointHandler[ApiPayload_T], web.Request, asyncio.Future]]" = asyncio.Queue()
# If parallel allowed, let multiple workers drain the queue (order preserved by FIFO per worker; overall start order is FIFO).
self._num_workers: int = 1 if not self.allow_parallel_requests else int(os.environ.get("WORKERS", "4"))
@property @property
def pubkey(self) -> Optional[RSA.RsaKey]: def pubkey(self) -> Optional[RSA.RsaKey]:
if self._pubkey is None: if self._pubkey is None:
@@ -96,22 +94,6 @@ class Backend:
timeout = ClientTimeout(total=None) timeout = ClientTimeout(total=None)
return ClientSession(self.model_server_url, timeout=timeout, connector=connector) return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
async def _worker(self):
while True:
handler, request, fut = await self.request_queue.get()
try:
# Skip if already cancelled while waiting in the queue
if fut.cancelled():
continue
res = await self.__process_enqueued_request(handler, request)
if not fut.cancelled():
fut.set_result(res)
except Exception as e:
if not fut.cancelled():
fut.set_exception(e)
finally:
self.request_queue.task_done()
def create_handler( def create_handler(
self, self,
handler: EndpointHandler[ApiPayload_T], handler: EndpointHandler[ApiPayload_T],
@@ -125,59 +107,26 @@ class Backend:
#######################################Private####################################### #######################################Private#######################################
def _fetch_pubkey(self): def _fetch_pubkey(self):
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"] report_addr = self.report_addr.rstrip("/")
result = subprocess.check_output(command, universal_newlines=True) command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
log.debug("public key:") try:
log.debug(result) result = subprocess.check_output(command, universal_newlines=True)
key = None log.debug("public key:")
for _ in range(5): log.debug(result)
try: key = RSA.import_key(result)
key = RSA.import_key(result) if key is not None:
break return key
except ValueError as e: except (ValueError , subprocess.CalledProcessError) as e:
log.debug(f"Error downloading key: {e}") log.debug(f"Error downloading key: {e}")
time.sleep(15) self.backend_errored("Failed to get autoscaler pubkey")
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
async def __handle_request( async def __handle_request(
self, self,
handler: EndpointHandler[ApiPayload_T], handler: EndpointHandler[ApiPayload_T],
request: web.Request, request: web.Request,
) -> Union[web.Response, web.StreamResponse]: ) -> Union[web.Response, web.StreamResponse]:
"""use this function to enqueue requests for FIFO processing""" """use this function to forward requests to the model endpoint"""
loop = get_running_loop()
fut: asyncio.Future = loop.create_future()
# If the client disconnects while waiting in the FIFO, cancel the future so the worker skips it
cancel_watch = create_task(request.wait_for_disconnection())
def _cancel_if_disconnected(_):
if not fut.done():
fut.cancel()
cancel_watch.add_done_callback(_cancel_if_disconnected)
try:
await self.request_queue.put((handler, request, fut))
return await fut
except asyncio.CancelledError:
# Propagate cancellation to ensure aiohttp doesn't expect a response body
raise
finally:
# Best-effort cleanup of the watcher
cancel_watch.cancel()
async def __process_enqueued_request(
self,
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""
This contains the original __handle_request logic and is invoked by workers,
ensuring FIFO execution via asyncio.Queue.
"""
try: try:
data = await request.json() data = await request.json()
auth_data, payload = handler.get_data_from_request(data) auth_data, payload = handler.get_data_from_request(data)
@@ -185,11 +134,8 @@ class Backend:
return web.json_response(data=e.message, status=422) return web.json_response(data=e.message, status=422)
except json.JSONDecodeError: except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics( request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created"
)
async def cancel_api_call_if_disconnected() -> web.Response: async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection() await request.wait_for_disconnection()
@@ -230,8 +176,6 @@ class Backend:
acquired = False acquired = False
try: try:
self.metrics._request_start(request_metrics) self.metrics._request_start(request_metrics)
# Preserve existing semaphore behavior for serializing requests when requested
if self.allow_parallel_requests is False: if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire() await self.sem.acquire()
@@ -241,7 +185,6 @@ class Backend:
) )
else: else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait( done, pending = await wait(
[ [
create_task(make_request()), create_task(make_request()),
@@ -309,14 +252,8 @@ class Backend:
self.backend_errored(str(e)) self.backend_errored(str(e))
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
# Start the FIFO workers alongside existing loops
worker_tasks = tuple(self._worker() for _ in range(self._num_workers))
await gather( await gather(
self.__read_logs(), self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
self.metrics._send_metrics_loop(),
self.__healthcheck(),
self.metrics._send_delete_requests_loop(),
*worker_tasks,
) )
def backend_errored(self, msg: str) -> None: def backend_errored(self, msg: str) -> None:
@@ -348,7 +285,7 @@ class Backend:
message = { message = {
key: value key: value
for (key, value) in (dataclasses.asdict(auth_data).items()) for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature" if key != "signature" and key != "__request_id"
} }
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN): if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug( log.debug(
@@ -358,7 +295,7 @@ class Backend:
elif message in self.msg_history: elif message in self.msg_history:
log.debug(f"message: {message} already in message history") log.debug(f"message: {message} already in message history")
return False return False
elif verify_signature(json.dumps(message, indent=4), auth_data.signature): elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum) self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message) self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:] self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
@@ -377,10 +314,10 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f: with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark") log.debug("already ran benchmark")
# trigger model load # trigger model load
payload = self.benchmark_handler.make_benchmark_payload() # payload = self.benchmark_handler.make_benchmark_payload()
_ = await self.__call_api( # _ = await self.__call_api(
handler=self.benchmark_handler, payload=payload # handler=self.benchmark_handler, payload=payload
) # )
return float(f.readline()) return float(f.readline())
except FileNotFoundError: except FileNotFoundError:
pass pass
@@ -455,7 +392,7 @@ class Backend:
) )
# some backends need a few seconds after logging successful startup before # some backends need a few seconds after logging successful startup before
# they can begin accepting requests # they can begin accepting requests
await sleep(5) # await sleep(5)
try: try:
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
self.__start_healthcheck = True self.__start_healthcheck = True
+5 -4
View File
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
class AuthData: class AuthData:
"""data used to authenticate requester""" """data used to authenticate requester"""
signature: str
cost: str cost: str
endpoint: str endpoint: str
reqnum: int reqnum: int
url: str
request_idx: int request_idx: int
signature: str
url: str
@classmethod @classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]): def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -190,11 +190,12 @@ class SystemMetrics:
self.additional_disk_usage = disk_usage - self.last_disk_usage self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage self.last_disk_usage = disk_usage
def reset(self): def reset(self, expected: float | None) -> None:
# autoscaler excepts model_loading_time to be populated only once, when the instance has # autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances # finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading # as well: they should send model_loading_time once when they are done loading
self.model_loading_time = None if self.model_loading_time == expected:
self.model_loading_time = None
@dataclass @dataclass
+17 -7
View File
@@ -180,6 +180,10 @@ class Metrics:
return # nothing to do return # nothing to do
for report_addr in self.report_addr: for report_addr in self.report_addr:
# TODO: Add a Redis subscriber queue for delete_requests
if report_addr == "https://cloud.vast.ai/api/v0":
# Patch: ignore the Redis API report_addr
continue
sent_success = True sent_success = True
sent_failed = True sent_failed = True
@@ -200,11 +204,13 @@ class Metrics:
async def __send_metrics_and_reset(self): async def __send_metrics_and_reset(self):
loadtime_snapshot = self.system_metrics.model_loading_time
def compute_autoscaler_data() -> AutoScalerData: def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData( return AutoScalerData(
id=self.id, id=self.id,
version=self.version, version=self.version,
loadtime=(self.system_metrics.model_loading_time or 0.0), loadtime=(loadtime_snapshot or 0.0),
new_load=self.model_metrics.workload_processing, new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load, cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected, rej_load=self.model_metrics.workload_rejected,
@@ -252,11 +258,15 @@ class Metrics:
self.system_metrics.update_disk_usage() self.system_metrics.update_disk_usage()
sent = False
for report_addr in self.report_addr: for report_addr in self.report_addr:
success = await send_data(report_addr) if await send_data(report_addr):
if success is True: sent = True
break break
self.update_pending = False
self.model_metrics.reset() if sent:
self.system_metrics.reset() # clear the one-shot loadtime only if we actually sent *this* value
self.last_metric_update = time.time() self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False
self.model_metrics.reset()
self.last_metric_update = time.time()
+1 -1
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log" DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log" PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}" REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}" USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}" WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR" mkdir -p "$WORKSPACE_DIR"
+1
View File
@@ -98,6 +98,7 @@ def call_text2image_workflow(
endpoint=route_response["endpoint"], endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"], reqnum=route_response["reqnum"],
url=route_response["url"], url=route_response["url"],
request_idx=route_response["request_idx"],
) )
# Build the payload for the worker request # Build the payload for the worker request
+1
View File
@@ -82,6 +82,7 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"], endpoint=message["endpoint"],
reqnum=message["reqnum"], reqnum=message["reqnum"],
url=message["url"], url=message["url"],
request_idx=message["request_idx"],
) )
workflow = { workflow = {
"3": { "3": {
+1
View File
@@ -43,6 +43,7 @@ backend = Backend(
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
], ],
], ],
max_wait_time=600
) )
+1
View File
@@ -113,6 +113,7 @@ backend = Backend(
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
], ],
], ],
max_wait_time=600
) )