fix pyworker miscounting active connections (#20)

* fix pyworker miscounting active connections

* clean up some issues

* add option to skip auth
This commit is contained in:
Nader Arbabian
2025-07-15 15:33:27 -07:00
committed by GitHub
parent 0bf2d04223
commit 6fb610cb5b
7 changed files with 59 additions and 43 deletions
+15 -20
View File
@@ -5,9 +5,10 @@ import base64
import subprocess import subprocess
import dataclasses import dataclasses
import logging import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task from asyncio import sleep, gather, Semaphore
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool
from anyio import open_file from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
@@ -55,6 +56,9 @@ class Backend:
reqnum = -1 reqnum = -1
msg_history = [] msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore) sem: Semaphore = dataclasses.field(default_factory=Semaphore)
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
@@ -119,16 +123,6 @@ class Backend:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
async def wait_for_disconnection() -> None:
while request.transport and not request.transport.is_closing():
await sleep(0.5)
async def cancel_api_call_if_disconnected() -> web.Response:
await wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}") log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum) self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
@@ -174,18 +168,16 @@ class Backend:
return web.Response(status=401) return web.Response(status=401)
try: try:
done, pending = await wait( return await make_request()
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
[task.cancel() for task in pending]
return done.pop().result()
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
finally:
if request.task.cancelled():
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(
workload=workload, reqnum=auth_data.reqnum
)
async def __healthcheck(self): async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint health_check_url = self.benchmark_handler.healthcheck_endpoint
@@ -229,6 +221,9 @@ class Backend:
return await self.session.post(url=handler.endpoint, json=api_payload) return await self.session.post(url=handler.endpoint, json=api_payload)
def __check_signature(self, auth_data: AuthData) -> bool: def __check_signature(self, auth_data: AuthData) -> bool:
if self.unsecured is True:
return True
def verify_signature(message, signature): def verify_signature(message, signature):
if self.pubkey is None: if self.pubkey is None:
log.debug(f"No Public Key!") log.debug(f"No Public Key!")
+1 -1
View File
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("starting server...") log.debug("starting server...")
app = web.Application() app = web.Application()
app.add_routes(routes) app.add_routes(routes)
runner = web.AppRunner(app) runner = web.AppRunner(app, handler_cancellation=True)
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(
runner, runner,
+24 -9
View File
@@ -53,6 +53,13 @@ test_args.add_argument(
default="https://run.vast.ai", default="https://run.vast.ai",
help="Call local autoscaler instead of prod, for dev use only", help="Call local autoscaler instead of prod, for dev use only",
) )
test_args.add_argument(
"-i",
dest="instance",
type=str,
default="prod",
help="Autoscaler shard to run the command against, default: prod",
)
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]] GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
@@ -70,6 +77,7 @@ class ClientState:
api_key: str api_key: str
server_url: str server_url: str
worker_endpoint: str worker_endpoint: str
instance: str
payload: ApiPayload payload: ApiPayload
url: str = "" url: str = ""
status: ClientStatus = ClientStatus.FetchEndpoint status: ClientStatus = ClientStatus.FetchEndpoint
@@ -79,11 +87,7 @@ class ClientState:
def make_call(self): def make_call(self):
self.status = ClientStatus.FetchEndpoint self.status = ClientStatus.FetchEndpoint
endpoint_api_key = Endpoint.get_endpoint_api_key( if not self.api_key:
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
)
if not endpoint_api_key:
self.as_error.append( self.as_error.append(
f"Endpoint {self.endpoint_group_name} not found for API key", f"Endpoint {self.endpoint_group_name} not found for API key",
) )
@@ -91,12 +95,14 @@ class ClientState:
return return
route_payload = { route_payload = {
"endpoint": self.endpoint_group_name, "endpoint": self.endpoint_group_name,
"api_key": endpoint_api_key, "api_key": self.api_key,
"cost": self.payload.count_workload(), "cost": self.payload.count_workload(),
} }
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post( response = requests.post(
urljoin(self.server_url, "/route/"), urljoin(self.server_url, "/route/"),
json=route_payload, json=route_payload,
headers=headers,
timeout=4, timeout=4,
) )
if response.status_code != 200: if response.status_code != 200:
@@ -135,6 +141,7 @@ class ClientState:
try: try:
self.make_call() self.make_call()
except Exception as e: except Exception as e:
print(e)
self.status = ClientStatus.Error self.status = ClientStatus.Error
_ = e _ = e
self.conn_errors[self.url] += 1 self.conn_errors[self.url] += 1
@@ -226,6 +233,7 @@ def run_test(
server_url: str, server_url: str,
worker_endpoint: str, worker_endpoint: str,
payload_cls: Type[ApiPayload], payload_cls: Type[ApiPayload],
instance: str,
): ):
threads = [] threads = []
@@ -234,8 +242,7 @@ def run_test(
print_thread.daemon = True # makes threads get killed on program exit print_thread.daemon = True # makes threads get killed on program exit
print_thread.start() print_thread.start()
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=endpoint_group_name, endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
account_api_key=api_key,
) )
if not endpoint_api_key: if not endpoint_api_key:
log.debug(f"Endpoint {endpoint_group_name} not found for API key") log.debug(f"Endpoint {endpoint_group_name} not found for API key")
@@ -248,6 +255,7 @@ def run_test(
server_url=server_url, server_url=server_url,
worker_endpoint=worker_endpoint, worker_endpoint=worker_endpoint,
payload=payload_cls.for_test(), payload=payload_cls.for_test(),
instance=instance,
) )
clients.append(client) clients.append(client)
thread = threading.Thread(target=client.simulate_user, args=()) thread = threading.Thread(target=client.simulate_user, args=())
@@ -281,12 +289,19 @@ def test_load_cmd(
args = arg_parser.parse_args() args = arg_parser.parse_args()
if hasattr(args, "comfy_model"): if hasattr(args, "comfy_model"):
os.environ["COMFY_MODEL"] = args.comfy_model os.environ["COMFY_MODEL"] = args.comfy_model
server_url = dict(
prod="https://run.vast.ai",
alpha="https://run-alpha.vast.ai",
candidate="https://run-candidate.vast.ai",
local="http://localhost:8080",
)[args.instance]
run_test( run_test(
num_requests=args.num_requests, num_requests=args.num_requests,
requests_per_second=args.requests_per_second, requests_per_second=args.requests_per_second,
api_key=args.api_key, api_key=args.api_key,
server_url=args.server_url, server_url=server_url,
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
worker_endpoint=endpoint, worker_endpoint=endpoint,
payload_cls=payload_cls, payload_cls=payload_cls,
instance=args.instance,
) )
+3 -3
View File
@@ -87,14 +87,14 @@ if [ "$USE_SSL" = true ]; then
IP.1 = 0.0.0.0 IP.1 = 0.0.0.0
EOF EOF
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \ openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
-nodes \ -nodes \
-sha256 \ -sha256 \
-keyout /etc/instance.key \ -keyout /etc/instance.key \
-out /etc/instance.csr \ -out /etc/instance.csr \
-config /etc/openssl-san.cnf -config /etc/openssl-san.cnf
curl --header 'Content-Type: application/octet-stream' \ curl --header 'Content-Type: application/octet-stream' \
--data-binary @//etc/instance.csr \ --data-binary @//etc/instance.csr \
-X \ -X \
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
@@ -103,7 +103,7 @@ fi
export REPORT_ADDR WORKER_PORT USE_SSL export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
cd "$SERVER_DIR" cd "$SERVER_DIR"
+6 -2
View File
@@ -17,7 +17,9 @@ class Endpoint:
""" """
@staticmethod @staticmethod
def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]: def get_endpoint_api_key(
endpoint_name: str, account_api_key: str, instance: str
) -> Optional[str]:
""" """
Fetch endpoint API key from VastAI console following the healthcheck pattern. Fetch endpoint API key from VastAI console following the healthcheck pattern.
@@ -33,7 +35,9 @@ class Endpoint:
try: try:
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}") log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
response = requests.get(vast_console_url, headers=headers) response = requests.get(
f"{vast_console_url}?autoscaler_instance={instance}", headers=headers
)
if response.status_code != 200: if response.status_code != 200:
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}" error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
+1
View File
@@ -153,6 +153,7 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name, endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key, account_api_key=args.api_key,
instance=args.instance,
) )
if endpoint_api_key: if endpoint_api_key:
try: try:
+1
View File
@@ -100,6 +100,7 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name, endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key, account_api_key=args.api_key,
instance=args.instance,
) )
if endpoint_api_key: if endpoint_api_key:
try: try: