Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1d09d7fe96 |
+2
-13
@@ -190,30 +190,18 @@ class Backend:
|
|||||||
log.debug(f"Exception in main handler loop {e}")
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def healthcheck_session(self):
|
|
||||||
"""Dedicated session for healthchecks to avoid conflicts with API session"""
|
|
||||||
log.debug("creating dedicated healthcheck session")
|
|
||||||
connector = TCPConnector(
|
|
||||||
force_close=True, # Keep this for isolation
|
|
||||||
enable_cleanup_closed=True,
|
|
||||||
)
|
|
||||||
timeout = ClientTimeout(total=10) # Reasonable timeout for healthchecks
|
|
||||||
return ClientSession(timeout=timeout, connector=connector)
|
|
||||||
|
|
||||||
async def __healthcheck(self):
|
async def __healthcheck(self):
|
||||||
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
||||||
if health_check_url is None:
|
if health_check_url is None:
|
||||||
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||||
return
|
return
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
await sleep(10)
|
await sleep(10)
|
||||||
if self.__start_healthcheck is False:
|
if self.__start_healthcheck is False:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||||
async with self.healthcheck_session.get(health_check_url) as response:
|
async with self.session.get(health_check_url) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
log.debug("Healthcheck successful")
|
log.debug("Healthcheck successful")
|
||||||
elif response.status == 503:
|
elif response.status == 503:
|
||||||
@@ -222,6 +210,7 @@ class Backend:
|
|||||||
f"Healthcheck failed with status: {response.status}"
|
f"Healthcheck failed with status: {response.status}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# endpoint not ready yet so bail
|
||||||
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Healthcheck failed with exception: {e}")
|
log.debug(f"Healthcheck failed with exception: {e}")
|
||||||
|
|||||||
+6
-6
@@ -292,12 +292,12 @@ def test_load_cmd(
|
|||||||
args = arg_parser.parse_args()
|
args = arg_parser.parse_args()
|
||||||
if hasattr(args, "comfy_model"):
|
if hasattr(args, "comfy_model"):
|
||||||
os.environ["COMFY_MODEL"] = args.comfy_model
|
os.environ["COMFY_MODEL"] = args.comfy_model
|
||||||
server_url = {
|
server_url = dict(
|
||||||
"prod": "https://run.vast.ai",
|
prod="https://run.vast.ai",
|
||||||
"alpha": "https://run-alpha.vast.ai",
|
alpha="https://run-alpha.vast.ai",
|
||||||
"candidate": "https://run-candidate.vast.ai",
|
candidate="https://run-candidate.vast.ai",
|
||||||
"local": "http://localhost:8080",
|
local="http://localhost:8080",
|
||||||
}.get(args.instance, "http://localhost:8080")
|
)[args.instance]
|
||||||
run_test(
|
run_test(
|
||||||
num_requests=args.num_requests,
|
num_requests=args.num_requests,
|
||||||
requests_per_second=args.requests_per_second,
|
requests_per_second=args.requests_per_second,
|
||||||
|
|||||||
+2
-2
@@ -59,12 +59,12 @@ then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Fork testing
|
# Fork testing
|
||||||
git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
|
[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
|
||||||
if [[ -n ${PYWORKER_REF:-} ]]; then
|
if [[ -n ${PYWORKER_REF:-} ]]; then
|
||||||
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
|
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
|
||||||
fi
|
fi
|
||||||
|
|
||||||
uv venv --managed-python "$ENV_PATH" -p 3.10
|
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
|
||||||
source "$ENV_PATH/bin/activate"
|
source "$ENV_PATH/bin/activate"
|
||||||
|
|
||||||
uv pip install -r "${SERVER_DIR}/requirements.txt"
|
uv pip install -r "${SERVER_DIR}/requirements.txt"
|
||||||
|
|||||||
+5
-43
@@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
from typing import Any, Dict, Optional
|
||||||
from typing import Any, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -17,38 +16,6 @@ class Endpoint:
|
|||||||
Utility class for handling endpoint operations.
|
Utility class for handling endpoint operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_endpoint_info(
|
|
||||||
endpoint_name: str, account_api_key: str, instance: str
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
|
||||||
url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}"
|
|
||||||
# Retry a few times to smooth over transient propagation/network delays
|
|
||||||
for attempt in range(4):
|
|
||||||
try:
|
|
||||||
response = requests.get(url, headers=headers, timeout=8)
|
|
||||||
if response.status_code != 200:
|
|
||||||
# brief backoff and retry
|
|
||||||
time.sleep(0.3 * (attempt + 1))
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
data = response.json()
|
|
||||||
except Exception:
|
|
||||||
# JSON parse failed; backoff and retry
|
|
||||||
time.sleep(0.3 * (attempt + 1))
|
|
||||||
continue
|
|
||||||
result = data.get("results", []) if isinstance(data, dict) else []
|
|
||||||
endpoint = next(
|
|
||||||
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if endpoint and endpoint.get("id") and endpoint.get("api_key"):
|
|
||||||
return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")}
|
|
||||||
except Exception:
|
|
||||||
# network or other transient error; retry
|
|
||||||
time.sleep(0.3 * (attempt + 1))
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_autoscaler_server_url(instance: str) -> str:
|
def get_autoscaler_server_url(instance: str) -> str:
|
||||||
endpoints = {
|
endpoints = {
|
||||||
@@ -56,10 +23,7 @@ class Endpoint:
|
|||||||
"candidate": "run-candidate",
|
"candidate": "run-candidate",
|
||||||
"prod": "run",
|
"prod": "run",
|
||||||
}
|
}
|
||||||
host = endpoints.get(instance)
|
return f"https://{endpoints[instance]}.vast.ai/"
|
||||||
if host:
|
|
||||||
return f"https://{host}.vast.ai/"
|
|
||||||
return "http://localhost:8080"
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_server_url(instance: str) -> str:
|
def get_server_url(instance: str) -> str:
|
||||||
@@ -68,8 +32,7 @@ class Endpoint:
|
|||||||
"candidate": "candidate",
|
"candidate": "candidate",
|
||||||
"prod": "console",
|
"prod": "console",
|
||||||
}
|
}
|
||||||
host = endpoints.get(instance, "alpha")
|
return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
|
||||||
return f"https://{host}.vast.ai/api/v0/endptjobs/"
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_endpoint_api_key(
|
def get_endpoint_api_key(
|
||||||
@@ -92,7 +55,6 @@ class Endpoint:
|
|||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=8,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
@@ -102,14 +64,14 @@ class Endpoint:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
except Exception as e:
|
except requests.exceptions.JSONDecodeError as e:
|
||||||
log.debug(f"Failed to parse JSON response: {e}")
|
log.debug(f"Failed to parse JSON response: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
result = data.get("results", [])
|
result = data.get("results", [])
|
||||||
|
|
||||||
endpoint: Optional[Dict[str, Any]] = next(
|
endpoint: Optional[Dict[str, Any]] = next(
|
||||||
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
(item for item in result if item["endpoint_name"] == endpoint_name),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if not endpoint:
|
if not endpoint:
|
||||||
|
|||||||
+7
-367
@@ -1,349 +1,8 @@
|
|||||||
from lib.test_utils import test_args
|
from lib.test_utils import test_load_cmd, test_args
|
||||||
from utils.endpoint_util import Endpoint
|
|
||||||
from utils.ssl import get_cert_file_path
|
|
||||||
from lib.data_types import AuthData
|
|
||||||
from .data_types.server import CompletionsData
|
from .data_types.server import CompletionsData
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import threading
|
|
||||||
import requests
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from collections import Counter
|
|
||||||
from urllib.parse import urljoin, urlparse
|
|
||||||
import re
|
|
||||||
|
|
||||||
# Headless plotting
|
WORKER_ENDPOINT = "/v1/completions"
|
||||||
import matplotlib
|
|
||||||
matplotlib.use("Agg")
|
|
||||||
import logging
|
|
||||||
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
|
|
||||||
from requests.adapters import HTTPAdapter
|
|
||||||
|
|
||||||
def get_incremented_path(path: str) -> str:
|
|
||||||
base, ext = os.path.splitext(path)
|
|
||||||
if not os.path.exists(path):
|
|
||||||
return path
|
|
||||||
i = 1
|
|
||||||
while os.path.exists(f"{base}-{i}{ext}"):
|
|
||||||
i += 1
|
|
||||||
return f"{base}-{i}{ext}"
|
|
||||||
|
|
||||||
WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT)
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ReqResult:
|
|
||||||
worker_url: str
|
|
||||||
route_ms: float
|
|
||||||
worker_ms: float
|
|
||||||
total_ms: float
|
|
||||||
ok: bool
|
|
||||||
error: str = ""
|
|
||||||
t_start: float = 0.0
|
|
||||||
t_end: float = 0.0
|
|
||||||
workload: float = 0.0
|
|
||||||
|
|
||||||
def do_one(endpoint_name: str,
|
|
||||||
endpoint_id: int,
|
|
||||||
endpoint_api_key: str,
|
|
||||||
server_url: str,
|
|
||||||
worker_endpoint: str,
|
|
||||||
payload,
|
|
||||||
results_list,
|
|
||||||
t0,
|
|
||||||
status_samples,
|
|
||||||
route_session,
|
|
||||||
worker_session):
|
|
||||||
try:
|
|
||||||
u = payload.count_workload()
|
|
||||||
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": u}
|
|
||||||
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
|
|
||||||
start = time.time()
|
|
||||||
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
|
|
||||||
t_after_route = time.time()
|
|
||||||
if r0.status_code != 200:
|
|
||||||
results_list.append(ReqResult("", (t_after_route - start) * 1000.0, 0.0, (t_after_route - start) * 1000.0, False,
|
|
||||||
f"route {r0.status_code} {r0.text}"))
|
|
||||||
return
|
|
||||||
msg = r0.json()
|
|
||||||
|
|
||||||
# 1) "Status" is in the response when no worker is ready
|
|
||||||
worker_sampled = True
|
|
||||||
status = msg.get("status", "")
|
|
||||||
if status:
|
|
||||||
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
|
|
||||||
if m:
|
|
||||||
tot, loading, standby, err = map(int, m.groups())
|
|
||||||
idle = max(tot - loading - standby - err, 0)
|
|
||||||
status_samples.append((time.time() - t0, idle))
|
|
||||||
worker_sampled = False
|
|
||||||
|
|
||||||
# 2) Otherwise (successful request), sample via /get_endpoint_workers/ for eligible (idle) worker tracking
|
|
||||||
if worker_sampled:
|
|
||||||
try:
|
|
||||||
r_status = route_session.post(
|
|
||||||
urljoin(server_url, "/get_endpoint_workers/"),
|
|
||||||
json={"id": endpoint_id},
|
|
||||||
headers={"Authorization": f"Bearer {endpoint_api_key}"},
|
|
||||||
timeout=3,
|
|
||||||
)
|
|
||||||
if r_status.status_code == 200:
|
|
||||||
workers = r_status.json()
|
|
||||||
idle = 0
|
|
||||||
for w in workers:
|
|
||||||
st = str(w.get("status", "")).lower()
|
|
||||||
if (st in ("idle")):
|
|
||||||
idle += 1
|
|
||||||
status_samples.append((time.time() - t0, idle))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 3) Send the request
|
|
||||||
worker_address = msg["url"]
|
|
||||||
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
|
|
||||||
t1 = time.time()
|
|
||||||
# Use explicit connect/read timeouts to avoid long hangs
|
|
||||||
r1 = worker_session.post(
|
|
||||||
urljoin(worker_address, worker_endpoint),
|
|
||||||
json=req,
|
|
||||||
verify=get_cert_file_path(),
|
|
||||||
timeout=(4, 120),
|
|
||||||
)
|
|
||||||
t2 = time.time()
|
|
||||||
if r1.status_code != 200:
|
|
||||||
results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0,
|
|
||||||
(t2 - start) * 1000.0, False,
|
|
||||||
f"infer {r1.status_code} {r1.text}"))
|
|
||||||
return
|
|
||||||
results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0, (t2 - start) * 1000.0,
|
|
||||||
True, "", t_start=start - t0, t_end=t2 - t0, workload=u))
|
|
||||||
except Exception as e:
|
|
||||||
t = time.time()
|
|
||||||
results_list.append(ReqResult("", (t - start) * 1000.0, 0.0, (t - start) * 1000.0, False, str(e)))
|
|
||||||
|
|
||||||
def run_load_with_metrics(num_requests: int,
|
|
||||||
requests_per_second: float,
|
|
||||||
endpoint_group_name: str,
|
|
||||||
account_api_key: str,
|
|
||||||
server_url: str,
|
|
||||||
worker_endpoint: str,
|
|
||||||
instance: str,
|
|
||||||
out_path: str):
|
|
||||||
# Resolve endpoint id + endpoint-scoped API key
|
|
||||||
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
|
|
||||||
account_api_key=account_api_key,
|
|
||||||
instance=instance)
|
|
||||||
if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"):
|
|
||||||
print(f"Endpoint {endpoint_group_name} not found for API key")
|
|
||||||
return
|
|
||||||
endpoint_id = int(ep_info["id"])
|
|
||||||
endpoint_api_key = ep_info["api_key"]
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
results = []
|
|
||||||
status_samples = []
|
|
||||||
# Concurrency control
|
|
||||||
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "1024"))
|
|
||||||
submit_queue_factor = 2 # cap queued tasks to reduce memory
|
|
||||||
|
|
||||||
# Shared HTTP sessions with connection pooling (persistent connections)
|
|
||||||
def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session:
|
|
||||||
sess = requests.Session()
|
|
||||||
adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0)
|
|
||||||
sess.mount("https://", adapter)
|
|
||||||
sess.mount("http://", adapter)
|
|
||||||
return sess
|
|
||||||
|
|
||||||
# Router: mostly single host, small connection pool is sufficient
|
|
||||||
route_session = make_session(pool_connections=8, pool_maxsize=max_concurrency)
|
|
||||||
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
|
|
||||||
worker_session = make_session(pool_connections=max(256, max_concurrency), pool_maxsize=max_concurrency)
|
|
||||||
|
|
||||||
# Fire requests using a thread pool, scheduling at requested RPS
|
|
||||||
inflight = set()
|
|
||||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
|
||||||
for i in range(num_requests):
|
|
||||||
# Pace submissions to RPS
|
|
||||||
target_time = t0 + i / max(requests_per_second, 1e-9)
|
|
||||||
sleep_s = target_time - time.time()
|
|
||||||
if sleep_s > 0:
|
|
||||||
time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive
|
|
||||||
|
|
||||||
payload = CompletionsData.for_test()
|
|
||||||
fut = executor.submit(
|
|
||||||
do_one,
|
|
||||||
endpoint_group_name,
|
|
||||||
endpoint_id,
|
|
||||||
endpoint_api_key,
|
|
||||||
server_url,
|
|
||||||
worker_endpoint,
|
|
||||||
payload,
|
|
||||||
results,
|
|
||||||
t0,
|
|
||||||
status_samples,
|
|
||||||
route_session,
|
|
||||||
worker_session,
|
|
||||||
)
|
|
||||||
inflight.add(fut)
|
|
||||||
# Prevent unbounded queue growth
|
|
||||||
if len(inflight) >= max_concurrency * submit_queue_factor:
|
|
||||||
done, not_done = wait(inflight, return_when=FIRST_COMPLETED)
|
|
||||||
inflight = not_done
|
|
||||||
# Wait for all outstanding tasks
|
|
||||||
if inflight:
|
|
||||||
wait(inflight)
|
|
||||||
# Close sessions
|
|
||||||
try:
|
|
||||||
route_session.close()
|
|
||||||
finally:
|
|
||||||
worker_session.close()
|
|
||||||
|
|
||||||
# Aggregate results
|
|
||||||
oks = [r for r in results if r.ok]
|
|
||||||
errs = [r for r in results if not r.ok]
|
|
||||||
total_reqs = len(results)
|
|
||||||
succ = len(oks)
|
|
||||||
|
|
||||||
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
|
|
||||||
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
|
|
||||||
#route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
|
|
||||||
|
|
||||||
avg_total = float(np.mean(total_ms)) if succ else 0.0
|
|
||||||
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
|
|
||||||
total_compute_time_ms = float(np.sum(worker_ms)) if succ else 0.0
|
|
||||||
|
|
||||||
# Distribution over workers (by host:port)
|
|
||||||
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
|
|
||||||
dist = Counter(hosts)
|
|
||||||
|
|
||||||
# Idle over time (mode per second)
|
|
||||||
idle_ts, idle_vals = [], []
|
|
||||||
if status_samples:
|
|
||||||
buckets = {}
|
|
||||||
for ts, idle in status_samples:
|
|
||||||
k = int(ts)
|
|
||||||
buckets.setdefault(k, []).append(idle)
|
|
||||||
keys = sorted(buckets.keys())
|
|
||||||
idle_ts = keys
|
|
||||||
# Use the most frequent sampled value per second (mode) to keep integer counts
|
|
||||||
idle_vals = []
|
|
||||||
for k in keys:
|
|
||||||
vals_k = [int(v) for v in buckets[k]]
|
|
||||||
if vals_k:
|
|
||||||
cnt = Counter(vals_k)
|
|
||||||
idle_vals.append(cnt.most_common(1)[0][0])
|
|
||||||
else:
|
|
||||||
idle_vals.append(0)
|
|
||||||
|
|
||||||
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
|
|
||||||
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
|
|
||||||
print(f"Total compute time (sum worker latency, s): {total_compute_time_ms/1000.0:.2f}")
|
|
||||||
if errs:
|
|
||||||
print("Sample errors:")
|
|
||||||
for e in errs[:5]:
|
|
||||||
print(f" {e.error}")
|
|
||||||
|
|
||||||
# Plot: 2x3 grid
|
|
||||||
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
|
||||||
fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}")
|
|
||||||
|
|
||||||
# Dist per worker
|
|
||||||
ax0 = axes[0, 0]
|
|
||||||
if dist:
|
|
||||||
items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True)
|
|
||||||
labels, counts = zip(*items)
|
|
||||||
ax0.bar(range(len(labels)), counts)
|
|
||||||
ax0.set_xticks(range(len(labels)))
|
|
||||||
ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
|
||||||
ax0.set_title("Request distribution over workers")
|
|
||||||
ax0.set_ylabel("count")
|
|
||||||
|
|
||||||
# Latency histogram (total)
|
|
||||||
ax1 = axes[0, 1]
|
|
||||||
if succ:
|
|
||||||
ax1.hist(total_ms, bins=30, color="#4e79a7")
|
|
||||||
ax1.set_title("Total latency (ms)")
|
|
||||||
ax1.set_xlabel("ms")
|
|
||||||
ax1.set_ylabel("freq")
|
|
||||||
|
|
||||||
# Eligible workers over time
|
|
||||||
ax_idle = axes[0, 2]
|
|
||||||
if idle_ts:
|
|
||||||
ax_idle.plot(idle_ts, idle_vals, "-o", ms=3)
|
|
||||||
ax_idle.set_title("Eligible workers over time")
|
|
||||||
ax_idle.set_xlabel("time (s)")
|
|
||||||
ax_idle.set_ylabel("eligible count")
|
|
||||||
|
|
||||||
# Throughput over time (completions/sec)
|
|
||||||
ax_idle = axes[1, 0]
|
|
||||||
ax_idle.clear()
|
|
||||||
if succ:
|
|
||||||
per_sec = {}
|
|
||||||
for r in oks:
|
|
||||||
s = int(r.t_end)
|
|
||||||
per_sec[s] = per_sec.get(s, 0) + 1
|
|
||||||
ts = sorted(per_sec.keys())
|
|
||||||
vals = [per_sec[t] for t in ts]
|
|
||||||
ax_idle.plot(ts, vals, "-o", ms=3)
|
|
||||||
ax_idle.set_title("Completions per second")
|
|
||||||
ax_idle.set_xlabel("time (s)")
|
|
||||||
ax_idle.set_ylabel("req/s")
|
|
||||||
|
|
||||||
# Summary text
|
|
||||||
ax3 = axes[1, 1]
|
|
||||||
ax3.axis("off")
|
|
||||||
text = (
|
|
||||||
f"Total requests: {total_reqs}\n"
|
|
||||||
f"Success: {succ} Errors: {len(errs)}\n"
|
|
||||||
f"Avg latency: {avg_total:.1f} ms\n"
|
|
||||||
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
|
|
||||||
f"Total compute time: {total_compute_time_ms/1000.0:.2f} s"
|
|
||||||
)
|
|
||||||
ax3.set_title("Summary")
|
|
||||||
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
|
|
||||||
|
|
||||||
# Latency CDF (total_ms)
|
|
||||||
ax_cdf = axes[1, 2]
|
|
||||||
if succ:
|
|
||||||
x = np.sort(total_ms)
|
|
||||||
y = np.linspace(0, 1, len(x), endpoint=True)
|
|
||||||
ax_cdf.plot(x, y)
|
|
||||||
ax_cdf.set_title("Latency CDF")
|
|
||||||
ax_cdf.set_xlabel("ms")
|
|
||||||
ax_cdf.set_ylabel("fraction ≤ x")
|
|
||||||
|
|
||||||
# Ensure unique output path and create directory if needed
|
|
||||||
final_out_path = get_incremented_path(out_path)
|
|
||||||
out_dir = os.path.dirname(final_out_path)
|
|
||||||
if out_dir:
|
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
|
||||||
|
|
||||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
|
||||||
plt.savefig(final_out_path, dpi=120)
|
|
||||||
print(f"Saved report to: {final_out_path}")
|
|
||||||
|
|
||||||
# Per-worker latency boxplot (top 12 by volume)
|
|
||||||
groups = {}
|
|
||||||
for r in oks:
|
|
||||||
host = urlparse(r.worker_url).netloc
|
|
||||||
groups.setdefault(host, []).append(r.total_ms)
|
|
||||||
items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12]
|
|
||||||
if items:
|
|
||||||
labels, data = zip(*items)
|
|
||||||
fig2, axb = plt.subplots(1, 1, figsize=(12, 5))
|
|
||||||
axb.boxplot(data, showfliers=False)
|
|
||||||
axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
|
||||||
axb.set_title("Per-worker latency (ms)")
|
|
||||||
axb.set_ylabel("ms")
|
|
||||||
plt.tight_layout()
|
|
||||||
extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png")
|
|
||||||
plt.savefig(extra_out, dpi=120)
|
|
||||||
fig2.tight_layout()
|
|
||||||
fig2.savefig(extra_out, dpi=120)
|
|
||||||
print(f"Saved worker latency plot to: {extra_out}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Check if MODEL_NAME environment variable is set
|
# Check if MODEL_NAME environment variable is set
|
||||||
@@ -357,32 +16,13 @@ if __name__ == "__main__":
|
|||||||
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse known args to get model early, before adding load args
|
# Parse known args to get model early, before test_load_cmd adds its args
|
||||||
known_args, _ = test_args.parse_known_args()
|
known_args, _ = test_args.parse_known_args()
|
||||||
|
|
||||||
|
# Set environment variable if model was provided
|
||||||
if hasattr(known_args, "model") and known_args.model:
|
if hasattr(known_args, "model") and known_args.model:
|
||||||
os.environ["MODEL_NAME"] = known_args.model
|
os.environ["MODEL_NAME"] = known_args.model
|
||||||
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
||||||
|
|
||||||
# Load test args
|
# Now call test_load_cmd normally - it will add its own args and re-parse
|
||||||
test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
|
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||||
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
|
|
||||||
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
|
|
||||||
args = test_args.parse_args()
|
|
||||||
|
|
||||||
server_url = {
|
|
||||||
"prod": "https://run.vast.ai",
|
|
||||||
"alpha": "https://run-alpha.vast.ai",
|
|
||||||
"candidate": "https://run-candidate.vast.ai",
|
|
||||||
"local": "http://localhost:8080"
|
|
||||||
}.get(args.instance, "http://localhost:8080")
|
|
||||||
|
|
||||||
run_load_with_metrics(
|
|
||||||
num_requests=args.num_requests,
|
|
||||||
requests_per_second=args.requests_per_second,
|
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
|
||||||
account_api_key=args.api_key,
|
|
||||||
server_url=server_url,
|
|
||||||
worker_endpoint=WORKER_ENDPOINT,
|
|
||||||
instance=args.instance,
|
|
||||||
out_path=args.out_path,
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user