fix wrong aiohttp version and add pubkey fetch retry

This commit is contained in:
Nader Arbabian
2025-04-02 16:21:27 -07:00
parent e8484e7c08
commit f15c7a3540
2 changed files with 37 additions and 24 deletions
+36 -23
View File
@@ -6,7 +6,7 @@ import subprocess
import dataclasses
import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property
from anyio import open_file
@@ -32,6 +32,7 @@ log = logging.getLogger(__file__)
# defines the minimum wait time between sending updates to autoscaler
LOG_POLL_INTERVAL = 0.1
BENCHMARK_INDICATOR_FILE = ".has_benchmark"
MAX_PUBKEY_FETCH_ATTEMPTS = 3
@dataclasses.dataclass
@@ -56,26 +57,15 @@ class Backend:
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
def __post_init__(self):
def fetch_public_key():
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = None
for _ in range(5):
try:
key = RSA.import_key(result)
break
except ValueError as e:
log.debug(f"Error downloading key: {e}")
time.sleep(15)
return key
###########
self.PUBLIC_KEY = fetch_public_key()
self.metrics = Metrics()
self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey()
@property
def pubkey(self) -> Optional[RSA.RsaKey]:
if self._pubkey is None:
self._pubkey = self._fetch_pubkey()
return self._pubkey
@cached_property
def session(self):
@@ -94,6 +84,25 @@ class Backend:
return handler_fn
#######################################Private#######################################
def _fetch_pubkey(self):
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = None
for _ in range(5):
try:
key = RSA.import_key(result)
break
except ValueError as e:
log.debug(f"Error downloading key: {e}")
time.sleep(15)
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
async def __handle_request(
self,
handler: EndpointHandler[ApiPayload_T],
@@ -109,8 +118,12 @@ class Backend:
return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload()
async def wait_for_disconnection() -> None:
while request.transport and not request.transport.is_closing():
await sleep(0.5)
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
await wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
return web.Response(status=500)
@@ -188,13 +201,13 @@ class Backend:
def __check_signature(self, auth_data: AuthData) -> bool:
def verify_signature(message, signature):
if self.PUBLIC_KEY is None:
if self.pubkey is None:
log.debug(f"No Public Key!")
return False
h = SHA256.new(message.encode())
try:
pkcs1_15.new(self.PUBLIC_KEY).verify(h, base64.b64decode(signature))
pkcs1_15.new(self.pubkey).verify(h, base64.b64decode(signature))
return True
except (ValueError, TypeError):
return False
+1 -1
View File
@@ -1,6 +1,6 @@
aiofiles==24.1.0
aiohappyeyeballs==2.3.4
aiohttp==3.11.0b0
aiohttp==3.11.16
aiojobs==1.2.1
aiosignal==1.3.1
anyio==4.4.0