From 8cb98c84f9509a66e8f9fb30e6ce6ecef14bb414 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Fri, 24 Oct 2025 19:08:36 -0700 Subject: [PATCH] non vibe coded test_load --- workers/openai/test_load.py | 122 +++++++++++++++++++++++------------- 1 file changed, 79 insertions(+), 43 deletions(-) diff --git a/workers/openai/test_load.py b/workers/openai/test_load.py index cf0e9c0..5fb29a2 100644 --- a/workers/openai/test_load.py +++ b/workers/openai/test_load.py @@ -42,6 +42,7 @@ class ReqResult: total_ms: float ok: bool error: str = "" + status_code: int = 0 t_start: float = 0.0 t_end: float = 0.0 workload: float = 0.0 @@ -58,31 +59,72 @@ def do_one(endpoint_name: str, route_session, worker_session): try: - u = payload.count_workload() - route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": u} + workload = payload.count_workload() + route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload} headers = {"Authorization": f"Bearer {endpoint_api_key}"} start = time.time() r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4) t_after_route = time.time() if r0.status_code != 200: - results_list.append(ReqResult("", (t_after_route - start) * 1000.0, 0.0, (t_after_route - start) * 1000.0, False, - f"route {r0.status_code} {r0.text}")) + results_list.append(ReqResult(worker_url="", + route_ms=(t_after_route - start) * 1000.0, + worker_ms=0.0, + total_ms=(t_after_route - start) * 1000.0, + ok=False, + error=f"route error {r0.reason} {r0.text}", + status_code=r0.status_code, + t_start=start - t0, + t_end=t_after_route - t0, + workload=workload)) return msg = r0.json() - # 1) "Status" is in the response when no worker is ready - worker_sampled = True - status = msg.get("status", "") - if status: + # 1) Check if we got a worker back from route + worker_url = msg.get("url", "") + if not worker_url: m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S) if m: tot, loading, standby, err = map(int, m.groups()) idle = max(tot - loading - standby - err, 0) status_samples.append((time.time() - t0, idle)) - worker_sampled = False - # 2) Otherwise (successful request), sample via /get_endpoint_workers/ for eligible (idle) worker tracking - if worker_sampled: + # 2) If we got a worker, send the request + if worker_url: + req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__) + t_before_worker = time.time() + r1 = worker_session.post( + urljoin(worker_url, worker_endpoint), + json=req, + verify=get_cert_file_path(), + timeout=(4, 120), + ) + t_after_worker = time.time() + if r1.status_code != 200: + results_list.append(ReqResult(worker_url=worker_url, + route_ms=(t_after_route - start) * 1000.0, + worker_ms=(t_after_worker - t_before_worker) * 1000.0, + total_ms=(t_after_worker - start) * 1000.0, + ok=False, + error=f"worker inference error {r1.reason} {r1.text}", + status_code=r1.status_code, + t_start=start - t0, + t_end=t_after_worker - t0, + workload=workload)) + return + # Success case + results_list.append(ReqResult(worker_url=worker_url, + route_ms=(t_after_route - start) * 1000.0, + worker_ms=(t_after_worker - t_before_worker) * 1000.0, + total_ms=(t_after_worker - start) * 1000.0, + ok=True, + error="", + status_code=200, + t_start=start - t0, + t_end=t_after_worker - t0, + workload=workload)) + + # 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking + if worker_url: try: r_status = route_session.post( urljoin(server_url, "/get_endpoint_workers/"), @@ -100,29 +142,18 @@ def do_one(endpoint_name: str, status_samples.append((time.time() - t0, idle)) except Exception: pass - - # 3) Send the request - worker_address = msg["url"] - req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__) - t1 = time.time() - # Use explicit connect/read timeouts to avoid long hangs - r1 = worker_session.post( - urljoin(worker_address, worker_endpoint), - json=req, - verify=get_cert_file_path(), - timeout=(4, 120), - ) - t2 = time.time() - if r1.status_code != 200: - results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0, - (t2 - start) * 1000.0, False, - f"infer {r1.status_code} {r1.text}")) - return - results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0, (t2 - start) * 1000.0, - True, "", t_start=start - t0, t_end=t2 - t0, workload=u)) except Exception as e: t = time.time() - results_list.append(ReqResult("", (t - start) * 1000.0, 0.0, (t - start) * 1000.0, False, str(e))) + results_list.append(ReqResult(worker_url="", + route_ms=0.0, + worker_ms=0.0, + total_ms=0.0, + ok=False, + error=f"unknown error {e}", + status_code=0, + t_start=t - t0, + t_end=t - t0, + workload=0.0)) def run_load_with_metrics(num_requests: int, requests_per_second: float, @@ -132,7 +163,7 @@ def run_load_with_metrics(num_requests: int, worker_endpoint: str, instance: str, out_path: str): - # Resolve endpoint id + endpoint-scoped API key + ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name, account_api_key=account_api_key, instance=instance) @@ -145,8 +176,7 @@ def run_load_with_metrics(num_requests: int, t0 = time.time() results = [] status_samples = [] - # Concurrency control - max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "1024")) + max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192")) submit_queue_factor = 2 # cap queued tasks to reduce memory # Shared HTTP sessions with connection pooling (persistent connections) @@ -158,9 +188,9 @@ def run_load_with_metrics(num_requests: int, return sess # Router: mostly single host, small connection pool is sufficient - route_session = make_session(pool_connections=8, pool_maxsize=max_concurrency) + route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency) # Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency - worker_session = make_session(pool_connections=max(256, max_concurrency), pool_maxsize=max_concurrency) + worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8) # Fire requests using a thread pool, scheduling at requested RPS inflight = set() @@ -209,11 +239,12 @@ def run_load_with_metrics(num_requests: int, total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([]) worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([]) - #route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([]) + route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([]) avg_total = float(np.mean(total_ms)) if succ else 0.0 + avg_worker = float(np.mean(worker_ms)) if succ else 0.0 + avg_route = float(np.mean(route_ms)) if succ else 0.0 p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0) - total_compute_time_ms = float(np.sum(worker_ms)) if succ else 0.0 # Distribution over workers (by host:port) hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url] @@ -240,11 +271,11 @@ def run_load_with_metrics(num_requests: int, print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}") print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}") - print(f"Total compute time (sum worker latency, s): {total_compute_time_ms/1000.0:.2f}") + print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}") if errs: print("Sample errors:") for e in errs[:5]: - print(f" {e.error}") + print(f" {e.status_code} {e.error}") # Plot: 2x3 grid fig, axes = plt.subplots(2, 3, figsize=(15, 8)) @@ -298,9 +329,14 @@ def run_load_with_metrics(num_requests: int, text = ( f"Total requests: {total_reqs}\n" f"Success: {succ} Errors: {len(errs)}\n" - f"Avg latency: {avg_total:.1f} ms\n" + f"Avg total latency: {avg_total:.1f} ms\n" f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n" - f"Total compute time: {total_compute_time_ms/1000.0:.2f} s" + f"Avg route latency: {avg_route:.1f} ms\n" + f"Avg worker latency: {avg_worker:.1f} ms\n" + f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n" + f"429 errors: {len([r for r in errs if r.status_code == 429])}\n" + f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n" + f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n" ) ax3.set_title("Summary") ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)