Compare commits

..

1 Commits

Author SHA1 Message Date
Nader Arbabian 72a5f6ad13 update tokenizers deps 2025-06-10 17:55:25 -07:00
8 changed files with 62 additions and 82 deletions
+37 -37
View File
@@ -5,10 +5,9 @@ import base64
import subprocess import subprocess
import dataclasses import dataclasses
import logging import logging
from asyncio import sleep, gather, Semaphore from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool
from anyio import open_file from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
@@ -56,15 +55,11 @@ class Backend:
reqnum = -1 reqnum = -1
msg_history = [] msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore) sem: Semaphore = dataclasses.field(default_factory=Semaphore)
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
self._total_pubkey_fetch_errors = 0 self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False
@property @property
def pubkey(self) -> Optional[RSA.RsaKey]: def pubkey(self) -> Optional[RSA.RsaKey]:
@@ -123,6 +118,16 @@ class Backend:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
async def wait_for_disconnection() -> None:
while request.transport and not request.transport.is_closing():
await sleep(0.5)
async def cancel_api_call_if_disconnected() -> web.Response:
await wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}") log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum) self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
@@ -168,42 +173,41 @@ class Backend:
return web.Response(status=401) return web.Response(status=401)
try: try:
return await make_request() done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
[task.cancel() for task in pending]
return done.pop().result()
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
finally:
if request.task.cancelled():
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(
workload=workload, reqnum=auth_data.reqnum
)
async def __healthcheck(self): async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint health_check_url = self.benchmark_handler.healthcheck_endpoint
if health_check_url is None: if health_check_url is None:
log.debug("No healthcheck endpoint defined, skipping healthcheck") log.debug("No healthcheck endpoint defined, skipping healthcheck")
return return
while True: await sleep(5)
await sleep(10) try:
if self.__start_healthcheck is False: log.debug(f"Performing healthcheck on {health_check_url}")
continue async with self.session.get(health_check_url) as response:
try: if response.status == 200:
log.debug(f"Performing healthcheck on {health_check_url}") log.debug("Healthcheck successful")
async with self.session.get(health_check_url) as response: elif response.status == 503:
if response.status == 200: log.debug(f"Healthcheck failed with status: {response.status}")
log.debug("Healthcheck successful") self.backend_errored(
elif response.status == 503: f"Healthcheck failed with status: {response.status}"
log.debug(f"Healthcheck failed with status: {response.status}") )
self.backend_errored( else:
f"Healthcheck failed with status: {response.status}" # endpoint not ready yet so bail
) log.debug(f"Healthcheck Endpoint not ready: {response.status}")
else: except Exception as e:
# endpoint not ready yet so bail log.debug(f"Healthcheck failed with exception: {e}")
log.debug(f"Healthcheck Endpoint not ready: {response.status}") self.backend_errored(str(e))
except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}")
self.backend_errored(str(e))
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
await gather( await gather(
@@ -221,9 +225,6 @@ class Backend:
return await self.session.post(url=handler.endpoint, json=api_payload) return await self.session.post(url=handler.endpoint, json=api_payload)
def __check_signature(self, auth_data: AuthData) -> bool: def __check_signature(self, auth_data: AuthData) -> bool:
if self.unsecured is True:
return True
def verify_signature(message, signature): def verify_signature(message, signature):
if self.pubkey is None: if self.pubkey is None:
log.debug(f"No Public Key!") log.debug(f"No Public Key!")
@@ -330,7 +331,6 @@ class Backend:
await sleep(5) await sleep(5)
try: try:
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
self.__start_healthcheck = True
self.metrics._model_loaded( self.metrics._model_loaded(
max_throughput=max_throughput, max_throughput=max_throughput,
) )
+2 -1
View File
@@ -5,6 +5,7 @@ import json
from asyncio import sleep from asyncio import sleep
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from functools import cache from functools import cache
from urllib.parse import urljoin
import requests import requests
@@ -118,7 +119,7 @@ class Metrics:
def send_data(report_addr: str) -> None: def send_data(report_addr: str) -> None:
data = compute_autoscaler_data() data = compute_autoscaler_data()
full_path = report_addr.rstrip("/") + "/worker_status/" full_path = urljoin(report_addr, "/worker_status/")
log.debug( log.debug(
"\n".join( "\n".join(
[ [
+1 -1
View File
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("starting server...") log.debug("starting server...")
app = web.Application() app = web.Application()
app.add_routes(routes) app.add_routes(routes)
runner = web.AppRunner(app, handler_cancellation=True) runner = web.AppRunner(app)
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(
runner, runner,
+9 -24
View File
@@ -53,13 +53,6 @@ test_args.add_argument(
default="https://run.vast.ai", default="https://run.vast.ai",
help="Call local autoscaler instead of prod, for dev use only", help="Call local autoscaler instead of prod, for dev use only",
) )
test_args.add_argument(
"-i",
dest="instance",
type=str,
default="prod",
help="Autoscaler shard to run the command against, default: prod",
)
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]] GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
@@ -77,7 +70,6 @@ class ClientState:
api_key: str api_key: str
server_url: str server_url: str
worker_endpoint: str worker_endpoint: str
instance: str
payload: ApiPayload payload: ApiPayload
url: str = "" url: str = ""
status: ClientStatus = ClientStatus.FetchEndpoint status: ClientStatus = ClientStatus.FetchEndpoint
@@ -87,7 +79,11 @@ class ClientState:
def make_call(self): def make_call(self):
self.status = ClientStatus.FetchEndpoint self.status = ClientStatus.FetchEndpoint
if not self.api_key: endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
)
if not endpoint_api_key:
self.as_error.append( self.as_error.append(
f"Endpoint {self.endpoint_group_name} not found for API key", f"Endpoint {self.endpoint_group_name} not found for API key",
) )
@@ -95,14 +91,12 @@ class ClientState:
return return
route_payload = { route_payload = {
"endpoint": self.endpoint_group_name, "endpoint": self.endpoint_group_name,
"api_key": self.api_key, "api_key": endpoint_api_key,
"cost": self.payload.count_workload(), "cost": self.payload.count_workload(),
} }
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post( response = requests.post(
urljoin(self.server_url, "/route/"), urljoin(self.server_url, "/route/"),
json=route_payload, json=route_payload,
headers=headers,
timeout=4, timeout=4,
) )
if response.status_code != 200: if response.status_code != 200:
@@ -141,7 +135,6 @@ class ClientState:
try: try:
self.make_call() self.make_call()
except Exception as e: except Exception as e:
print(e)
self.status = ClientStatus.Error self.status = ClientStatus.Error
_ = e _ = e
self.conn_errors[self.url] += 1 self.conn_errors[self.url] += 1
@@ -233,7 +226,6 @@ def run_test(
server_url: str, server_url: str,
worker_endpoint: str, worker_endpoint: str,
payload_cls: Type[ApiPayload], payload_cls: Type[ApiPayload],
instance: str,
): ):
threads = [] threads = []
@@ -242,7 +234,8 @@ def run_test(
print_thread.daemon = True # makes threads get killed on program exit print_thread.daemon = True # makes threads get killed on program exit
print_thread.start() print_thread.start()
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance endpoint_name=endpoint_group_name,
account_api_key=api_key,
) )
if not endpoint_api_key: if not endpoint_api_key:
log.debug(f"Endpoint {endpoint_group_name} not found for API key") log.debug(f"Endpoint {endpoint_group_name} not found for API key")
@@ -255,7 +248,6 @@ def run_test(
server_url=server_url, server_url=server_url,
worker_endpoint=worker_endpoint, worker_endpoint=worker_endpoint,
payload=payload_cls.for_test(), payload=payload_cls.for_test(),
instance=instance,
) )
clients.append(client) clients.append(client)
thread = threading.Thread(target=client.simulate_user, args=()) thread = threading.Thread(target=client.simulate_user, args=())
@@ -289,19 +281,12 @@ def test_load_cmd(
args = arg_parser.parse_args() args = arg_parser.parse_args()
if hasattr(args, "comfy_model"): if hasattr(args, "comfy_model"):
os.environ["COMFY_MODEL"] = args.comfy_model os.environ["COMFY_MODEL"] = args.comfy_model
server_url = dict(
prod="https://run.vast.ai",
alpha="https://run-alpha.vast.ai",
candidate="https://run-candidate.vast.ai",
local="http://localhost:8080",
)[args.instance]
run_test( run_test(
num_requests=args.num_requests, num_requests=args.num_requests,
requests_per_second=args.requests_per_second, requests_per_second=args.requests_per_second,
api_key=args.api_key, api_key=args.api_key,
server_url=server_url, server_url=args.server_url,
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
worker_endpoint=endpoint, worker_endpoint=endpoint,
payload_cls=payload_cls, payload_cls=payload_cls,
instance=args.instance,
) )
+11 -11
View File
@@ -87,23 +87,23 @@ if [ "$USE_SSL" = true ]; then
IP.1 = 0.0.0.0 IP.1 = 0.0.0.0
EOF EOF
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \ openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
-nodes \ -nodes \
-sha256 \ -sha256 \
-keyout /etc/instance.key \ -keyout /etc/instance.key \
-out /etc/instance.csr \ -out /etc/instance.csr \
-config /etc/openssl-san.cnf -config /etc/openssl-san.cnf
curl --header 'Content-Type: application/octet-stream' \ curl --header 'Content-Type: application/octet-stream' \
--data-binary @//etc/instance.csr \ --data-binary @//etc/instance.csr \
-X \ -X \
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
fi fi
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED export REPORT_ADDR WORKER_PORT USE_SSL
cd "$SERVER_DIR" cd "$SERVER_DIR"
+2 -6
View File
@@ -17,9 +17,7 @@ class Endpoint:
""" """
@staticmethod @staticmethod
def get_endpoint_api_key( def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]:
endpoint_name: str, account_api_key: str, instance: str
) -> Optional[str]:
""" """
Fetch endpoint API key from VastAI console following the healthcheck pattern. Fetch endpoint API key from VastAI console following the healthcheck pattern.
@@ -35,9 +33,7 @@ class Endpoint:
try: try:
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}") log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
response = requests.get( response = requests.get(vast_console_url, headers=headers)
f"{vast_console_url}?autoscaler_instance={instance}", headers=headers
)
if response.status_code != 200: if response.status_code != 200:
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}" error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
-1
View File
@@ -153,7 +153,6 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name, endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key, account_api_key=args.api_key,
instance=args.instance,
) )
if endpoint_api_key: if endpoint_api_key:
try: try:
-1
View File
@@ -100,7 +100,6 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name, endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key, account_api_key=args.api_key,
instance=args.instance,
) )
if endpoint_api_key: if endpoint_api_key:
try: try: