Compare commits

...

19 Commits

Author SHA1 Message Date
Lucas Armand 249ca2eb99 refactor, handle zombie tasks 2025-11-12 15:23:42 -08:00
Lucas Armand d8bb1fcc68 add fifo queue
Bump pyworker version
2025-11-12 12:26:15 -08:00
LucasArmandVast 7db54f3bd7 Merge pull request #55 from vast-ai/use-mtoken
Use mtoken
2025-11-10 11:54:04 -08:00
LucasArmandVast d63a060202 Merge pull request #56 from vast-ai/obfuscate-mtoken
Obfuscate mtoken in logs
2025-11-10 11:53:17 -08:00
Lucas Armand c6521cb6d4 add ... 2025-11-07 10:10:35 -08:00
Lucas Armand b7fe4ebb91 Obfuscate mtoken in logs 2025-11-07 10:02:39 -08:00
Lucas Armand 8ae7b74605 bump version to 0.2.0 2025-11-05 13:32:21 -08:00
Lucas Armand 106067d716 bump version to 0.1.1 2025-11-04 17:15:59 -08:00
Lucas Armand f5134d4bf5 Fix spelling mistake 2025-11-04 16:59:39 -08:00
Lucas Armand 47e5460532 added mtoken 2025-11-04 15:55:14 -08:00
Colter-Downing ec2ac0a21a Merge pull request #52 from vast-ai/remove-sleeps-and-delays
Remove sleeps and delays
2025-10-30 11:53:39 -07:00
Abiola Akinnubi 2cde573c56 Merge pull request #48 from vast-ai/comfy-request-idx
Added request_idx to comfy auth_data
2025-10-30 11:27:35 -07:00
Abiola Akinnubi b2e4a5db0c Merge pull request #49 from vast-ai/unsecure_report_addr
Added caller for REPORT_ADDR to backend.py to use the report add
2025-10-30 10:39:46 -07:00
edgaratvast 02c8307af7 remove redis pubsub from pyworker (#53)
Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-29 17:07:56 -07:00
Colter Downing 7c0f316eeb leave the env vars alone! 2025-10-29 11:36:46 -07:00
Colter Downing b4025a744f remove env var writing 2025-10-29 09:58:09 -07:00
Colter Downing d190308329 removed 5 sec sleep and warmup request on load 2025-10-29 09:57:46 -07:00
Abiola Akinnubi 944f83fc03 Removed extra spaces from operator assignment 2025-10-28 21:03:52 +00:00
Abiola Akinnubi f56bbc0ebe Added request_idx to comfy auth_data 2025-10-27 03:17:06 +00:00
6 changed files with 127 additions and 46 deletions
+106 -42
View File
@@ -9,6 +9,7 @@ from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool from distutils.util import strtobool
from collections import deque
from anyio import open_file from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
@@ -30,7 +31,7 @@ from lib.data_types import (
BenchmarkResult BenchmarkResult
) )
VERSION = "0.1.0" VERSION = "0.2.1"
MSG_HISTORY_LEN = 100 MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
@@ -63,16 +64,21 @@ class Backend:
version = VERSION version = VERSION
msg_history = [] msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore) sem: Semaphore = dataclasses.field(default_factory=Semaphore)
queue: deque = dataclasses.field(default_factory=deque, repr=False)
unsecured: bool = dataclasses.field( unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
) )
report_addr: str = dataclasses.field( report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai") default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
) )
mtoken: str = dataclasses.field(
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
self.metrics._set_version(self.version) self.metrics._set_version(self.version)
self.metrics._set_mtoken(self.mtoken)
self._total_pubkey_fetch_errors = 0 self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False self.__start_healthcheck: bool = False
@@ -137,11 +143,26 @@ class Backend:
workload = payload.count_workload() workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
async def cancel_api_call_if_disconnected() -> web.Response:
def advance_queue_after_completion(event: asyncio.Event):
"""Pop current head and wake next waiter, if any."""
# If this event is current head, wake next waiter
if self.queue and self.queue[0] is event:
self.queue.popleft()
if self.queue:
self.queue[0].set()
else:
# Else, remove it from the queue
try:
self.queue.remove(event)
except ValueError:
pass
async def cancel_api_call_if_disconnected() -> None:
await request.wait_for_disconnection() await request.wait_for_disconnection()
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled") log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics) self.metrics._request_canceled(request_metrics)
raise asyncio.CancelledError return
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
try: try:
@@ -158,7 +179,9 @@ class Backend:
res = await handler.generate_client_response(request, response) res = await handler.generate_client_response(request, response)
self.metrics._request_success(request_metrics) self.metrics._request_success(request_metrics)
return res return res
except requests.exceptions.RequestException as e: except asyncio.CancelledError:
raise
except Exception as e:
log.debug(f"[backend] Request error: {e}") log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(request_metrics) self.metrics._request_errored(request_metrics)
return web.Response(status=500) return web.Response(status=500)
@@ -173,46 +196,87 @@ class Backend:
self.metrics._request_reject(request_metrics) self.metrics._request_reject(request_metrics)
return web.Response(status=429) return web.Response(status=429)
acquired = False disconnect_task = create_task(cancel_api_call_if_disconnected())
try: next_request_task = None
self.metrics._request_start(request_metrics) work_task = None
if self.allow_parallel_requests is False: event = asyncio.Event() # Used in finally block, so initialize here
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire() self.metrics._request_start(request_metrics)
acquired = True
log.debug( try:
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..." if self.allow_parallel_requests:
) log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
else: work_task = create_task(make_request())
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
done, pending = await wait(
[ for t in pending:
create_task(make_request()), t.cancel()
create_task(cancel_api_call_if_disconnected()), await asyncio.gather(*pending, return_exceptions=True)
],
return_when=FIRST_COMPLETED, if disconnect_task in done:
) return web.Response(status=499)
for t in pending:
t.cancel() # otherwise work_task completed
await asyncio.gather(*pending, return_exceptions=True) return await work_task
# FIFO-queue branch
else:
# Insert a Event into the queue for this request
# Event.set() == our request is up next
self.queue.append(event)
if self.queue and self.queue[0] is event:
event.set()
# Race between our request being next and request being cancelled
next_request_task = create_task(event.wait())
first_done, first_pending = await wait(
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
)
# If the disconnect task wins the race
if disconnect_task in first_done:
# Clean up the next_request_task, then exit
for t in first_pending:
t.cancel()
await asyncio.gather(*first_pending, return_exceptions=True)
return web.Response(status=499)
# We are the next-up request in the queue
log.debug(f"Starting work on request {request_metrics.reqnum}...")
# Race the backend API call with the disconnect task
work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done:
return web.Response(status=499)
# otherwise work_task completed
return await work_task
done_task = done.pop()
try:
return done_task.result()
except Exception as e:
log.debug(f"Request task raised exception: {e}")
return web.Response(status=500)
except asyncio.CancelledError: except asyncio.CancelledError:
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499) return web.Response(status=499)
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
finally: finally:
# Always release the semaphore if it was acquired if not self.allow_parallel_requests:
if acquired: advance_queue_after_completion(event)
self.sem.release()
self.metrics._request_end(request_metrics) self.metrics._request_end(request_metrics)
cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t]
for t in cleanup_tasks:
if not t.done():
t.cancel()
if cleanup_tasks:
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
@cached_property @cached_property
def healthcheck_session(self): def healthcheck_session(self):
@@ -314,10 +378,10 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f: with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark") log.debug("already ran benchmark")
# trigger model load # trigger model load
payload = self.benchmark_handler.make_benchmark_payload() # payload = self.benchmark_handler.make_benchmark_payload()
_ = await self.__call_api( # _ = await self.__call_api(
handler=self.benchmark_handler, payload=payload # handler=self.benchmark_handler, payload=payload
) # )
return float(f.readline()) return float(f.readline())
except FileNotFoundError: except FileNotFoundError:
pass pass
@@ -392,7 +456,7 @@ class Backend:
) )
# some backends need a few seconds after logging successful startup before # some backends need a few seconds after logging successful startup before
# they can begin accepting requests # they can begin accepting requests
await sleep(5) # await sleep(5)
try: try:
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
self.__start_healthcheck = True self.__start_healthcheck = True
+1
View File
@@ -286,6 +286,7 @@ class AutoScalerData:
"""Data that is reported to autoscaler""" """Data that is reported to autoscaler"""
id: int id: int
mtoken: str
version: str version: str
loadtime: float loadtime: float
cur_load: float cur_load: float
+16 -2
View File
@@ -28,6 +28,7 @@ def get_url() -> str:
@dataclass @dataclass
class Metrics: class Metrics:
version: str = "0" version: str = "0"
mtoken: str = ""
last_metric_update: float = 0.0 last_metric_update: float = 0.0
last_request_served: float = 0.0 last_request_served: float = 0.0
update_pending: bool = False update_pending: bool = False
@@ -142,12 +143,16 @@ class Metrics:
def _set_version(self, version: str) -> None: def _set_version(self, version: str) -> None:
self.version = version self.version = version
def _set_mtoken(self, mtoken: str) -> None:
self.mtoken = mtoken
#######################################Private####################################### #######################################Private#######################################
async def __send_delete_requests_and_reset(self): async def __send_delete_requests_and_reset(self):
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool: async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
data = { data = {
"worker_id": self.id, "worker_id": self.id,
"mtoken": self.mtoken,
"request_idxs": idxs, "request_idxs": idxs,
"success": success_flag, "success": success_flag,
} }
@@ -209,6 +214,7 @@ class Metrics:
def compute_autoscaler_data() -> AutoScalerData: def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData( return AutoScalerData(
id=self.id, id=self.id,
mtoken=self.mtoken,
version=self.version, version=self.version,
loadtime=(loadtime_snapshot or 0.0), loadtime=(loadtime_snapshot or 0.0),
new_load=self.model_metrics.workload_processing, new_load=self.model_metrics.workload_processing,
@@ -228,17 +234,25 @@ class Metrics:
async def send_data(report_addr: str) -> bool: async def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data() data = compute_autoscaler_data()
full_path = report_addr.rstrip("/") + "/worker_status/" log_data = asdict(data)
def obfuscate(secret: str) -> str:
if secret is None:
return ""
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
log.debug( log.debug(
"\n".join( "\n".join(
[ [
"#" * 60, "#" * 60,
f"sending data to autoscaler", f"sending data to autoscaler",
f"{json.dumps((asdict(data)), indent=2)}", f"{json.dumps(log_data, indent=2)}",
"#" * 60, "#" * 60,
] ]
) )
) )
full_path = report_addr.rstrip("/") + "/worker_status/"
for attempt in range(1, 4): for attempt in range(1, 4):
try: try:
session = await self.http() session = await self.http()
+1 -1
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log" DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log" PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}" REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}" USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}" WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR" mkdir -p "$WORKSPACE_DIR"
+1
View File
@@ -98,6 +98,7 @@ def call_text2image_workflow(
endpoint=route_response["endpoint"], endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"], reqnum=route_response["reqnum"],
url=route_response["url"], url=route_response["url"],
request_idx=route_response["request_idx"],
) )
# Build the payload for the worker request # Build the payload for the worker request
+1
View File
@@ -82,6 +82,7 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"], endpoint=message["endpoint"],
reqnum=message["reqnum"], reqnum=message["reqnum"],
url=message["url"], url=message["url"],
request_idx=message["request_idx"],
) )
workflow = { workflow = {
"3": { "3": {