Compare commits

..

6 Commits

Author SHA1 Message Date
Nader Arbabian 6b0f019cf7 add option to skip auth 2025-07-15 15:18:09 -07:00
Nader Arbabian ce52419023 AUTO-421: clean up some issues 2025-07-11 15:11:48 -07:00
Nader Arbabian 3e49b7d04b AUTO-421: fix pyworker miscounting active connections 2025-07-10 19:30:35 -07:00
Nader Arbabian 0bf2d04223 stop using urljoin for worker_status endpoint 2025-06-17 23:09:45 -07:00
Nader Arbabian 9ebf1924ea don't healthcheck endpoints until model is loaded and benchmarks have run 2025-06-11 15:26:50 -07:00
Nader Arbabian 0ab9a13a46 update tokenizers deps 2025-06-10 17:56:06 -07:00
8 changed files with 82 additions and 62 deletions
+37 -37
View File
@@ -5,9 +5,10 @@ import base64
import subprocess import subprocess
import dataclasses import dataclasses
import logging import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task from asyncio import sleep, gather, Semaphore
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool
from anyio import open_file from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
@@ -55,11 +56,15 @@ class Backend:
reqnum = -1 reqnum = -1
msg_history = [] msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore) sem: Semaphore = dataclasses.field(default_factory=Semaphore)
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
self._total_pubkey_fetch_errors = 0 self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False
@property @property
def pubkey(self) -> Optional[RSA.RsaKey]: def pubkey(self) -> Optional[RSA.RsaKey]:
@@ -118,16 +123,6 @@ class Backend:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
async def wait_for_disconnection() -> None:
while request.transport and not request.transport.is_closing():
await sleep(0.5)
async def cancel_api_call_if_disconnected() -> web.Response:
await wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}") log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum) self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
@@ -173,41 +168,42 @@ class Backend:
return web.Response(status=401) return web.Response(status=401)
try: try:
done, pending = await wait( return await make_request()
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
[task.cancel() for task in pending]
return done.pop().result()
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
finally:
if request.task.cancelled():
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(
workload=workload, reqnum=auth_data.reqnum
)
async def __healthcheck(self): async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint health_check_url = self.benchmark_handler.healthcheck_endpoint
if health_check_url is None: if health_check_url is None:
log.debug("No healthcheck endpoint defined, skipping healthcheck") log.debug("No healthcheck endpoint defined, skipping healthcheck")
return return
await sleep(5) while True:
try: await sleep(10)
log.debug(f"Performing healthcheck on {health_check_url}") if self.__start_healthcheck is False:
async with self.session.get(health_check_url) as response: continue
if response.status == 200: try:
log.debug("Healthcheck successful") log.debug(f"Performing healthcheck on {health_check_url}")
elif response.status == 503: async with self.session.get(health_check_url) as response:
log.debug(f"Healthcheck failed with status: {response.status}") if response.status == 200:
self.backend_errored( log.debug("Healthcheck successful")
f"Healthcheck failed with status: {response.status}" elif response.status == 503:
) log.debug(f"Healthcheck failed with status: {response.status}")
else: self.backend_errored(
# endpoint not ready yet so bail f"Healthcheck failed with status: {response.status}"
log.debug(f"Healthcheck Endpoint not ready: {response.status}") )
except Exception as e: else:
log.debug(f"Healthcheck failed with exception: {e}") # endpoint not ready yet so bail
self.backend_errored(str(e)) log.debug(f"Healthcheck Endpoint not ready: {response.status}")
except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}")
self.backend_errored(str(e))
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
await gather( await gather(
@@ -225,6 +221,9 @@ class Backend:
return await self.session.post(url=handler.endpoint, json=api_payload) return await self.session.post(url=handler.endpoint, json=api_payload)
def __check_signature(self, auth_data: AuthData) -> bool: def __check_signature(self, auth_data: AuthData) -> bool:
if self.unsecured is True:
return True
def verify_signature(message, signature): def verify_signature(message, signature):
if self.pubkey is None: if self.pubkey is None:
log.debug(f"No Public Key!") log.debug(f"No Public Key!")
@@ -331,6 +330,7 @@ class Backend:
await sleep(5) await sleep(5)
try: try:
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
self.__start_healthcheck = True
self.metrics._model_loaded( self.metrics._model_loaded(
max_throughput=max_throughput, max_throughput=max_throughput,
) )
+1 -2
View File
@@ -5,7 +5,6 @@ import json
from asyncio import sleep from asyncio import sleep
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from functools import cache from functools import cache
from urllib.parse import urljoin
import requests import requests
@@ -119,7 +118,7 @@ class Metrics:
def send_data(report_addr: str) -> None: def send_data(report_addr: str) -> None:
data = compute_autoscaler_data() data = compute_autoscaler_data()
full_path = urljoin(report_addr, "/worker_status/") full_path = report_addr.rstrip("/") + "/worker_status/"
log.debug( log.debug(
"\n".join( "\n".join(
[ [
+1 -1
View File
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("starting server...") log.debug("starting server...")
app = web.Application() app = web.Application()
app.add_routes(routes) app.add_routes(routes)
runner = web.AppRunner(app) runner = web.AppRunner(app, handler_cancellation=True)
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(
runner, runner,
+24 -9
View File
@@ -53,6 +53,13 @@ test_args.add_argument(
default="https://run.vast.ai", default="https://run.vast.ai",
help="Call local autoscaler instead of prod, for dev use only", help="Call local autoscaler instead of prod, for dev use only",
) )
test_args.add_argument(
"-i",
dest="instance",
type=str,
default="prod",
help="Autoscaler shard to run the command against, default: prod",
)
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]] GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
@@ -70,6 +77,7 @@ class ClientState:
api_key: str api_key: str
server_url: str server_url: str
worker_endpoint: str worker_endpoint: str
instance: str
payload: ApiPayload payload: ApiPayload
url: str = "" url: str = ""
status: ClientStatus = ClientStatus.FetchEndpoint status: ClientStatus = ClientStatus.FetchEndpoint
@@ -79,11 +87,7 @@ class ClientState:
def make_call(self): def make_call(self):
self.status = ClientStatus.FetchEndpoint self.status = ClientStatus.FetchEndpoint
endpoint_api_key = Endpoint.get_endpoint_api_key( if not self.api_key:
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
)
if not endpoint_api_key:
self.as_error.append( self.as_error.append(
f"Endpoint {self.endpoint_group_name} not found for API key", f"Endpoint {self.endpoint_group_name} not found for API key",
) )
@@ -91,12 +95,14 @@ class ClientState:
return return
route_payload = { route_payload = {
"endpoint": self.endpoint_group_name, "endpoint": self.endpoint_group_name,
"api_key": endpoint_api_key, "api_key": self.api_key,
"cost": self.payload.count_workload(), "cost": self.payload.count_workload(),
} }
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post( response = requests.post(
urljoin(self.server_url, "/route/"), urljoin(self.server_url, "/route/"),
json=route_payload, json=route_payload,
headers=headers,
timeout=4, timeout=4,
) )
if response.status_code != 200: if response.status_code != 200:
@@ -135,6 +141,7 @@ class ClientState:
try: try:
self.make_call() self.make_call()
except Exception as e: except Exception as e:
print(e)
self.status = ClientStatus.Error self.status = ClientStatus.Error
_ = e _ = e
self.conn_errors[self.url] += 1 self.conn_errors[self.url] += 1
@@ -226,6 +233,7 @@ def run_test(
server_url: str, server_url: str,
worker_endpoint: str, worker_endpoint: str,
payload_cls: Type[ApiPayload], payload_cls: Type[ApiPayload],
instance: str,
): ):
threads = [] threads = []
@@ -234,8 +242,7 @@ def run_test(
print_thread.daemon = True # makes threads get killed on program exit print_thread.daemon = True # makes threads get killed on program exit
print_thread.start() print_thread.start()
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=endpoint_group_name, endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
account_api_key=api_key,
) )
if not endpoint_api_key: if not endpoint_api_key:
log.debug(f"Endpoint {endpoint_group_name} not found for API key") log.debug(f"Endpoint {endpoint_group_name} not found for API key")
@@ -248,6 +255,7 @@ def run_test(
server_url=server_url, server_url=server_url,
worker_endpoint=worker_endpoint, worker_endpoint=worker_endpoint,
payload=payload_cls.for_test(), payload=payload_cls.for_test(),
instance=instance,
) )
clients.append(client) clients.append(client)
thread = threading.Thread(target=client.simulate_user, args=()) thread = threading.Thread(target=client.simulate_user, args=())
@@ -281,12 +289,19 @@ def test_load_cmd(
args = arg_parser.parse_args() args = arg_parser.parse_args()
if hasattr(args, "comfy_model"): if hasattr(args, "comfy_model"):
os.environ["COMFY_MODEL"] = args.comfy_model os.environ["COMFY_MODEL"] = args.comfy_model
server_url = dict(
prod="https://run.vast.ai",
alpha="https://run-alpha.vast.ai",
candidate="https://run-candidate.vast.ai",
local="http://localhost:8080",
)[args.instance]
run_test( run_test(
num_requests=args.num_requests, num_requests=args.num_requests,
requests_per_second=args.requests_per_second, requests_per_second=args.requests_per_second,
api_key=args.api_key, api_key=args.api_key,
server_url=args.server_url, server_url=server_url,
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
worker_endpoint=endpoint, worker_endpoint=endpoint,
payload_cls=payload_cls, payload_cls=payload_cls,
instance=args.instance,
) )
+11 -11
View File
@@ -87,23 +87,23 @@ if [ "$USE_SSL" = true ]; then
IP.1 = 0.0.0.0 IP.1 = 0.0.0.0
EOF EOF
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \ openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
-nodes \ -nodes \
-sha256 \ -sha256 \
-keyout /etc/instance.key \ -keyout /etc/instance.key \
-out /etc/instance.csr \ -out /etc/instance.csr \
-config /etc/openssl-san.cnf -config /etc/openssl-san.cnf
curl --header 'Content-Type: application/octet-stream' \ curl --header 'Content-Type: application/octet-stream' \
--data-binary @//etc/instance.csr \ --data-binary @//etc/instance.csr \
-X \ -X \
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
fi fi
export REPORT_ADDR WORKER_PORT USE_SSL export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
cd "$SERVER_DIR" cd "$SERVER_DIR"
+6 -2
View File
@@ -17,7 +17,9 @@ class Endpoint:
""" """
@staticmethod @staticmethod
def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]: def get_endpoint_api_key(
endpoint_name: str, account_api_key: str, instance: str
) -> Optional[str]:
""" """
Fetch endpoint API key from VastAI console following the healthcheck pattern. Fetch endpoint API key from VastAI console following the healthcheck pattern.
@@ -33,7 +35,9 @@ class Endpoint:
try: try:
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}") log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
response = requests.get(vast_console_url, headers=headers) response = requests.get(
f"{vast_console_url}?autoscaler_instance={instance}", headers=headers
)
if response.status_code != 200: if response.status_code != 200:
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}" error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
+1
View File
@@ -153,6 +153,7 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name, endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key, account_api_key=args.api_key,
instance=args.instance,
) )
if endpoint_api_key: if endpoint_api_key:
try: try:
+1
View File
@@ -100,6 +100,7 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name, endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key, account_api_key=args.api_key,
instance=args.instance,
) )
if endpoint_api_key: if endpoint_api_key:
try: try: