From a00c1adab5ac9152f78aaf912d0be9fd9911deeb Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Thu, 9 Oct 2025 19:37:39 -0700 Subject: [PATCH] improved test load --- lib/test_utils.py | 12 +- utils/endpoint_util.py | 48 ++++- workers/openai/test_load.py | 374 +++++++++++++++++++++++++++++++++++- 3 files changed, 415 insertions(+), 19 deletions(-) diff --git a/lib/test_utils.py b/lib/test_utils.py index 8635027..d64a4b6 100644 --- a/lib/test_utils.py +++ b/lib/test_utils.py @@ -292,12 +292,12 @@ def test_load_cmd( args = arg_parser.parse_args() if hasattr(args, "comfy_model"): os.environ["COMFY_MODEL"] = args.comfy_model - server_url = dict( - prod="https://run.vast.ai", - alpha="https://run-alpha.vast.ai", - candidate="https://run-candidate.vast.ai", - local="http://localhost:8080", - )[args.instance] + server_url = { + "prod": "https://run.vast.ai", + "alpha": "https://run-alpha.vast.ai", + "candidate": "https://run-candidate.vast.ai", + "local": "http://localhost:8080", + }.get(args.instance, "http://localhost:8080") run_test( num_requests=args.num_requests, requests_per_second=args.requests_per_second, diff --git a/utils/endpoint_util.py b/utils/endpoint_util.py index 37930af..927262e 100644 --- a/utils/endpoint_util.py +++ b/utils/endpoint_util.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Dict, Optional +import time +from typing import Any, Dict, Optional, Tuple import requests @@ -16,6 +17,38 @@ class Endpoint: Utility class for handling endpoint operations. """ + @staticmethod + def get_endpoint_info( + endpoint_name: str, account_api_key: str, instance: str + ) -> Optional[Dict[str, Any]]: + headers = {"Authorization": f"Bearer {account_api_key}"} + url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}" + # Retry a few times to smooth over transient propagation/network delays + for attempt in range(4): + try: + response = requests.get(url, headers=headers, timeout=8) + if response.status_code != 200: + # brief backoff and retry + time.sleep(0.3 * (attempt + 1)) + continue + try: + data = response.json() + except Exception: + # JSON parse failed; backoff and retry + time.sleep(0.3 * (attempt + 1)) + continue + result = data.get("results", []) if isinstance(data, dict) else [] + endpoint = next( + (item for item in result if item.get("endpoint_name") == endpoint_name), + None, + ) + if endpoint and endpoint.get("id") and endpoint.get("api_key"): + return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")} + except Exception: + # network or other transient error; retry + time.sleep(0.3 * (attempt + 1)) + return None + @staticmethod def get_autoscaler_server_url(instance: str) -> str: endpoints = { @@ -23,7 +56,10 @@ class Endpoint: "candidate": "run-candidate", "prod": "run", } - return f"https://{endpoints[instance]}.vast.ai/" + host = endpoints.get(instance) + if host: + return f"https://{host}.vast.ai/" + return "http://localhost:8080" @staticmethod def get_server_url(instance: str) -> str: @@ -32,7 +68,8 @@ class Endpoint: "candidate": "candidate", "prod": "console", } - return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/" + host = endpoints.get(instance, "alpha") + return f"https://{host}.vast.ai/api/v0/endptjobs/" @staticmethod def get_endpoint_api_key( @@ -55,6 +92,7 @@ class Endpoint: response = requests.get( f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}", headers=headers, + timeout=8, ) if response.status_code != 200: @@ -64,14 +102,14 @@ class Endpoint: try: data = response.json() - except requests.exceptions.JSONDecodeError as e: + except Exception as e: log.debug(f"Failed to parse JSON response: {e}") return None result = data.get("results", []) endpoint: Optional[Dict[str, Any]] = next( - (item for item in result if item["endpoint_name"] == endpoint_name), + (item for item in result if item.get("endpoint_name") == endpoint_name), None, ) if not endpoint: diff --git a/workers/openai/test_load.py b/workers/openai/test_load.py index 0c45524..b0e81d3 100644 --- a/workers/openai/test_load.py +++ b/workers/openai/test_load.py @@ -1,8 +1,347 @@ -from lib.test_utils import test_load_cmd, test_args +from lib.test_utils import test_args +from utils.endpoint_util import Endpoint +from utils.ssl import get_cert_file_path +from lib.data_types import AuthData from .data_types.server import CompletionsData -import os -WORKER_ENDPOINT = "/v1/completions" +import os +import time +import threading +import requests +from dataclasses import dataclass +from collections import Counter +from urllib.parse import urljoin, urlparse +import re + +# Headless plotting +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED +from requests.adapters import HTTPAdapter + +def get_incremented_path(path: str) -> str: + base, ext = os.path.splitext(path) + if not os.path.exists(path): + return path + i = 1 + while os.path.exists(f"{base}-{i}{ext}"): + i += 1 + return f"{base}-{i}{ext}" + +WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT) + +@dataclass +class ReqResult: + worker_url: str + route_ms: float + worker_ms: float + total_ms: float + ok: bool + error: str = "" + t_start: float = 0.0 + t_end: float = 0.0 + workload: float = 0.0 + +def do_one(endpoint_name: str, + endpoint_id: int, + endpoint_api_key: str, + server_url: str, + worker_endpoint: str, + payload, + results_list, + t0, + status_samples, + route_session, + worker_session): + try: + u = payload.count_workload() + route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": u} + headers = {"Authorization": f"Bearer {endpoint_api_key}"} + start = time.time() + r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4) + t_after_route = time.time() + if r0.status_code != 200: + results_list.append(ReqResult("", (t_after_route - start) * 1000.0, 0.0, (t_after_route - start) * 1000.0, False, + f"route {r0.status_code} {r0.text}")) + return + msg = r0.json() + + # 1) "Status" is in the response when no worker is ready + worker_sampled = True + status = msg.get("status", "") + if status: + m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S) + if m: + tot, loading, standby, err = map(int, m.groups()) + idle = max(tot - loading - standby - err, 0) + status_samples.append((time.time() - t0, idle)) + worker_sampled = False + + # 2) Otherwise (successful request), sample via /get_endpoint_workers/ for eligible (idle) worker tracking + if worker_sampled: + try: + r_status = route_session.post( + urljoin(server_url, "/get_endpoint_workers/"), + json={"id": endpoint_id}, + headers={"Authorization": f"Bearer {endpoint_api_key}"}, + timeout=3, + ) + if r_status.status_code == 200: + workers = r_status.json() + idle = 0 + for w in workers: + st = str(w.get("status", "")).lower() + if (st in ("idle")): + idle += 1 + status_samples.append((time.time() - t0, idle)) + except Exception: + pass + + # 3) Send the request + worker_address = msg["url"] + req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__) + t1 = time.time() + # Use explicit connect/read timeouts to avoid long hangs + r1 = worker_session.post( + urljoin(worker_address, worker_endpoint), + json=req, + verify=get_cert_file_path(), + timeout=(4, 120), + ) + t2 = time.time() + if r1.status_code != 200: + results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0, + (t2 - start) * 1000.0, False, + f"infer {r1.status_code} {r1.text}")) + return + results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0, (t2 - start) * 1000.0, + True, "", t_start=start - t0, t_end=t2 - t0, workload=u)) + except Exception as e: + t = time.time() + results_list.append(ReqResult("", (t - start) * 1000.0, 0.0, (t - start) * 1000.0, False, str(e))) + +def run_load_with_metrics(num_requests: int, + requests_per_second: float, + endpoint_group_name: str, + account_api_key: str, + server_url: str, + worker_endpoint: str, + instance: str, + out_path: str): + # Resolve endpoint id + endpoint-scoped API key + ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name, + account_api_key=account_api_key, + instance=instance) + if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"): + print(f"Endpoint {endpoint_group_name} not found for API key") + return + endpoint_id = int(ep_info["id"]) + endpoint_api_key = ep_info["api_key"] + + t0 = time.time() + results = [] + status_samples = [] + # Concurrency control + max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "1024")) + submit_queue_factor = 2 # cap queued tasks to reduce memory + + # Shared HTTP sessions with connection pooling (persistent connections) + def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session: + sess = requests.Session() + adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0) + sess.mount("https://", adapter) + sess.mount("http://", adapter) + return sess + + # Router: mostly single host, small connection pool is sufficient + route_session = make_session(pool_connections=8, pool_maxsize=max_concurrency) + # Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency + worker_session = make_session(pool_connections=max(256, max_concurrency), pool_maxsize=max_concurrency) + + # Fire requests using a thread pool, scheduling at requested RPS + inflight = set() + with ThreadPoolExecutor(max_workers=max_concurrency) as executor: + for i in range(num_requests): + # Pace submissions to RPS + target_time = t0 + i / max(requests_per_second, 1e-9) + sleep_s = target_time - time.time() + if sleep_s > 0: + time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive + + payload = CompletionsData.for_test() + fut = executor.submit( + do_one, + endpoint_group_name, + endpoint_id, + endpoint_api_key, + server_url, + worker_endpoint, + payload, + results, + t0, + status_samples, + route_session, + worker_session, + ) + inflight.add(fut) + # Prevent unbounded queue growth + if len(inflight) >= max_concurrency * submit_queue_factor: + done, not_done = wait(inflight, return_when=FIRST_COMPLETED) + inflight = not_done + # Wait for all outstanding tasks + if inflight: + wait(inflight) + # Close sessions + try: + route_session.close() + finally: + worker_session.close() + + # Aggregate results + oks = [r for r in results if r.ok] + errs = [r for r in results if not r.ok] + total_reqs = len(results) + succ = len(oks) + + total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([]) + worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([]) + #route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([]) + + avg_total = float(np.mean(total_ms)) if succ else 0.0 + p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0) + total_compute_time_ms = float(np.sum(worker_ms)) if succ else 0.0 + + # Distribution over workers (by host:port) + hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url] + dist = Counter(hosts) + + # Idle over time (mode per second) + idle_ts, idle_vals = [], [] + if status_samples: + buckets = {} + for ts, idle in status_samples: + k = int(ts) + buckets.setdefault(k, []).append(idle) + keys = sorted(buckets.keys()) + idle_ts = keys + # Use the most frequent sampled value per second (mode) to keep integer counts + idle_vals = [] + for k in keys: + vals_k = [int(v) for v in buckets[k]] + if vals_k: + cnt = Counter(vals_k) + idle_vals.append(cnt.most_common(1)[0][0]) + else: + idle_vals.append(0) + + print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}") + print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}") + print(f"Total compute time (sum worker latency, s): {total_compute_time_ms/1000.0:.2f}") + if errs: + print("Sample errors:") + for e in errs[:5]: + print(f" {e.error}") + + # Plot: 2x3 grid + fig, axes = plt.subplots(2, 3, figsize=(15, 8)) + fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}") + + # Dist per worker + ax0 = axes[0, 0] + if dist: + items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True) + labels, counts = zip(*items) + ax0.bar(range(len(labels)), counts) + ax0.set_xticks(range(len(labels))) + ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) + ax0.set_title("Request distribution over workers") + ax0.set_ylabel("count") + + # Latency histogram (total) + ax1 = axes[0, 1] + if succ: + ax1.hist(total_ms, bins=30, color="#4e79a7") + ax1.set_title("Total latency (ms)") + ax1.set_xlabel("ms") + ax1.set_ylabel("freq") + + # Eligible workers over time + ax_idle = axes[0, 2] + if idle_ts: + ax_idle.plot(idle_ts, idle_vals, "-o", ms=3) + ax_idle.set_title("Eligible workers over time") + ax_idle.set_xlabel("time (s)") + ax_idle.set_ylabel("eligible count") + + # Throughput over time (completions/sec) + ax_idle = axes[1, 0] + ax_idle.clear() + if succ: + per_sec = {} + for r in oks: + s = int(r.t_end) + per_sec[s] = per_sec.get(s, 0) + 1 + ts = sorted(per_sec.keys()) + vals = [per_sec[t] for t in ts] + ax_idle.plot(ts, vals, "-o", ms=3) + ax_idle.set_title("Completions per second") + ax_idle.set_xlabel("time (s)") + ax_idle.set_ylabel("req/s") + + # Summary text + ax3 = axes[1, 1] + ax3.axis("off") + text = ( + f"Total requests: {total_reqs}\n" + f"Success: {succ} Errors: {len(errs)}\n" + f"Avg latency: {avg_total:.1f} ms\n" + f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n" + f"Total compute time: {total_compute_time_ms/1000.0:.2f} s" + ) + ax3.set_title("Summary") + ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes) + + # Latency CDF (total_ms) + ax_cdf = axes[1, 2] + if succ: + x = np.sort(total_ms) + y = np.linspace(0, 1, len(x), endpoint=True) + ax_cdf.plot(x, y) + ax_cdf.set_title("Latency CDF") + ax_cdf.set_xlabel("ms") + ax_cdf.set_ylabel("fraction ≤ x") + + # Ensure unique output path and create directory if needed + final_out_path = get_incremented_path(out_path) + out_dir = os.path.dirname(final_out_path) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + + plt.tight_layout(rect=[0, 0, 1, 0.96]) + plt.savefig(final_out_path, dpi=120) + print(f"Saved report to: {final_out_path}") + + # Per-worker latency boxplot (top 12 by volume) + groups = {} + for r in oks: + host = urlparse(r.worker_url).netloc + groups.setdefault(host, []).append(r.total_ms) + items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12] + if items: + labels, data = zip(*items) + fig2, axb = plt.subplots(1, 1, figsize=(12, 5)) + axb.boxplot(data, showfliers=False) + axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) + axb.set_title("Per-worker latency (ms)") + axb.set_ylabel("ms") + plt.tight_layout() + extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png") + plt.savefig(extra_out, dpi=120) + fig2.tight_layout() + fig2.savefig(extra_out, dpi=120) + print(f"Saved worker latency plot to: {extra_out}") if __name__ == "__main__": # Check if MODEL_NAME environment variable is set @@ -16,13 +355,32 @@ if __name__ == "__main__": help="Model to use for completions request (required if MODEL_NAME env var not set)", ) - # Parse known args to get model early, before test_load_cmd adds its args + # Parse known args to get model early, before adding load args known_args, _ = test_args.parse_known_args() - - # Set environment variable if model was provided if hasattr(known_args, "model") and known_args.model: os.environ["MODEL_NAME"] = known_args.model print(f"Set MODEL_NAME environment variable to: {known_args.model}") - # Now call test_load_cmd normally - it will add its own args and re-parse - test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args) + # Load test args + test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests") + test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second") + test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image") + args = test_args.parse_args() + + server_url = { + "prod": "https://run.vast.ai", + "alpha": "https://run-alpha.vast.ai", + "candidate": "https://run-candidate.vast.ai", + "local": "http://localhost:8080" + }.get(args.instance, "http://localhost:8080") + + run_load_with_metrics( + num_requests=args.num_requests, + requests_per_second=args.requests_per_second, + endpoint_group_name=args.endpoint_group_name, + account_api_key=args.api_key, + server_url=server_url, + worker_endpoint=WORKER_ENDPOINT, + instance=args.instance, + out_path=args.out_path, + ) \ No newline at end of file