fix wrong aiohttp version and add pubkey fetch retry
This commit is contained in:
+36
-23
@@ -6,7 +6,7 @@ import subprocess
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable
|
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from anyio import open_file
|
from anyio import open_file
|
||||||
@@ -32,6 +32,7 @@ log = logging.getLogger(__file__)
|
|||||||
# defines the minimum wait time between sending updates to autoscaler
|
# defines the minimum wait time between sending updates to autoscaler
|
||||||
LOG_POLL_INTERVAL = 0.1
|
LOG_POLL_INTERVAL = 0.1
|
||||||
BENCHMARK_INDICATOR_FILE = ".has_benchmark"
|
BENCHMARK_INDICATOR_FILE = ".has_benchmark"
|
||||||
|
MAX_PUBKEY_FETCH_ATTEMPTS = 3
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -56,26 +57,15 @@ class Backend:
|
|||||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
||||||
def fetch_public_key():
|
|
||||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
|
||||||
result = subprocess.check_output(command, universal_newlines=True)
|
|
||||||
log.debug("public key:")
|
|
||||||
log.debug(result)
|
|
||||||
key = None
|
|
||||||
for _ in range(5):
|
|
||||||
try:
|
|
||||||
key = RSA.import_key(result)
|
|
||||||
break
|
|
||||||
except ValueError as e:
|
|
||||||
log.debug(f"Error downloading key: {e}")
|
|
||||||
time.sleep(15)
|
|
||||||
return key
|
|
||||||
|
|
||||||
###########
|
|
||||||
|
|
||||||
self.PUBLIC_KEY = fetch_public_key()
|
|
||||||
self.metrics = Metrics()
|
self.metrics = Metrics()
|
||||||
|
self._total_pubkey_fetch_errors = 0
|
||||||
|
self._pubkey = self._fetch_pubkey()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pubkey(self) -> Optional[RSA.RsaKey]:
|
||||||
|
if self._pubkey is None:
|
||||||
|
self._pubkey = self._fetch_pubkey()
|
||||||
|
return self._pubkey
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def session(self):
|
def session(self):
|
||||||
@@ -94,6 +84,25 @@ class Backend:
|
|||||||
return handler_fn
|
return handler_fn
|
||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
|
def _fetch_pubkey(self):
|
||||||
|
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||||
|
result = subprocess.check_output(command, universal_newlines=True)
|
||||||
|
log.debug("public key:")
|
||||||
|
log.debug(result)
|
||||||
|
key = None
|
||||||
|
for _ in range(5):
|
||||||
|
try:
|
||||||
|
key = RSA.import_key(result)
|
||||||
|
break
|
||||||
|
except ValueError as e:
|
||||||
|
log.debug(f"Error downloading key: {e}")
|
||||||
|
time.sleep(15)
|
||||||
|
if key is None:
|
||||||
|
self._total_pubkey_fetch_errors += 1
|
||||||
|
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
||||||
|
self.backend_errored("Failed to get autoscaler pubkey")
|
||||||
|
return key
|
||||||
|
|
||||||
async def __handle_request(
|
async def __handle_request(
|
||||||
self,
|
self,
|
||||||
handler: EndpointHandler[ApiPayload_T],
|
handler: EndpointHandler[ApiPayload_T],
|
||||||
@@ -109,8 +118,12 @@ class Backend:
|
|||||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
|
|
||||||
|
async def wait_for_disconnection() -> None:
|
||||||
|
while request.transport and not request.transport.is_closing():
|
||||||
|
await sleep(0.5)
|
||||||
|
|
||||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||||
await request.wait_for_disconnection()
|
await wait_for_disconnection()
|
||||||
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
||||||
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
|
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
@@ -188,13 +201,13 @@ class Backend:
|
|||||||
|
|
||||||
def __check_signature(self, auth_data: AuthData) -> bool:
|
def __check_signature(self, auth_data: AuthData) -> bool:
|
||||||
def verify_signature(message, signature):
|
def verify_signature(message, signature):
|
||||||
if self.PUBLIC_KEY is None:
|
if self.pubkey is None:
|
||||||
log.debug(f"No Public Key!")
|
log.debug(f"No Public Key!")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
h = SHA256.new(message.encode())
|
h = SHA256.new(message.encode())
|
||||||
try:
|
try:
|
||||||
pkcs1_15.new(self.PUBLIC_KEY).verify(h, base64.b64decode(signature))
|
pkcs1_15.new(self.pubkey).verify(h, base64.b64decode(signature))
|
||||||
return True
|
return True
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return False
|
return False
|
||||||
|
|||||||
+1
-1
@@ -1,6 +1,6 @@
|
|||||||
aiofiles==24.1.0
|
aiofiles==24.1.0
|
||||||
aiohappyeyeballs==2.3.4
|
aiohappyeyeballs==2.3.4
|
||||||
aiohttp==3.11.0b0
|
aiohttp==3.11.16
|
||||||
aiojobs==1.2.1
|
aiojobs==1.2.1
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
anyio==4.4.0
|
anyio==4.4.0
|
||||||
|
|||||||
Reference in New Issue
Block a user