diff --git a/lib/backend.py b/lib/backend.py index e01f7c4..dea39a3 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -5,9 +5,10 @@ import base64 import subprocess import dataclasses import logging -from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task +from asyncio import sleep, gather, Semaphore from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from functools import cached_property +from distutils.util import strtobool from anyio import open_file from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError @@ -55,6 +56,9 @@ class Backend: reqnum = -1 msg_history = [] sem: Semaphore = dataclasses.field(default_factory=Semaphore) + unsecured: bool = dataclasses.field( + default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), + ) def __post_init__(self): self.metrics = Metrics() @@ -119,16 +123,6 @@ class Backend: return web.json_response(dict(error="invalid JSON"), status=422) workload = payload.count_workload() - async def wait_for_disconnection() -> None: - while request.transport and not request.transport.is_closing(): - await sleep(0.5) - - async def cancel_api_call_if_disconnected() -> web.Response: - await wait_for_disconnection() - log.debug(f"request with reqnum: {auth_data.reqnum} was canceled") - self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum) - return web.Response(status=500) - async def make_request() -> Union[web.Response, web.StreamResponse]: log.debug(f"got request, {auth_data.reqnum}") self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum) @@ -174,18 +168,16 @@ class Backend: return web.Response(status=401) try: - done, pending = await wait( - [ - create_task(make_request()), - create_task(cancel_api_call_if_disconnected()), - ], - return_when=FIRST_COMPLETED, - ) - [task.cancel() for task in pending] - return done.pop().result() + return await make_request() except Exception as e: log.debug(f"Exception in main handler loop {e}") return web.Response(status=500) + finally: + if request.task.cancelled(): + log.debug(f"request with reqnum: {auth_data.reqnum} was canceled") + self.metrics._request_canceled( + workload=workload, reqnum=auth_data.reqnum + ) async def __healthcheck(self): health_check_url = self.benchmark_handler.healthcheck_endpoint @@ -229,6 +221,9 @@ class Backend: return await self.session.post(url=handler.endpoint, json=api_payload) def __check_signature(self, auth_data: AuthData) -> bool: + if self.unsecured is True: + return True + def verify_signature(message, signature): if self.pubkey is None: log.debug(f"No Public Key!") diff --git a/lib/server.py b/lib/server.py index b21c880..80e2959 100644 --- a/lib/server.py +++ b/lib/server.py @@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs): log.debug("starting server...") app = web.Application() app.add_routes(routes) - runner = web.AppRunner(app) + runner = web.AppRunner(app, handler_cancellation=True) await runner.setup() site = web.TCPSite( runner, diff --git a/lib/test_utils.py b/lib/test_utils.py index 791b7dd..ba97611 100644 --- a/lib/test_utils.py +++ b/lib/test_utils.py @@ -53,6 +53,13 @@ test_args.add_argument( default="https://run.vast.ai", help="Call local autoscaler instead of prod, for dev use only", ) +test_args.add_argument( + "-i", + dest="instance", + type=str, + default="prod", + help="Autoscaler shard to run the command against, default: prod", +) GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]] @@ -70,6 +77,7 @@ class ClientState: api_key: str server_url: str worker_endpoint: str + instance: str payload: ApiPayload url: str = "" status: ClientStatus = ClientStatus.FetchEndpoint @@ -79,11 +87,7 @@ class ClientState: def make_call(self): self.status = ClientStatus.FetchEndpoint - endpoint_api_key = Endpoint.get_endpoint_api_key( - endpoint_name=self.endpoint_group_name, - account_api_key=self.api_key, - ) - if not endpoint_api_key: + if not self.api_key: self.as_error.append( f"Endpoint {self.endpoint_group_name} not found for API key", ) @@ -91,12 +95,14 @@ class ClientState: return route_payload = { "endpoint": self.endpoint_group_name, - "api_key": endpoint_api_key, + "api_key": self.api_key, "cost": self.payload.count_workload(), } + headers = {"Authorization": f"Bearer {self.api_key}"} response = requests.post( urljoin(self.server_url, "/route/"), json=route_payload, + headers=headers, timeout=4, ) if response.status_code != 200: @@ -135,6 +141,7 @@ class ClientState: try: self.make_call() except Exception as e: + print(e) self.status = ClientStatus.Error _ = e self.conn_errors[self.url] += 1 @@ -226,6 +233,7 @@ def run_test( server_url: str, worker_endpoint: str, payload_cls: Type[ApiPayload], + instance: str, ): threads = [] @@ -234,8 +242,7 @@ def run_test( print_thread.daemon = True # makes threads get killed on program exit print_thread.start() endpoint_api_key = Endpoint.get_endpoint_api_key( - endpoint_name=endpoint_group_name, - account_api_key=api_key, + endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance ) if not endpoint_api_key: log.debug(f"Endpoint {endpoint_group_name} not found for API key") @@ -248,6 +255,7 @@ def run_test( server_url=server_url, worker_endpoint=worker_endpoint, payload=payload_cls.for_test(), + instance=instance, ) clients.append(client) thread = threading.Thread(target=client.simulate_user, args=()) @@ -281,12 +289,19 @@ def test_load_cmd( args = arg_parser.parse_args() if hasattr(args, "comfy_model"): os.environ["COMFY_MODEL"] = args.comfy_model + server_url = dict( + prod="https://run.vast.ai", + alpha="https://run-alpha.vast.ai", + candidate="https://run-candidate.vast.ai", + local="http://localhost:8080", + )[args.instance] run_test( num_requests=args.num_requests, requests_per_second=args.requests_per_second, api_key=args.api_key, - server_url=args.server_url, + server_url=server_url, endpoint_group_name=args.endpoint_group_name, worker_endpoint=endpoint, payload_cls=payload_cls, + instance=args.instance, ) diff --git a/start_server.sh b/start_server.sh index 57a96e8..e6949c5 100755 --- a/start_server.sh +++ b/start_server.sh @@ -87,23 +87,23 @@ if [ "$USE_SSL" = true ]; then IP.1 = 0.0.0.0 EOF -openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \ - -nodes \ - -sha256 \ - -keyout /etc/instance.key \ - -out /etc/instance.csr \ - -config /etc/openssl-san.cnf + openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \ + -nodes \ + -sha256 \ + -keyout /etc/instance.key \ + -out /etc/instance.csr \ + -config /etc/openssl-san.cnf -curl --header 'Content-Type: application/octet-stream' \ - --data-binary @//etc/instance.csr \ - -X \ - POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; + curl --header 'Content-Type: application/octet-stream' \ + --data-binary @//etc/instance.csr \ + -X \ + POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; fi -export REPORT_ADDR WORKER_PORT USE_SSL +export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED cd "$SERVER_DIR" diff --git a/utils/endpoint_util.py b/utils/endpoint_util.py index 42fea1c..48c4cdb 100644 --- a/utils/endpoint_util.py +++ b/utils/endpoint_util.py @@ -17,7 +17,9 @@ class Endpoint: """ @staticmethod - def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]: + def get_endpoint_api_key( + endpoint_name: str, account_api_key: str, instance: str + ) -> Optional[str]: """ Fetch endpoint API key from VastAI console following the healthcheck pattern. @@ -33,7 +35,9 @@ class Endpoint: try: log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}") - response = requests.get(vast_console_url, headers=headers) + response = requests.get( + f"{vast_console_url}?autoscaler_instance={instance}", headers=headers + ) if response.status_code != 200: error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}" diff --git a/workers/comfyui/client.py b/workers/comfyui/client.py index 23c8ca0..6563e00 100644 --- a/workers/comfyui/client.py +++ b/workers/comfyui/client.py @@ -153,6 +153,7 @@ if __name__ == "__main__": endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_name=args.endpoint_group_name, account_api_key=args.api_key, + instance=args.instance, ) if endpoint_api_key: try: diff --git a/workers/tgi/client.py b/workers/tgi/client.py index cfa91f8..7e4f1bb 100644 --- a/workers/tgi/client.py +++ b/workers/tgi/client.py @@ -100,6 +100,7 @@ if __name__ == "__main__": endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_name=args.endpoint_group_name, account_api_key=args.api_key, + instance=args.instance, ) if endpoint_api_key: try: