Compare commits

..

4 Commits

6 changed files with 61 additions and 145 deletions
+58 -121
View File
@@ -9,7 +9,6 @@ from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool from distutils.util import strtobool
from collections import deque
from anyio import open_file from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
@@ -31,7 +30,7 @@ from lib.data_types import (
BenchmarkResult BenchmarkResult
) )
VERSION = "0.2.1" VERSION = "0.1.0"
MSG_HISTORY_LEN = 100 MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
@@ -64,21 +63,13 @@ class Backend:
version = VERSION version = VERSION
msg_history = [] msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore) sem: Semaphore = dataclasses.field(default_factory=Semaphore)
queue: deque = dataclasses.field(default_factory=deque, repr=False)
unsecured: bool = dataclasses.field( unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
) )
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
mtoken: str = dataclasses.field(
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
self.metrics._set_version(self.version) self.metrics._set_version(self.version)
self.metrics._set_mtoken(self.mtoken)
self._total_pubkey_fetch_errors = 0 self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False self.__start_healthcheck: bool = False
@@ -113,19 +104,23 @@ class Backend:
#######################################Private####################################### #######################################Private#######################################
def _fetch_pubkey(self): def _fetch_pubkey(self):
report_addr = self.report_addr.rstrip("/") command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"] result = subprocess.check_output(command, universal_newlines=True)
try: log.debug("public key:")
result = subprocess.check_output(command, universal_newlines=True) log.debug(result)
log.debug("public key:") key = None
log.debug(result) for _ in range(5):
key = RSA.import_key(result) try:
if key is not None: key = RSA.import_key(result)
return key break
except (ValueError , subprocess.CalledProcessError) as e: except ValueError as e:
log.debug(f"Error downloading key: {e}") log.debug(f"Error downloading key: {e}")
self.backend_errored("Failed to get autoscaler pubkey") time.sleep(15)
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
async def __handle_request( async def __handle_request(
self, self,
@@ -143,26 +138,11 @@ class Backend:
workload = payload.count_workload() workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
async def cancel_api_call_if_disconnected() -> web.Response:
def advance_queue_after_completion(event: asyncio.Event):
"""Pop current head and wake next waiter, if any."""
# If this event is current head, wake next waiter
if self.queue and self.queue[0] is event:
self.queue.popleft()
if self.queue:
self.queue[0].set()
else:
# Else, remove it from the queue
try:
self.queue.remove(event)
except ValueError:
pass
async def cancel_api_call_if_disconnected() -> None:
await request.wait_for_disconnection() await request.wait_for_disconnection()
log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled") log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics) self.metrics._request_canceled(request_metrics)
return raise asyncio.CancelledError
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
try: try:
@@ -179,9 +159,7 @@ class Backend:
res = await handler.generate_client_response(request, response) res = await handler.generate_client_response(request, response)
self.metrics._request_success(request_metrics) self.metrics._request_success(request_metrics)
return res return res
except asyncio.CancelledError: except requests.exceptions.RequestException as e:
raise
except Exception as e:
log.debug(f"[backend] Request error: {e}") log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(request_metrics) self.metrics._request_errored(request_metrics)
return web.Response(status=500) return web.Response(status=500)
@@ -196,87 +174,46 @@ class Backend:
self.metrics._request_reject(request_metrics) self.metrics._request_reject(request_metrics)
return web.Response(status=429) return web.Response(status=429)
disconnect_task = create_task(cancel_api_call_if_disconnected()) acquired = False
next_request_task = None
work_task = None
event = asyncio.Event() # Used in finally block, so initialize here
self.metrics._request_start(request_metrics)
try: try:
if self.allow_parallel_requests: self.metrics._request_start(request_metrics)
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") if self.allow_parallel_requests is False:
work_task = create_task(make_request()) log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED) await self.sem.acquire()
acquired = True
for t in pending: log.debug(
t.cancel() f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done:
return web.Response(status=499)
# otherwise work_task completed
return await work_task
# FIFO-queue branch
else:
# Insert a Event into the queue for this request
# Event.set() == our request is up next
self.queue.append(event)
if self.queue and self.queue[0] is event:
event.set()
# Race between our request being next and request being cancelled
next_request_task = create_task(event.wait())
first_done, first_pending = await wait(
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
) )
else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
# If the disconnect task wins the race done_task = done.pop()
if disconnect_task in first_done: try:
# Clean up the next_request_task, then exit return done_task.result()
for t in first_pending: except Exception as e:
t.cancel() log.debug(f"Request task raised exception: {e}")
await asyncio.gather(*first_pending, return_exceptions=True) return web.Response(status=500)
return web.Response(status=499)
# We are the next-up request in the queue
log.debug(f"Starting work on request {request_metrics.reqnum}...")
# Race the backend API call with the disconnect task
work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done:
return web.Response(status=499)
# otherwise work_task completed
return await work_task
except asyncio.CancelledError: except asyncio.CancelledError:
return web.Response(status=499) # Client is gone. Do not write a response; just unwind.
return web.Response(status=499)
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
finally: finally:
if not self.allow_parallel_requests: # Always release the semaphore if it was acquired
advance_queue_after_completion(event) if acquired:
self.sem.release()
self.metrics._request_end(request_metrics) self.metrics._request_end(request_metrics)
cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t]
for t in cleanup_tasks:
if not t.done():
t.cancel()
if cleanup_tasks:
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
@cached_property @cached_property
def healthcheck_session(self): def healthcheck_session(self):
@@ -378,10 +315,10 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f: with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark") log.debug("already ran benchmark")
# trigger model load # trigger model load
# payload = self.benchmark_handler.make_benchmark_payload() payload = self.benchmark_handler.make_benchmark_payload()
# _ = await self.__call_api( _ = await self.__call_api(
# handler=self.benchmark_handler, payload=payload handler=self.benchmark_handler, payload=payload
# ) )
return float(f.readline()) return float(f.readline())
except FileNotFoundError: except FileNotFoundError:
pass pass
@@ -456,7 +393,7 @@ class Backend:
) )
# some backends need a few seconds after logging successful startup before # some backends need a few seconds after logging successful startup before
# they can begin accepting requests # they can begin accepting requests
# await sleep(5) await sleep(5)
try: try:
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
self.__start_healthcheck = True self.__start_healthcheck = True
-1
View File
@@ -286,7 +286,6 @@ class AutoScalerData:
"""Data that is reported to autoscaler""" """Data that is reported to autoscaler"""
id: int id: int
mtoken: str
version: str version: str
loadtime: float loadtime: float
cur_load: float cur_load: float
+2 -20
View File
@@ -28,7 +28,6 @@ def get_url() -> str:
@dataclass @dataclass
class Metrics: class Metrics:
version: str = "0" version: str = "0"
mtoken: str = ""
last_metric_update: float = 0.0 last_metric_update: float = 0.0
last_request_served: float = 0.0 last_request_served: float = 0.0
update_pending: bool = False update_pending: bool = False
@@ -143,16 +142,12 @@ class Metrics:
def _set_version(self, version: str) -> None: def _set_version(self, version: str) -> None:
self.version = version self.version = version
def _set_mtoken(self, mtoken: str) -> None:
self.mtoken = mtoken
#######################################Private####################################### #######################################Private#######################################
async def __send_delete_requests_and_reset(self): async def __send_delete_requests_and_reset(self):
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool: async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
data = { data = {
"worker_id": self.id, "worker_id": self.id,
"mtoken": self.mtoken,
"request_idxs": idxs, "request_idxs": idxs,
"success": success_flag, "success": success_flag,
} }
@@ -185,10 +180,6 @@ class Metrics:
return # nothing to do return # nothing to do
for report_addr in self.report_addr: for report_addr in self.report_addr:
# TODO: Add a Redis subscriber queue for delete_requests
if report_addr == "https://cloud.vast.ai/api/v0":
# Patch: ignore the Redis API report_addr
continue
sent_success = True sent_success = True
sent_failed = True sent_failed = True
@@ -214,7 +205,6 @@ class Metrics:
def compute_autoscaler_data() -> AutoScalerData: def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData( return AutoScalerData(
id=self.id, id=self.id,
mtoken=self.mtoken,
version=self.version, version=self.version,
loadtime=(loadtime_snapshot or 0.0), loadtime=(loadtime_snapshot or 0.0),
new_load=self.model_metrics.workload_processing, new_load=self.model_metrics.workload_processing,
@@ -234,25 +224,17 @@ class Metrics:
async def send_data(report_addr: str) -> bool: async def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data() data = compute_autoscaler_data()
log_data = asdict(data) full_path = report_addr.rstrip("/") + "/worker_status/"
def obfuscate(secret: str) -> str:
if secret is None:
return ""
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
log.debug( log.debug(
"\n".join( "\n".join(
[ [
"#" * 60, "#" * 60,
f"sending data to autoscaler", f"sending data to autoscaler",
f"{json.dumps(log_data, indent=2)}", f"{json.dumps((asdict(data)), indent=2)}",
"#" * 60, "#" * 60,
] ]
) )
) )
full_path = report_addr.rstrip("/") + "/worker_status/"
for attempt in range(1, 4): for attempt in range(1, 4):
try: try:
session = await self.http() session = await self.http()
+1 -1
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log" DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log" PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}" REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}" USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}" WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR" mkdir -p "$WORKSPACE_DIR"
-1
View File
@@ -98,7 +98,6 @@ def call_text2image_workflow(
endpoint=route_response["endpoint"], endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"], reqnum=route_response["reqnum"],
url=route_response["url"], url=route_response["url"],
request_idx=route_response["request_idx"],
) )
# Build the payload for the worker request # Build the payload for the worker request
-1
View File
@@ -82,7 +82,6 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"], endpoint=message["endpoint"],
reqnum=message["reqnum"], reqnum=message["reqnum"],
url=message["url"], url=message["url"],
request_idx=message["request_idx"],
) )
workflow = { workflow = {
"3": { "3": {