From f15c7a35406ec812e29acf50a64708d887b78cba Mon Sep 17 00:00:00 2001 From: Nader Arbabian Date: Wed, 2 Apr 2025 16:21:27 -0700 Subject: [PATCH] fix wrong aiohttp version and add pubkey fetch retry --- lib/backend.py | 59 +++++++++++++++++++++++++++++------------------- requirements.txt | 2 +- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 1298774..416e0ac 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -6,7 +6,7 @@ import subprocess import dataclasses import logging from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task -from typing import Tuple, Awaitable, NoReturn, List, Union, Callable +from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from functools import cached_property from anyio import open_file @@ -32,6 +32,7 @@ log = logging.getLogger(__file__) # defines the minimum wait time between sending updates to autoscaler LOG_POLL_INTERVAL = 0.1 BENCHMARK_INDICATOR_FILE = ".has_benchmark" +MAX_PUBKEY_FETCH_ATTEMPTS = 3 @dataclasses.dataclass @@ -56,26 +57,15 @@ class Backend: sem: Semaphore = dataclasses.field(default_factory=Semaphore) def __post_init__(self): - - def fetch_public_key(): - command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"] - result = subprocess.check_output(command, universal_newlines=True) - log.debug("public key:") - log.debug(result) - key = None - for _ in range(5): - try: - key = RSA.import_key(result) - break - except ValueError as e: - log.debug(f"Error downloading key: {e}") - time.sleep(15) - return key - - ########### - - self.PUBLIC_KEY = fetch_public_key() self.metrics = Metrics() + self._total_pubkey_fetch_errors = 0 + self._pubkey = self._fetch_pubkey() + + @property + def pubkey(self) -> Optional[RSA.RsaKey]: + if self._pubkey is None: + self._pubkey = self._fetch_pubkey() + return self._pubkey @cached_property def session(self): @@ -94,6 +84,25 @@ class Backend: return handler_fn #######################################Private####################################### + def _fetch_pubkey(self): + command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"] + result = subprocess.check_output(command, universal_newlines=True) + log.debug("public key:") + log.debug(result) + key = None + for _ in range(5): + try: + key = RSA.import_key(result) + break + except ValueError as e: + log.debug(f"Error downloading key: {e}") + time.sleep(15) + if key is None: + self._total_pubkey_fetch_errors += 1 + if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS: + self.backend_errored("Failed to get autoscaler pubkey") + return key + async def __handle_request( self, handler: EndpointHandler[ApiPayload_T], @@ -109,8 +118,12 @@ class Backend: return web.json_response(dict(error="invalid JSON"), status=422) workload = payload.count_workload() + async def wait_for_disconnection() -> None: + while request.transport and not request.transport.is_closing(): + await sleep(0.5) + async def cancel_api_call_if_disconnected() -> web.Response: - await request.wait_for_disconnection() + await wait_for_disconnection() log.debug(f"request with reqnum: {auth_data.reqnum} was canceled") self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum) return web.Response(status=500) @@ -188,13 +201,13 @@ class Backend: def __check_signature(self, auth_data: AuthData) -> bool: def verify_signature(message, signature): - if self.PUBLIC_KEY is None: + if self.pubkey is None: log.debug(f"No Public Key!") return False h = SHA256.new(message.encode()) try: - pkcs1_15.new(self.PUBLIC_KEY).verify(h, base64.b64decode(signature)) + pkcs1_15.new(self.pubkey).verify(h, base64.b64decode(signature)) return True except (ValueError, TypeError): return False diff --git a/requirements.txt b/requirements.txt index 3b18bf1..8b2619b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ aiofiles==24.1.0 aiohappyeyeballs==2.3.4 -aiohttp==3.11.0b0 +aiohttp==3.11.16 aiojobs==1.2.1 aiosignal==1.3.1 anyio==4.4.0