improved test load
This commit is contained in:
+6
-6
@@ -292,12 +292,12 @@ def test_load_cmd(
|
|||||||
args = arg_parser.parse_args()
|
args = arg_parser.parse_args()
|
||||||
if hasattr(args, "comfy_model"):
|
if hasattr(args, "comfy_model"):
|
||||||
os.environ["COMFY_MODEL"] = args.comfy_model
|
os.environ["COMFY_MODEL"] = args.comfy_model
|
||||||
server_url = dict(
|
server_url = {
|
||||||
prod="https://run.vast.ai",
|
"prod": "https://run.vast.ai",
|
||||||
alpha="https://run-alpha.vast.ai",
|
"alpha": "https://run-alpha.vast.ai",
|
||||||
candidate="https://run-candidate.vast.ai",
|
"candidate": "https://run-candidate.vast.ai",
|
||||||
local="http://localhost:8080",
|
"local": "http://localhost:8080",
|
||||||
)[args.instance]
|
}.get(args.instance, "http://localhost:8080")
|
||||||
run_test(
|
run_test(
|
||||||
num_requests=args.num_requests,
|
num_requests=args.num_requests,
|
||||||
requests_per_second=args.requests_per_second,
|
requests_per_second=args.requests_per_second,
|
||||||
|
|||||||
+43
-5
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
import time
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -16,6 +17,38 @@ class Endpoint:
|
|||||||
Utility class for handling endpoint operations.
|
Utility class for handling endpoint operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_endpoint_info(
|
||||||
|
endpoint_name: str, account_api_key: str, instance: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||||
|
url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}"
|
||||||
|
# Retry a few times to smooth over transient propagation/network delays
|
||||||
|
for attempt in range(4):
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=headers, timeout=8)
|
||||||
|
if response.status_code != 200:
|
||||||
|
# brief backoff and retry
|
||||||
|
time.sleep(0.3 * (attempt + 1))
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except Exception:
|
||||||
|
# JSON parse failed; backoff and retry
|
||||||
|
time.sleep(0.3 * (attempt + 1))
|
||||||
|
continue
|
||||||
|
result = data.get("results", []) if isinstance(data, dict) else []
|
||||||
|
endpoint = next(
|
||||||
|
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if endpoint and endpoint.get("id") and endpoint.get("api_key"):
|
||||||
|
return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")}
|
||||||
|
except Exception:
|
||||||
|
# network or other transient error; retry
|
||||||
|
time.sleep(0.3 * (attempt + 1))
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_autoscaler_server_url(instance: str) -> str:
|
def get_autoscaler_server_url(instance: str) -> str:
|
||||||
endpoints = {
|
endpoints = {
|
||||||
@@ -23,7 +56,10 @@ class Endpoint:
|
|||||||
"candidate": "run-candidate",
|
"candidate": "run-candidate",
|
||||||
"prod": "run",
|
"prod": "run",
|
||||||
}
|
}
|
||||||
return f"https://{endpoints[instance]}.vast.ai/"
|
host = endpoints.get(instance)
|
||||||
|
if host:
|
||||||
|
return f"https://{host}.vast.ai/"
|
||||||
|
return "http://localhost:8080"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_server_url(instance: str) -> str:
|
def get_server_url(instance: str) -> str:
|
||||||
@@ -32,7 +68,8 @@ class Endpoint:
|
|||||||
"candidate": "candidate",
|
"candidate": "candidate",
|
||||||
"prod": "console",
|
"prod": "console",
|
||||||
}
|
}
|
||||||
return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
|
host = endpoints.get(instance, "alpha")
|
||||||
|
return f"https://{host}.vast.ai/api/v0/endptjobs/"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_endpoint_api_key(
|
def get_endpoint_api_key(
|
||||||
@@ -55,6 +92,7 @@ class Endpoint:
|
|||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
timeout=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
@@ -64,14 +102,14 @@ class Endpoint:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
except requests.exceptions.JSONDecodeError as e:
|
except Exception as e:
|
||||||
log.debug(f"Failed to parse JSON response: {e}")
|
log.debug(f"Failed to parse JSON response: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
result = data.get("results", [])
|
result = data.get("results", [])
|
||||||
|
|
||||||
endpoint: Optional[Dict[str, Any]] = next(
|
endpoint: Optional[Dict[str, Any]] = next(
|
||||||
(item for item in result if item["endpoint_name"] == endpoint_name),
|
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if not endpoint:
|
if not endpoint:
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
from lib.test_utils import test_load_cmd, test_args
|
from lib.test_utils import test_args
|
||||||
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
from lib.data_types import AuthData
|
||||||
from .data_types.server import CompletionsData
|
from .data_types.server import CompletionsData
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
@@ -353,13 +357,32 @@ if __name__ == "__main__":
|
|||||||
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse known args to get model early, before test_load_cmd adds its args
|
# Parse known args to get model early, before adding load args
|
||||||
known_args, _ = test_args.parse_known_args()
|
known_args, _ = test_args.parse_known_args()
|
||||||
|
|
||||||
# Set environment variable if model was provided
|
|
||||||
if hasattr(known_args, "model") and known_args.model:
|
if hasattr(known_args, "model") and known_args.model:
|
||||||
os.environ["MODEL_NAME"] = known_args.model
|
os.environ["MODEL_NAME"] = known_args.model
|
||||||
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
||||||
|
|
||||||
# Now call test_load_cmd normally - it will add its own args and re-parse
|
# Load test args
|
||||||
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)
|
test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
|
||||||
|
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
|
||||||
|
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
|
||||||
|
args = test_args.parse_args()
|
||||||
|
|
||||||
|
server_url = {
|
||||||
|
"prod": "https://run.vast.ai",
|
||||||
|
"alpha": "https://run-alpha.vast.ai",
|
||||||
|
"candidate": "https://run-candidate.vast.ai",
|
||||||
|
"local": "http://localhost:8080"
|
||||||
|
}.get(args.instance, "http://localhost:8080")
|
||||||
|
|
||||||
|
run_load_with_metrics(
|
||||||
|
num_requests=args.num_requests,
|
||||||
|
requests_per_second=args.requests_per_second,
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
|
account_api_key=args.api_key,
|
||||||
|
server_url=server_url,
|
||||||
|
worker_endpoint=WORKER_ENDPOINT,
|
||||||
|
instance=args.instance,
|
||||||
|
out_path=args.out_path,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user