Compare commits

..

1 Commits

Author SHA1 Message Date
Lucas Armand 0471f6b219 trying queue 2025-10-27 17:34:37 -07:00
6 changed files with 139 additions and 168 deletions
+125 -126
View File
@@ -5,11 +5,10 @@ import base64
import subprocess
import dataclasses
import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task, get_running_loop
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property
from distutils.util import strtobool
from collections import deque
from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
@@ -31,7 +30,7 @@ from lib.data_types import (
BenchmarkResult
)
VERSION = "0.2.1"
VERSION = "0.1.0"
MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
@@ -48,7 +47,7 @@ class Backend:
This class is responsible for:
1. Tailing logs and updating load time metrics
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
sending the request. It also updates metrics as it makes those requests.
sending the request. It also updates metrics as it makes those requests.
3. Running a benchmark from an EndpointHandler
"""
@@ -64,25 +63,22 @@ class Backend:
version = VERSION
msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
queue: deque = dataclasses.field(default_factory=deque, repr=False)
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
mtoken: str = dataclasses.field(
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
)
def __post_init__(self):
self.metrics = Metrics()
self.metrics._set_version(self.version)
self.metrics._set_mtoken(self.mtoken)
self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False
# NEW: FIFO queue + worker count
self.request_queue: "asyncio.Queue[tuple[EndpointHandler[ApiPayload_T], web.Request, asyncio.Future]]" = asyncio.Queue()
# If parallel allowed, let multiple workers drain the queue (order preserved by FIFO per worker; overall start order is FIFO).
self._num_workers: int = 1 if not self.allow_parallel_requests else int(os.environ.get("WORKERS", "4"))
@property
def pubkey(self) -> Optional[RSA.RsaKey]:
if self._pubkey is None:
@@ -100,6 +96,22 @@ class Backend:
timeout = ClientTimeout(total=None)
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
async def _worker(self):
while True:
handler, request, fut = await self.request_queue.get()
try:
# Skip if already cancelled while waiting in the queue
if fut.cancelled():
continue
res = await self.__process_enqueued_request(handler, request)
if not fut.cancelled():
fut.set_result(res)
except Exception as e:
if not fut.cancelled():
fut.set_exception(e)
finally:
self.request_queue.task_done()
def create_handler(
self,
handler: EndpointHandler[ApiPayload_T],
@@ -113,26 +125,59 @@ class Backend:
#######################################Private#######################################
def _fetch_pubkey(self):
report_addr = self.report_addr.rstrip("/")
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
try:
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = RSA.import_key(result)
if key is not None:
return key
except (ValueError , subprocess.CalledProcessError) as e:
log.debug(f"Error downloading key: {e}")
self.backend_errored("Failed to get autoscaler pubkey")
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = None
for _ in range(5):
try:
key = RSA.import_key(result)
break
except ValueError as e:
log.debug(f"Error downloading key: {e}")
time.sleep(15)
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
async def __handle_request(
self,
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""use this function to forward requests to the model endpoint"""
"""use this function to enqueue requests for FIFO processing"""
loop = get_running_loop()
fut: asyncio.Future = loop.create_future()
# If the client disconnects while waiting in the FIFO, cancel the future so the worker skips it
cancel_watch = create_task(request.wait_for_disconnection())
def _cancel_if_disconnected(_):
if not fut.done():
fut.cancel()
cancel_watch.add_done_callback(_cancel_if_disconnected)
try:
await self.request_queue.put((handler, request, fut))
return await fut
except asyncio.CancelledError:
# Propagate cancellation to ensure aiohttp doesn't expect a response body
raise
finally:
# Best-effort cleanup of the watcher
cancel_watch.cancel()
async def __process_enqueued_request(
self,
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""
This contains the original __handle_request logic and is invoked by workers,
ensuring FIFO execution via asyncio.Queue.
"""
try:
data = await request.json()
auth_data, payload = handler.get_data_from_request(data)
@@ -140,29 +185,17 @@ class Backend:
return web.json_response(data=e.message, status=422)
except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
request_metrics: RequestMetrics = RequestMetrics(
request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created"
)
def advance_queue_after_completion(event: asyncio.Event):
"""Pop current head and wake next waiter, if any."""
# If this event is current head, wake next waiter
if self.queue and self.queue[0] is event:
self.queue.popleft()
if self.queue:
self.queue[0].set()
else:
# Else, remove it from the queue
try:
self.queue.remove(event)
except ValueError:
pass
async def cancel_api_call_if_disconnected() -> None:
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled")
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics)
return
raise asyncio.CancelledError
async def make_request() -> Union[web.Response, web.StreamResponse]:
try:
@@ -179,9 +212,7 @@ class Backend:
res = await handler.generate_client_response(request, response)
self.metrics._request_success(request_metrics)
return res
except asyncio.CancelledError:
raise
except Exception as e:
except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(request_metrics)
return web.Response(status=500)
@@ -196,87 +227,49 @@ class Backend:
self.metrics._request_reject(request_metrics)
return web.Response(status=429)
disconnect_task = create_task(cancel_api_call_if_disconnected())
next_request_task = None
work_task = None
event = asyncio.Event() # Used in finally block, so initialize here
self.metrics._request_start(request_metrics)
acquired = False
try:
if self.allow_parallel_requests:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
self.metrics._request_start(request_metrics)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done:
return web.Response(status=499)
# otherwise work_task completed
return await work_task
# FIFO-queue branch
else:
# Insert a Event into the queue for this request
# Event.set() == our request is up next
self.queue.append(event)
if self.queue and self.queue[0] is event:
event.set()
# Race between our request being next and request being cancelled
next_request_task = create_task(event.wait())
first_done, first_pending = await wait(
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
# Preserve existing semaphore behavior for serializing requests when requested
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
acquired = True
log.debug(
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
# If the disconnect task wins the race
if disconnect_task in first_done:
# Clean up the next_request_task, then exit
for t in first_pending:
t.cancel()
await asyncio.gather(*first_pending, return_exceptions=True)
return web.Response(status=499)
done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
# We are the next-up request in the queue
log.debug(f"Starting work on request {request_metrics.reqnum}...")
# Race the backend API call with the disconnect task
work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done:
return web.Response(status=499)
# otherwise work_task completed
return await work_task
done_task = done.pop()
try:
return done_task.result()
except Exception as e:
log.debug(f"Request task raised exception: {e}")
return web.Response(status=500)
except asyncio.CancelledError:
return web.Response(status=499)
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499)
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
finally:
if not self.allow_parallel_requests:
advance_queue_after_completion(event)
# Always release the semaphore if it was acquired
if acquired:
self.sem.release()
self.metrics._request_end(request_metrics)
cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t]
for t in cleanup_tasks:
if not t.done():
t.cancel()
if cleanup_tasks:
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
@cached_property
def healthcheck_session(self):
@@ -316,8 +309,14 @@ class Backend:
self.backend_errored(str(e))
async def _start_tracking(self) -> None:
# Start the FIFO workers alongside existing loops
worker_tasks = tuple(self._worker() for _ in range(self._num_workers))
await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
self.__read_logs(),
self.metrics._send_metrics_loop(),
self.__healthcheck(),
self.metrics._send_delete_requests_loop(),
*worker_tasks,
)
def backend_errored(self, msg: str) -> None:
@@ -349,7 +348,7 @@ class Backend:
message = {
key: value
for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature" and key != "__request_id"
if key != "signature"
}
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug(
@@ -359,7 +358,7 @@ class Backend:
elif message in self.msg_history:
log.debug(f"message: {message} already in message history")
return False
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
@@ -378,10 +377,10 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark")
# trigger model load
# payload = self.benchmark_handler.make_benchmark_payload()
# _ = await self.__call_api(
# handler=self.benchmark_handler, payload=payload
# )
payload = self.benchmark_handler.make_benchmark_payload()
_ = await self.__call_api(
handler=self.benchmark_handler, payload=payload
)
return float(f.readline())
except FileNotFoundError:
pass
@@ -456,7 +455,7 @@ class Backend:
)
# some backends need a few seconds after logging successful startup before
# they can begin accepting requests
# await sleep(5)
await sleep(5)
try:
max_throughput = await run_benchmark()
self.__start_healthcheck = True
+4 -6
View File
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
class AuthData:
"""data used to authenticate requester"""
signature: str
cost: str
endpoint: str
reqnum: int
request_idx: int
signature: str
url: str
request_idx: int
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -190,12 +190,11 @@ class SystemMetrics:
self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage
def reset(self, expected: float | None) -> None:
def reset(self):
# autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading
if self.model_loading_time == expected:
self.model_loading_time = None
self.model_loading_time = None
@dataclass
@@ -286,7 +285,6 @@ class AutoScalerData:
"""Data that is reported to autoscaler"""
id: int
mtoken: str
version: str
loadtime: float
cur_load: float
+9 -33
View File
@@ -28,7 +28,6 @@ def get_url() -> str:
@dataclass
class Metrics:
version: str = "0"
mtoken: str = ""
last_metric_update: float = 0.0
last_request_served: float = 0.0
update_pending: bool = False
@@ -143,16 +142,12 @@ class Metrics:
def _set_version(self, version: str) -> None:
self.version = version
def _set_mtoken(self, mtoken: str) -> None:
self.mtoken = mtoken
#######################################Private#######################################
async def __send_delete_requests_and_reset(self):
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
data = {
"worker_id": self.id,
"mtoken": self.mtoken,
"request_idxs": idxs,
"success": success_flag,
}
@@ -185,10 +180,6 @@ class Metrics:
return # nothing to do
for report_addr in self.report_addr:
# TODO: Add a Redis subscriber queue for delete_requests
if report_addr == "https://cloud.vast.ai/api/v0":
# Patch: ignore the Redis API report_addr
continue
sent_success = True
sent_failed = True
@@ -209,14 +200,11 @@ class Metrics:
async def __send_metrics_and_reset(self):
loadtime_snapshot = self.system_metrics.model_loading_time
def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData(
id=self.id,
mtoken=self.mtoken,
version=self.version,
loadtime=(loadtime_snapshot or 0.0),
loadtime=(self.system_metrics.model_loading_time or 0.0),
new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected,
@@ -234,25 +222,17 @@ class Metrics:
async def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data()
log_data = asdict(data)
def obfuscate(secret: str) -> str:
if secret is None:
return ""
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
full_path = report_addr.rstrip("/") + "/worker_status/"
log.debug(
"\n".join(
[
"#" * 60,
f"sending data to autoscaler",
f"{json.dumps(log_data, indent=2)}",
f"{json.dumps((asdict(data)), indent=2)}",
"#" * 60,
]
)
)
full_path = report_addr.rstrip("/") + "/worker_status/"
for attempt in range(1, 4):
try:
session = await self.http()
@@ -272,15 +252,11 @@ class Metrics:
self.system_metrics.update_disk_usage()
sent = False
for report_addr in self.report_addr:
if await send_data(report_addr):
sent = True
success = await send_data(report_addr)
if success is True:
break
if sent:
# clear the one-shot loadtime only if we actually sent *this* value
self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False
self.model_metrics.reset()
self.last_metric_update = time.time()
self.update_pending = False
self.model_metrics.reset()
self.system_metrics.reset()
self.last_metric_update = time.time()
+1 -1
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR"
-1
View File
@@ -98,7 +98,6 @@ def call_text2image_workflow(
endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"],
url=route_response["url"],
request_idx=route_response["request_idx"],
)
# Build the payload for the worker request
-1
View File
@@ -82,7 +82,6 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
request_idx=message["request_idx"],
)
workflow = {
"3": {