From e251afda2b877db513c906f1a3682f2ab8d9ab5a Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Thu, 9 Oct 2025 19:37:39 -0700 Subject: [PATCH] improved test load --- lib/test_utils.py | 12 +++++----- utils/endpoint_util.py | 48 +++++++++++++++++++++++++++++++++---- workers/openai/test_load.py | 35 ++++++++++++++++++++++----- 3 files changed, 78 insertions(+), 17 deletions(-) diff --git a/lib/test_utils.py b/lib/test_utils.py index 8635027..d64a4b6 100644 --- a/lib/test_utils.py +++ b/lib/test_utils.py @@ -292,12 +292,12 @@ def test_load_cmd( args = arg_parser.parse_args() if hasattr(args, "comfy_model"): os.environ["COMFY_MODEL"] = args.comfy_model - server_url = dict( - prod="https://run.vast.ai", - alpha="https://run-alpha.vast.ai", - candidate="https://run-candidate.vast.ai", - local="http://localhost:8080", - )[args.instance] + server_url = { + "prod": "https://run.vast.ai", + "alpha": "https://run-alpha.vast.ai", + "candidate": "https://run-candidate.vast.ai", + "local": "http://localhost:8080", + }.get(args.instance, "http://localhost:8080") run_test( num_requests=args.num_requests, requests_per_second=args.requests_per_second, diff --git a/utils/endpoint_util.py b/utils/endpoint_util.py index 37930af..927262e 100644 --- a/utils/endpoint_util.py +++ b/utils/endpoint_util.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Dict, Optional +import time +from typing import Any, Dict, Optional, Tuple import requests @@ -16,6 +17,38 @@ class Endpoint: Utility class for handling endpoint operations. """ + @staticmethod + def get_endpoint_info( + endpoint_name: str, account_api_key: str, instance: str + ) -> Optional[Dict[str, Any]]: + headers = {"Authorization": f"Bearer {account_api_key}"} + url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}" + # Retry a few times to smooth over transient propagation/network delays + for attempt in range(4): + try: + response = requests.get(url, headers=headers, timeout=8) + if response.status_code != 200: + # brief backoff and retry + time.sleep(0.3 * (attempt + 1)) + continue + try: + data = response.json() + except Exception: + # JSON parse failed; backoff and retry + time.sleep(0.3 * (attempt + 1)) + continue + result = data.get("results", []) if isinstance(data, dict) else [] + endpoint = next( + (item for item in result if item.get("endpoint_name") == endpoint_name), + None, + ) + if endpoint and endpoint.get("id") and endpoint.get("api_key"): + return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")} + except Exception: + # network or other transient error; retry + time.sleep(0.3 * (attempt + 1)) + return None + @staticmethod def get_autoscaler_server_url(instance: str) -> str: endpoints = { @@ -23,7 +56,10 @@ class Endpoint: "candidate": "run-candidate", "prod": "run", } - return f"https://{endpoints[instance]}.vast.ai/" + host = endpoints.get(instance) + if host: + return f"https://{host}.vast.ai/" + return "http://localhost:8080" @staticmethod def get_server_url(instance: str) -> str: @@ -32,7 +68,8 @@ class Endpoint: "candidate": "candidate", "prod": "console", } - return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/" + host = endpoints.get(instance, "alpha") + return f"https://{host}.vast.ai/api/v0/endptjobs/" @staticmethod def get_endpoint_api_key( @@ -55,6 +92,7 @@ class Endpoint: response = requests.get( f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}", headers=headers, + timeout=8, ) if response.status_code != 200: @@ -64,14 +102,14 @@ class Endpoint: try: data = response.json() - except requests.exceptions.JSONDecodeError as e: + except Exception as e: log.debug(f"Failed to parse JSON response: {e}") return None result = data.get("results", []) endpoint: Optional[Dict[str, Any]] = next( - (item for item in result if item["endpoint_name"] == endpoint_name), + (item for item in result if item.get("endpoint_name") == endpoint_name), None, ) if not endpoint: diff --git a/workers/openai/test_load.py b/workers/openai/test_load.py index 5c12bc6..cf0e9c0 100644 --- a/workers/openai/test_load.py +++ b/workers/openai/test_load.py @@ -1,5 +1,9 @@ -from lib.test_utils import test_load_cmd, test_args +from lib.test_utils import test_args +from utils.endpoint_util import Endpoint +from utils.ssl import get_cert_file_path +from lib.data_types import AuthData from .data_types.server import CompletionsData + import os import time import threading @@ -353,13 +357,32 @@ if __name__ == "__main__": help="Model to use for completions request (required if MODEL_NAME env var not set)", ) - # Parse known args to get model early, before test_load_cmd adds its args + # Parse known args to get model early, before adding load args known_args, _ = test_args.parse_known_args() - - # Set environment variable if model was provided if hasattr(known_args, "model") and known_args.model: os.environ["MODEL_NAME"] = known_args.model print(f"Set MODEL_NAME environment variable to: {known_args.model}") - # Now call test_load_cmd normally - it will add its own args and re-parse - test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args) + # Load test args + test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests") + test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second") + test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image") + args = test_args.parse_args() + + server_url = { + "prod": "https://run.vast.ai", + "alpha": "https://run-alpha.vast.ai", + "candidate": "https://run-candidate.vast.ai", + "local": "http://localhost:8080" + }.get(args.instance, "http://localhost:8080") + + run_load_with_metrics( + num_requests=args.num_requests, + requests_per_second=args.requests_per_second, + endpoint_group_name=args.endpoint_group_name, + account_api_key=args.api_key, + server_url=server_url, + worker_endpoint=WORKER_ENDPOINT, + instance=args.instance, + out_path=args.out_path, + ) \ No newline at end of file