Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7437028cb2 | |||
| 9f5a432513 | |||
| e09f1fa953 | |||
| ba6f1c2e4b | |||
| 298590fb88 | |||
| 814c3acd4c | |||
| 22bca74087 |
+18
-30
@@ -66,6 +66,9 @@ class Backend:
|
|||||||
unsecured: bool = dataclasses.field(
|
unsecured: bool = dataclasses.field(
|
||||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||||
)
|
)
|
||||||
|
report_addr: str = dataclasses.field(
|
||||||
|
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.metrics = Metrics()
|
self.metrics = Metrics()
|
||||||
@@ -91,17 +94,6 @@ class Backend:
|
|||||||
timeout = ClientTimeout(total=None)
|
timeout = ClientTimeout(total=None)
|
||||||
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
||||||
|
|
||||||
async def _worker(self):
|
|
||||||
while True:
|
|
||||||
handler, request, fut = await self.request_queue.get()
|
|
||||||
try:
|
|
||||||
res = await self.__process_request(handler, request)
|
|
||||||
fut.set_result(res)
|
|
||||||
except Exception as e:
|
|
||||||
fut.set_exception(e)
|
|
||||||
finally:
|
|
||||||
self.request_queue.task_done()
|
|
||||||
|
|
||||||
def create_handler(
|
def create_handler(
|
||||||
self,
|
self,
|
||||||
handler: EndpointHandler[ApiPayload_T],
|
handler: EndpointHandler[ApiPayload_T],
|
||||||
@@ -115,23 +107,19 @@ class Backend:
|
|||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
def _fetch_pubkey(self):
|
def _fetch_pubkey(self):
|
||||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
report_addr = self.report_addr.rstrip("/")
|
||||||
result = subprocess.check_output(command, universal_newlines=True)
|
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
||||||
log.debug("public key:")
|
try:
|
||||||
log.debug(result)
|
result = subprocess.check_output(command, universal_newlines=True)
|
||||||
key = None
|
log.debug("public key:")
|
||||||
for _ in range(5):
|
log.debug(result)
|
||||||
try:
|
key = RSA.import_key(result)
|
||||||
key = RSA.import_key(result)
|
if key is not None:
|
||||||
break
|
return key
|
||||||
except ValueError as e:
|
except (ValueError , subprocess.CalledProcessError) as e:
|
||||||
log.debug(f"Error downloading key: {e}")
|
log.debug(f"Error downloading key: {e}")
|
||||||
time.sleep(15)
|
self.backend_errored("Failed to get autoscaler pubkey")
|
||||||
if key is None:
|
|
||||||
self._total_pubkey_fetch_errors += 1
|
|
||||||
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
|
||||||
self.backend_errored("Failed to get autoscaler pubkey")
|
|
||||||
return key
|
|
||||||
|
|
||||||
async def __handle_request(
|
async def __handle_request(
|
||||||
self,
|
self,
|
||||||
@@ -297,7 +285,7 @@ class Backend:
|
|||||||
message = {
|
message = {
|
||||||
key: value
|
key: value
|
||||||
for (key, value) in (dataclasses.asdict(auth_data).items())
|
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||||
if key != "signature"
|
if key != "signature" and key != "__request_id"
|
||||||
}
|
}
|
||||||
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -307,7 +295,7 @@ class Backend:
|
|||||||
elif message in self.msg_history:
|
elif message in self.msg_history:
|
||||||
log.debug(f"message: {message} already in message history")
|
log.debug(f"message: {message} already in message history")
|
||||||
return False
|
return False
|
||||||
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
|
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
|
||||||
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||||
self.msg_history.append(message)
|
self.msg_history.append(message)
|
||||||
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||||
|
|||||||
+5
-4
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
|
|||||||
class AuthData:
|
class AuthData:
|
||||||
"""data used to authenticate requester"""
|
"""data used to authenticate requester"""
|
||||||
|
|
||||||
signature: str
|
|
||||||
cost: str
|
cost: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
reqnum: int
|
reqnum: int
|
||||||
url: str
|
|
||||||
request_idx: int
|
request_idx: int
|
||||||
|
signature: str
|
||||||
|
url: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
||||||
@@ -190,11 +190,12 @@ class SystemMetrics:
|
|||||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||||
self.last_disk_usage = disk_usage
|
self.last_disk_usage = disk_usage
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, expected: float | None) -> None:
|
||||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||||
# as well: they should send model_loading_time once when they are done loading
|
# as well: they should send model_loading_time once when they are done loading
|
||||||
self.model_loading_time = None
|
if self.model_loading_time == expected:
|
||||||
|
self.model_loading_time = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
+17
-7
@@ -180,6 +180,10 @@ class Metrics:
|
|||||||
return # nothing to do
|
return # nothing to do
|
||||||
|
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
|
# TODO: Add a Redis subscriber queue for delete_requests
|
||||||
|
if report_addr == "https://cloud.vast.ai/api/v0":
|
||||||
|
# Patch: ignore the Redis API report_addr
|
||||||
|
continue
|
||||||
sent_success = True
|
sent_success = True
|
||||||
sent_failed = True
|
sent_failed = True
|
||||||
|
|
||||||
@@ -200,11 +204,13 @@ class Metrics:
|
|||||||
|
|
||||||
async def __send_metrics_and_reset(self):
|
async def __send_metrics_and_reset(self):
|
||||||
|
|
||||||
|
loadtime_snapshot = self.system_metrics.model_loading_time
|
||||||
|
|
||||||
def compute_autoscaler_data() -> AutoScalerData:
|
def compute_autoscaler_data() -> AutoScalerData:
|
||||||
return AutoScalerData(
|
return AutoScalerData(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
version=self.version,
|
version=self.version,
|
||||||
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
loadtime=(loadtime_snapshot or 0.0),
|
||||||
new_load=self.model_metrics.workload_processing,
|
new_load=self.model_metrics.workload_processing,
|
||||||
cur_load=self.model_metrics.cur_load,
|
cur_load=self.model_metrics.cur_load,
|
||||||
rej_load=self.model_metrics.workload_rejected,
|
rej_load=self.model_metrics.workload_rejected,
|
||||||
@@ -252,11 +258,15 @@ class Metrics:
|
|||||||
|
|
||||||
self.system_metrics.update_disk_usage()
|
self.system_metrics.update_disk_usage()
|
||||||
|
|
||||||
|
sent = False
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
success = await send_data(report_addr)
|
if await send_data(report_addr):
|
||||||
if success is True:
|
sent = True
|
||||||
break
|
break
|
||||||
self.update_pending = False
|
|
||||||
self.model_metrics.reset()
|
if sent:
|
||||||
self.system_metrics.reset()
|
# clear the one-shot loadtime only if we actually sent *this* value
|
||||||
self.last_metric_update = time.time()
|
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||||
|
self.update_pending = False
|
||||||
|
self.model_metrics.reset()
|
||||||
|
self.last_metric_update = time.time()
|
||||||
|
|||||||
Reference in New Issue
Block a user