Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 249ca2eb99 | |||
| d8bb1fcc68 | |||
| 7db54f3bd7 | |||
| d63a060202 | |||
| c6521cb6d4 | |||
| b7fe4ebb91 | |||
| 8ae7b74605 | |||
| 106067d716 | |||
| f5134d4bf5 | |||
| 47e5460532 | |||
| ec2ac0a21a | |||
| 2cde573c56 | |||
| b2e4a5db0c | |||
| 02c8307af7 | |||
| 7c0f316eeb | |||
| b4025a744f | |||
| d190308329 | |||
| 944f83fc03 | |||
| f56bbc0ebe |
+106
-42
@@ -9,6 +9,7 @@ from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
|||||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
from anyio import open_file
|
from anyio import open_file
|
||||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
||||||
@@ -30,7 +31,7 @@ from lib.data_types import (
|
|||||||
BenchmarkResult
|
BenchmarkResult
|
||||||
)
|
)
|
||||||
|
|
||||||
VERSION = "0.1.0"
|
VERSION = "0.2.1"
|
||||||
|
|
||||||
MSG_HISTORY_LEN = 100
|
MSG_HISTORY_LEN = 100
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
@@ -63,16 +64,21 @@ class Backend:
|
|||||||
version = VERSION
|
version = VERSION
|
||||||
msg_history = []
|
msg_history = []
|
||||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||||
|
queue: deque = dataclasses.field(default_factory=deque, repr=False)
|
||||||
unsecured: bool = dataclasses.field(
|
unsecured: bool = dataclasses.field(
|
||||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||||
)
|
)
|
||||||
report_addr: str = dataclasses.field(
|
report_addr: str = dataclasses.field(
|
||||||
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
||||||
)
|
)
|
||||||
|
mtoken: str = dataclasses.field(
|
||||||
|
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.metrics = Metrics()
|
self.metrics = Metrics()
|
||||||
self.metrics._set_version(self.version)
|
self.metrics._set_version(self.version)
|
||||||
|
self.metrics._set_mtoken(self.mtoken)
|
||||||
self._total_pubkey_fetch_errors = 0
|
self._total_pubkey_fetch_errors = 0
|
||||||
self._pubkey = self._fetch_pubkey()
|
self._pubkey = self._fetch_pubkey()
|
||||||
self.__start_healthcheck: bool = False
|
self.__start_healthcheck: bool = False
|
||||||
@@ -137,11 +143,26 @@ class Backend:
|
|||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
||||||
|
|
||||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
|
||||||
|
def advance_queue_after_completion(event: asyncio.Event):
|
||||||
|
"""Pop current head and wake next waiter, if any."""
|
||||||
|
# If this event is current head, wake next waiter
|
||||||
|
if self.queue and self.queue[0] is event:
|
||||||
|
self.queue.popleft()
|
||||||
|
if self.queue:
|
||||||
|
self.queue[0].set()
|
||||||
|
else:
|
||||||
|
# Else, remove it from the queue
|
||||||
|
try:
|
||||||
|
self.queue.remove(event)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def cancel_api_call_if_disconnected() -> None:
|
||||||
await request.wait_for_disconnection()
|
await request.wait_for_disconnection()
|
||||||
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
|
log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled")
|
||||||
self.metrics._request_canceled(request_metrics)
|
self.metrics._request_canceled(request_metrics)
|
||||||
raise asyncio.CancelledError
|
return
|
||||||
|
|
||||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||||
try:
|
try:
|
||||||
@@ -158,7 +179,9 @@ class Backend:
|
|||||||
res = await handler.generate_client_response(request, response)
|
res = await handler.generate_client_response(request, response)
|
||||||
self.metrics._request_success(request_metrics)
|
self.metrics._request_success(request_metrics)
|
||||||
return res
|
return res
|
||||||
except requests.exceptions.RequestException as e:
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
log.debug(f"[backend] Request error: {e}")
|
log.debug(f"[backend] Request error: {e}")
|
||||||
self.metrics._request_errored(request_metrics)
|
self.metrics._request_errored(request_metrics)
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
@@ -173,46 +196,87 @@ class Backend:
|
|||||||
self.metrics._request_reject(request_metrics)
|
self.metrics._request_reject(request_metrics)
|
||||||
return web.Response(status=429)
|
return web.Response(status=429)
|
||||||
|
|
||||||
acquired = False
|
disconnect_task = create_task(cancel_api_call_if_disconnected())
|
||||||
try:
|
next_request_task = None
|
||||||
self.metrics._request_start(request_metrics)
|
work_task = None
|
||||||
if self.allow_parallel_requests is False:
|
event = asyncio.Event() # Used in finally block, so initialize here
|
||||||
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
|
||||||
await self.sem.acquire()
|
self.metrics._request_start(request_metrics)
|
||||||
acquired = True
|
|
||||||
log.debug(
|
try:
|
||||||
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
|
if self.allow_parallel_requests:
|
||||||
)
|
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||||
else:
|
work_task = create_task(make_request())
|
||||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
|
||||||
done, pending = await wait(
|
|
||||||
[
|
for t in pending:
|
||||||
create_task(make_request()),
|
t.cancel()
|
||||||
create_task(cancel_api_call_if_disconnected()),
|
await asyncio.gather(*pending, return_exceptions=True)
|
||||||
],
|
|
||||||
return_when=FIRST_COMPLETED,
|
if disconnect_task in done:
|
||||||
)
|
return web.Response(status=499)
|
||||||
for t in pending:
|
|
||||||
t.cancel()
|
# otherwise work_task completed
|
||||||
await asyncio.gather(*pending, return_exceptions=True)
|
return await work_task
|
||||||
|
|
||||||
|
# FIFO-queue branch
|
||||||
|
else:
|
||||||
|
# Insert a Event into the queue for this request
|
||||||
|
# Event.set() == our request is up next
|
||||||
|
self.queue.append(event)
|
||||||
|
if self.queue and self.queue[0] is event:
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
# Race between our request being next and request being cancelled
|
||||||
|
next_request_task = create_task(event.wait())
|
||||||
|
first_done, first_pending = await wait(
|
||||||
|
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the disconnect task wins the race
|
||||||
|
if disconnect_task in first_done:
|
||||||
|
# Clean up the next_request_task, then exit
|
||||||
|
for t in first_pending:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*first_pending, return_exceptions=True)
|
||||||
|
return web.Response(status=499)
|
||||||
|
|
||||||
|
# We are the next-up request in the queue
|
||||||
|
log.debug(f"Starting work on request {request_metrics.reqnum}...")
|
||||||
|
|
||||||
|
# Race the backend API call with the disconnect task
|
||||||
|
work_task = create_task(make_request())
|
||||||
|
|
||||||
|
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
|
||||||
|
for t in pending:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
|
||||||
|
if disconnect_task in done:
|
||||||
|
return web.Response(status=499)
|
||||||
|
|
||||||
|
# otherwise work_task completed
|
||||||
|
return await work_task
|
||||||
|
|
||||||
done_task = done.pop()
|
|
||||||
try:
|
|
||||||
return done_task.result()
|
|
||||||
except Exception as e:
|
|
||||||
log.debug(f"Request task raised exception: {e}")
|
|
||||||
return web.Response(status=500)
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# Client is gone. Do not write a response; just unwind.
|
|
||||||
return web.Response(status=499)
|
return web.Response(status=499)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Exception in main handler loop {e}")
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Always release the semaphore if it was acquired
|
if not self.allow_parallel_requests:
|
||||||
if acquired:
|
advance_queue_after_completion(event)
|
||||||
self.sem.release()
|
|
||||||
self.metrics._request_end(request_metrics)
|
self.metrics._request_end(request_metrics)
|
||||||
|
cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t]
|
||||||
|
for t in cleanup_tasks:
|
||||||
|
if not t.done():
|
||||||
|
t.cancel()
|
||||||
|
if cleanup_tasks:
|
||||||
|
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def healthcheck_session(self):
|
def healthcheck_session(self):
|
||||||
@@ -314,10 +378,10 @@ class Backend:
|
|||||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||||
log.debug("already ran benchmark")
|
log.debug("already ran benchmark")
|
||||||
# trigger model load
|
# trigger model load
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
_ = await self.__call_api(
|
# _ = await self.__call_api(
|
||||||
handler=self.benchmark_handler, payload=payload
|
# handler=self.benchmark_handler, payload=payload
|
||||||
)
|
# )
|
||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
@@ -392,7 +456,7 @@ class Backend:
|
|||||||
)
|
)
|
||||||
# some backends need a few seconds after logging successful startup before
|
# some backends need a few seconds after logging successful startup before
|
||||||
# they can begin accepting requests
|
# they can begin accepting requests
|
||||||
await sleep(5)
|
# await sleep(5)
|
||||||
try:
|
try:
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
self.__start_healthcheck = True
|
self.__start_healthcheck = True
|
||||||
|
|||||||
@@ -286,6 +286,7 @@ class AutoScalerData:
|
|||||||
"""Data that is reported to autoscaler"""
|
"""Data that is reported to autoscaler"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
|
mtoken: str
|
||||||
version: str
|
version: str
|
||||||
loadtime: float
|
loadtime: float
|
||||||
cur_load: float
|
cur_load: float
|
||||||
|
|||||||
+16
-2
@@ -28,6 +28,7 @@ def get_url() -> str:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Metrics:
|
class Metrics:
|
||||||
version: str = "0"
|
version: str = "0"
|
||||||
|
mtoken: str = ""
|
||||||
last_metric_update: float = 0.0
|
last_metric_update: float = 0.0
|
||||||
last_request_served: float = 0.0
|
last_request_served: float = 0.0
|
||||||
update_pending: bool = False
|
update_pending: bool = False
|
||||||
@@ -142,12 +143,16 @@ class Metrics:
|
|||||||
def _set_version(self, version: str) -> None:
|
def _set_version(self, version: str) -> None:
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|
||||||
|
def _set_mtoken(self, mtoken: str) -> None:
|
||||||
|
self.mtoken = mtoken
|
||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
|
|
||||||
async def __send_delete_requests_and_reset(self):
|
async def __send_delete_requests_and_reset(self):
|
||||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||||
data = {
|
data = {
|
||||||
"worker_id": self.id,
|
"worker_id": self.id,
|
||||||
|
"mtoken": self.mtoken,
|
||||||
"request_idxs": idxs,
|
"request_idxs": idxs,
|
||||||
"success": success_flag,
|
"success": success_flag,
|
||||||
}
|
}
|
||||||
@@ -209,6 +214,7 @@ class Metrics:
|
|||||||
def compute_autoscaler_data() -> AutoScalerData:
|
def compute_autoscaler_data() -> AutoScalerData:
|
||||||
return AutoScalerData(
|
return AutoScalerData(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
|
mtoken=self.mtoken,
|
||||||
version=self.version,
|
version=self.version,
|
||||||
loadtime=(loadtime_snapshot or 0.0),
|
loadtime=(loadtime_snapshot or 0.0),
|
||||||
new_load=self.model_metrics.workload_processing,
|
new_load=self.model_metrics.workload_processing,
|
||||||
@@ -228,17 +234,25 @@ class Metrics:
|
|||||||
|
|
||||||
async def send_data(report_addr: str) -> bool:
|
async def send_data(report_addr: str) -> bool:
|
||||||
data = compute_autoscaler_data()
|
data = compute_autoscaler_data()
|
||||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
log_data = asdict(data)
|
||||||
|
def obfuscate(secret: str) -> str:
|
||||||
|
if secret is None:
|
||||||
|
return ""
|
||||||
|
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
||||||
|
|
||||||
|
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
||||||
log.debug(
|
log.debug(
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
f"sending data to autoscaler",
|
f"sending data to autoscaler",
|
||||||
f"{json.dumps((asdict(data)), indent=2)}",
|
f"{json.dumps(log_data, indent=2)}",
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||||
for attempt in range(1, 4):
|
for attempt in range(1, 4):
|
||||||
try:
|
try:
|
||||||
session = await self.http()
|
session = await self.http()
|
||||||
|
|||||||
+1
-1
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
|
|||||||
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
||||||
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
||||||
|
|
||||||
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
|
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
|
||||||
USE_SSL="${USE_SSL:-true}"
|
USE_SSL="${USE_SSL:-true}"
|
||||||
WORKER_PORT="${WORKER_PORT:-3000}"
|
WORKER_PORT="${WORKER_PORT:-3000}"
|
||||||
mkdir -p "$WORKSPACE_DIR"
|
mkdir -p "$WORKSPACE_DIR"
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ def call_text2image_workflow(
|
|||||||
endpoint=route_response["endpoint"],
|
endpoint=route_response["endpoint"],
|
||||||
reqnum=route_response["reqnum"],
|
reqnum=route_response["reqnum"],
|
||||||
url=route_response["url"],
|
url=route_response["url"],
|
||||||
|
request_idx=route_response["request_idx"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the payload for the worker request
|
# Build the payload for the worker request
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ def call_custom_workflow_for_sd3(
|
|||||||
endpoint=message["endpoint"],
|
endpoint=message["endpoint"],
|
||||||
reqnum=message["reqnum"],
|
reqnum=message["reqnum"],
|
||||||
url=message["url"],
|
url=message["url"],
|
||||||
|
request_idx=message["request_idx"],
|
||||||
)
|
)
|
||||||
workflow = {
|
workflow = {
|
||||||
"3": {
|
"3": {
|
||||||
|
|||||||
Reference in New Issue
Block a user