Compare commits

..

2 Commits

Author SHA1 Message Date
Lucas Armand 3988cf553f Suppress matplot debug logs 2025-10-10 11:57:46 -07:00
Colter Downing a00c1adab5 improved test load 2025-10-09 19:37:39 -07:00
6 changed files with 441 additions and 61 deletions
+4 -5
View File
@@ -45,7 +45,6 @@ class Metrics:
self.model_metrics.workload_received += workload self.model_metrics.workload_received += workload
self.model_metrics.requests_recieved.add(reqnum) self.model_metrics.requests_recieved.add(reqnum)
self.model_metrics.requests_working.add(reqnum) self.model_metrics.requests_working.add(reqnum)
self.update_pending = True
def _request_end(self, workload: float, reqnum: int) -> None: def _request_end(self, workload: float, reqnum: int) -> None:
""" """
@@ -79,10 +78,10 @@ class Metrics:
elapsed = time.time() - self.last_metric_update elapsed = time.time() - self.last_metric_update
if self.system_metrics.model_is_loaded is False and elapsed >= 10: if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait") log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset() self.__send_metrics_and_reset(elapsed)
elif self.update_pending or elapsed > 10: elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait") log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset() self.__send_metrics_and_reset(elapsed)
def _model_loaded(self, max_throughput: float) -> None: def _model_loaded(self, max_throughput: float) -> None:
self.system_metrics.model_loading_time = ( self.system_metrics.model_loading_time = (
@@ -97,13 +96,13 @@ class Metrics:
#######################################Private####################################### #######################################Private#######################################
def __send_metrics_and_reset(self): def __send_metrics_and_reset(self, elapsed):
def compute_autoscaler_data() -> AutoScalaerData: def compute_autoscaler_data() -> AutoScalaerData:
return AutoScalaerData( return AutoScalaerData(
id=self.id, id=self.id,
loadtime=(self.system_metrics.model_loading_time or 0.0), loadtime=(self.system_metrics.model_loading_time or 0.0),
cur_load=(self.model_metrics.workload_processing), cur_load=(self.model_metrics.workload_processing / elapsed),
max_perf=self.model_metrics.max_throughput, max_perf=self.model_metrics.max_throughput,
cur_perf=self.model_metrics.cur_perf, cur_perf=self.model_metrics.cur_perf,
error_msg=self.model_metrics.error_msg or "", error_msg=self.model_metrics.error_msg or "",
+6 -6
View File
@@ -292,12 +292,12 @@ def test_load_cmd(
args = arg_parser.parse_args() args = arg_parser.parse_args()
if hasattr(args, "comfy_model"): if hasattr(args, "comfy_model"):
os.environ["COMFY_MODEL"] = args.comfy_model os.environ["COMFY_MODEL"] = args.comfy_model
server_url = dict( server_url = {
prod="https://run.vast.ai", "prod": "https://run.vast.ai",
alpha="https://run-alpha.vast.ai", "alpha": "https://run-alpha.vast.ai",
candidate="https://run-candidate.vast.ai", "candidate": "https://run-candidate.vast.ai",
local="http://localhost:8080", "local": "http://localhost:8080",
)[args.instance] }.get(args.instance, "http://localhost:8080")
run_test( run_test(
num_requests=args.num_requests, num_requests=args.num_requests,
requests_per_second=args.requests_per_second, requests_per_second=args.requests_per_second,
+19 -36
View File
@@ -3,7 +3,8 @@
set -e -o pipefail set -e -o pipefail
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}" WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
SERVER_DIR="$WORKSPACE_DIR/worker"
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
ENV_PATH="$WORKSPACE_DIR/worker-env" ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log" DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log" PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
@@ -21,23 +22,24 @@ function echo_var(){
echo "$1: ${!1}" echo "$1: ${!1}"
} }
# Updated validation - BACKEND no longer required, but MODEL_LOG still is [ -z "$BACKEND" ] && echo "BACKEND must be set!" && exit 1
[ -z "$MODEL_LOG" ] && echo "MODEL_LOG must be set!" && exit 1 [ -z "$MODEL_LOG" ] && echo "MODEL_LOG must be set!" && exit 1
[ -z "$HF_TOKEN" ] && echo "HF_TOKEN must be set!" && exit 1 [ -z "$HF_TOKEN" ] && echo "HF_TOKEN must be set!" && exit 1
[ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && echo "For comfyui backends, COMFY_MODEL must be set!" && exit 1
echo "start_server.sh - SDK Worker Version"
echo "start_server.sh"
date date
echo_var BACKEND
echo_var REPORT_ADDR echo_var REPORT_ADDR
echo_var WORKER_PORT echo_var WORKER_PORT
echo_var WORKSPACE_DIR echo_var WORKSPACE_DIR
echo_var SERVER_DIR
echo_var ENV_PATH echo_var ENV_PATH
echo_var DEBUG_LOG echo_var DEBUG_LOG
echo_var PYWORKER_LOG echo_var PYWORKER_LOG
echo_var MODEL_LOG echo_var MODEL_LOG
echo_var MODEL_SERVER_URL
echo_var PYWORKER_REPO
echo_var PYWORKER_REF
# Populate /etc/environment with quoted values # Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then if ! grep -q "VAST" /etc/environment; then
@@ -56,32 +58,16 @@ then
source ~/.local/bin/env source ~/.local/bin/env
fi fi
if [[ ! -d $SERVER_DIR ]]; then # Fork testing
echo "Cloning worker repository..." git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
git clone --depth=1 "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
fi
if [[ -n ${PYWORKER_REF:-} ]]; then if [[ -n ${PYWORKER_REF:-} ]]; then
echo "Checking out ref: $PYWORKER_REF" (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
(
cd "$SERVER_DIR"
git fetch --depth=1 origin "$PYWORKER_REF"
git checkout "$PYWORKER_REF"
)
fi fi
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10 uv venv --managed-python "$ENV_PATH" -p 3.10
source "$ENV_PATH/bin/activate" source "$ENV_PATH/bin/activate"
# Install vast-sdk from server-side-sdk branch uv pip install -r "${SERVER_DIR}/requirements.txt"
echo "Installing vast-sdk from GitHub (server-side-sdk branch)..."
uv pip install "git+https://github.com/vast-ai/vast-sdk.git@server-side-sdk"
# Install requirements from worker repo if they exist
if [ -f "${SERVER_DIR}/requirements.txt" ]; then
echo "Installing additional dependencies from requirements.txt..."
uv pip install -r "${SERVER_DIR}/requirements.txt"
fi
touch ~/.no_auto_tmux touch ~/.no_auto_tmux
else else
@@ -91,12 +77,7 @@ else
echo "venv: $VIRTUAL_ENV" echo "venv: $VIRTUAL_ENV"
fi fi
# Check that worker.py exists [ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
if [ ! -f "$SERVER_DIR/worker.py" ]; then
echo "ERROR: worker.py not found in $SERVER_DIR"
echo "Please ensure your PYWORKER_REPO contains a worker.py file"
exit 1
fi
if [ "$USE_SSL" = true ]; then if [ "$USE_SSL" = true ]; then
@@ -134,6 +115,9 @@ EOF
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
fi fi
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
cd "$SERVER_DIR" cd "$SERVER_DIR"
@@ -144,6 +128,5 @@ echo "launching PyWorker server"
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only # from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG" [ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
# Launch the SDK-based worker instead of the old backend system (python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
(python3 worker.py |& tee -a "$PYWORKER_LOG") & echo "launching PyWorker server done"
echo "launching PyWorker server done"
+43 -5
View File
@@ -1,5 +1,6 @@
import logging import logging
from typing import Any, Dict, Optional import time
from typing import Any, Dict, Optional, Tuple
import requests import requests
@@ -16,6 +17,38 @@ class Endpoint:
Utility class for handling endpoint operations. Utility class for handling endpoint operations.
""" """
@staticmethod
def get_endpoint_info(
endpoint_name: str, account_api_key: str, instance: str
) -> Optional[Dict[str, Any]]:
headers = {"Authorization": f"Bearer {account_api_key}"}
url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}"
# Retry a few times to smooth over transient propagation/network delays
for attempt in range(4):
try:
response = requests.get(url, headers=headers, timeout=8)
if response.status_code != 200:
# brief backoff and retry
time.sleep(0.3 * (attempt + 1))
continue
try:
data = response.json()
except Exception:
# JSON parse failed; backoff and retry
time.sleep(0.3 * (attempt + 1))
continue
result = data.get("results", []) if isinstance(data, dict) else []
endpoint = next(
(item for item in result if item.get("endpoint_name") == endpoint_name),
None,
)
if endpoint and endpoint.get("id") and endpoint.get("api_key"):
return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")}
except Exception:
# network or other transient error; retry
time.sleep(0.3 * (attempt + 1))
return None
@staticmethod @staticmethod
def get_autoscaler_server_url(instance: str) -> str: def get_autoscaler_server_url(instance: str) -> str:
endpoints = { endpoints = {
@@ -23,7 +56,10 @@ class Endpoint:
"candidate": "run-candidate", "candidate": "run-candidate",
"prod": "run", "prod": "run",
} }
return f"https://{endpoints[instance]}.vast.ai/" host = endpoints.get(instance)
if host:
return f"https://{host}.vast.ai/"
return "http://localhost:8080"
@staticmethod @staticmethod
def get_server_url(instance: str) -> str: def get_server_url(instance: str) -> str:
@@ -32,7 +68,8 @@ class Endpoint:
"candidate": "candidate", "candidate": "candidate",
"prod": "console", "prod": "console",
} }
return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/" host = endpoints.get(instance, "alpha")
return f"https://{host}.vast.ai/api/v0/endptjobs/"
@staticmethod @staticmethod
def get_endpoint_api_key( def get_endpoint_api_key(
@@ -55,6 +92,7 @@ class Endpoint:
response = requests.get( response = requests.get(
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}", f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
headers=headers, headers=headers,
timeout=8,
) )
if response.status_code != 200: if response.status_code != 200:
@@ -64,14 +102,14 @@ class Endpoint:
try: try:
data = response.json() data = response.json()
except requests.exceptions.JSONDecodeError as e: except Exception as e:
log.debug(f"Failed to parse JSON response: {e}") log.debug(f"Failed to parse JSON response: {e}")
return None return None
result = data.get("results", []) result = data.get("results", [])
endpoint: Optional[Dict[str, Any]] = next( endpoint: Optional[Dict[str, Any]] = next(
(item for item in result if item["endpoint_name"] == endpoint_name), (item for item in result if item.get("endpoint_name") == endpoint_name),
None, None,
) )
if not endpoint: if not endpoint:
+1 -1
View File
@@ -70,7 +70,7 @@ class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
@property @property
def healthcheck_endpoint(self) -> Optional[str]: def healthcheck_endpoint(self) -> Optional[str]:
return f"{MODEL_SERVER_URL}/health" return "/health"
@classmethod @classmethod
def payload_cls(cls) -> Type[ComfyWorkflowData]: def payload_cls(cls) -> Type[ComfyWorkflowData]:
+368 -8
View File
@@ -1,8 +1,349 @@
from lib.test_utils import test_load_cmd, test_args from lib.test_utils import test_args
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from lib.data_types import AuthData
from .data_types.server import CompletionsData from .data_types.server import CompletionsData
import os
WORKER_ENDPOINT = "/v1/completions" import os
import time
import threading
import requests
from dataclasses import dataclass
from collections import Counter
from urllib.parse import urljoin, urlparse
import re
# Headless plotting
import matplotlib
matplotlib.use("Agg")
import logging
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
import matplotlib.pyplot as plt
import numpy as np
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
from requests.adapters import HTTPAdapter
def get_incremented_path(path: str) -> str:
base, ext = os.path.splitext(path)
if not os.path.exists(path):
return path
i = 1
while os.path.exists(f"{base}-{i}{ext}"):
i += 1
return f"{base}-{i}{ext}"
WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT)
@dataclass
class ReqResult:
worker_url: str
route_ms: float
worker_ms: float
total_ms: float
ok: bool
error: str = ""
t_start: float = 0.0
t_end: float = 0.0
workload: float = 0.0
def do_one(endpoint_name: str,
endpoint_id: int,
endpoint_api_key: str,
server_url: str,
worker_endpoint: str,
payload,
results_list,
t0,
status_samples,
route_session,
worker_session):
try:
u = payload.count_workload()
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": u}
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
start = time.time()
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
t_after_route = time.time()
if r0.status_code != 200:
results_list.append(ReqResult("", (t_after_route - start) * 1000.0, 0.0, (t_after_route - start) * 1000.0, False,
f"route {r0.status_code} {r0.text}"))
return
msg = r0.json()
# 1) "Status" is in the response when no worker is ready
worker_sampled = True
status = msg.get("status", "")
if status:
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
if m:
tot, loading, standby, err = map(int, m.groups())
idle = max(tot - loading - standby - err, 0)
status_samples.append((time.time() - t0, idle))
worker_sampled = False
# 2) Otherwise (successful request), sample via /get_endpoint_workers/ for eligible (idle) worker tracking
if worker_sampled:
try:
r_status = route_session.post(
urljoin(server_url, "/get_endpoint_workers/"),
json={"id": endpoint_id},
headers={"Authorization": f"Bearer {endpoint_api_key}"},
timeout=3,
)
if r_status.status_code == 200:
workers = r_status.json()
idle = 0
for w in workers:
st = str(w.get("status", "")).lower()
if (st in ("idle")):
idle += 1
status_samples.append((time.time() - t0, idle))
except Exception:
pass
# 3) Send the request
worker_address = msg["url"]
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
t1 = time.time()
# Use explicit connect/read timeouts to avoid long hangs
r1 = worker_session.post(
urljoin(worker_address, worker_endpoint),
json=req,
verify=get_cert_file_path(),
timeout=(4, 120),
)
t2 = time.time()
if r1.status_code != 200:
results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0,
(t2 - start) * 1000.0, False,
f"infer {r1.status_code} {r1.text}"))
return
results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0, (t2 - start) * 1000.0,
True, "", t_start=start - t0, t_end=t2 - t0, workload=u))
except Exception as e:
t = time.time()
results_list.append(ReqResult("", (t - start) * 1000.0, 0.0, (t - start) * 1000.0, False, str(e)))
def run_load_with_metrics(num_requests: int,
requests_per_second: float,
endpoint_group_name: str,
account_api_key: str,
server_url: str,
worker_endpoint: str,
instance: str,
out_path: str):
# Resolve endpoint id + endpoint-scoped API key
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
account_api_key=account_api_key,
instance=instance)
if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"):
print(f"Endpoint {endpoint_group_name} not found for API key")
return
endpoint_id = int(ep_info["id"])
endpoint_api_key = ep_info["api_key"]
t0 = time.time()
results = []
status_samples = []
# Concurrency control
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "1024"))
submit_queue_factor = 2 # cap queued tasks to reduce memory
# Shared HTTP sessions with connection pooling (persistent connections)
def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session:
sess = requests.Session()
adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0)
sess.mount("https://", adapter)
sess.mount("http://", adapter)
return sess
# Router: mostly single host, small connection pool is sufficient
route_session = make_session(pool_connections=8, pool_maxsize=max_concurrency)
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
worker_session = make_session(pool_connections=max(256, max_concurrency), pool_maxsize=max_concurrency)
# Fire requests using a thread pool, scheduling at requested RPS
inflight = set()
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
for i in range(num_requests):
# Pace submissions to RPS
target_time = t0 + i / max(requests_per_second, 1e-9)
sleep_s = target_time - time.time()
if sleep_s > 0:
time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive
payload = CompletionsData.for_test()
fut = executor.submit(
do_one,
endpoint_group_name,
endpoint_id,
endpoint_api_key,
server_url,
worker_endpoint,
payload,
results,
t0,
status_samples,
route_session,
worker_session,
)
inflight.add(fut)
# Prevent unbounded queue growth
if len(inflight) >= max_concurrency * submit_queue_factor:
done, not_done = wait(inflight, return_when=FIRST_COMPLETED)
inflight = not_done
# Wait for all outstanding tasks
if inflight:
wait(inflight)
# Close sessions
try:
route_session.close()
finally:
worker_session.close()
# Aggregate results
oks = [r for r in results if r.ok]
errs = [r for r in results if not r.ok]
total_reqs = len(results)
succ = len(oks)
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
#route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
avg_total = float(np.mean(total_ms)) if succ else 0.0
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
total_compute_time_ms = float(np.sum(worker_ms)) if succ else 0.0
# Distribution over workers (by host:port)
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
dist = Counter(hosts)
# Idle over time (mode per second)
idle_ts, idle_vals = [], []
if status_samples:
buckets = {}
for ts, idle in status_samples:
k = int(ts)
buckets.setdefault(k, []).append(idle)
keys = sorted(buckets.keys())
idle_ts = keys
# Use the most frequent sampled value per second (mode) to keep integer counts
idle_vals = []
for k in keys:
vals_k = [int(v) for v in buckets[k]]
if vals_k:
cnt = Counter(vals_k)
idle_vals.append(cnt.most_common(1)[0][0])
else:
idle_vals.append(0)
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
print(f"Total compute time (sum worker latency, s): {total_compute_time_ms/1000.0:.2f}")
if errs:
print("Sample errors:")
for e in errs[:5]:
print(f" {e.error}")
# Plot: 2x3 grid
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}")
# Dist per worker
ax0 = axes[0, 0]
if dist:
items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True)
labels, counts = zip(*items)
ax0.bar(range(len(labels)), counts)
ax0.set_xticks(range(len(labels)))
ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
ax0.set_title("Request distribution over workers")
ax0.set_ylabel("count")
# Latency histogram (total)
ax1 = axes[0, 1]
if succ:
ax1.hist(total_ms, bins=30, color="#4e79a7")
ax1.set_title("Total latency (ms)")
ax1.set_xlabel("ms")
ax1.set_ylabel("freq")
# Eligible workers over time
ax_idle = axes[0, 2]
if idle_ts:
ax_idle.plot(idle_ts, idle_vals, "-o", ms=3)
ax_idle.set_title("Eligible workers over time")
ax_idle.set_xlabel("time (s)")
ax_idle.set_ylabel("eligible count")
# Throughput over time (completions/sec)
ax_idle = axes[1, 0]
ax_idle.clear()
if succ:
per_sec = {}
for r in oks:
s = int(r.t_end)
per_sec[s] = per_sec.get(s, 0) + 1
ts = sorted(per_sec.keys())
vals = [per_sec[t] for t in ts]
ax_idle.plot(ts, vals, "-o", ms=3)
ax_idle.set_title("Completions per second")
ax_idle.set_xlabel("time (s)")
ax_idle.set_ylabel("req/s")
# Summary text
ax3 = axes[1, 1]
ax3.axis("off")
text = (
f"Total requests: {total_reqs}\n"
f"Success: {succ} Errors: {len(errs)}\n"
f"Avg latency: {avg_total:.1f} ms\n"
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
f"Total compute time: {total_compute_time_ms/1000.0:.2f} s"
)
ax3.set_title("Summary")
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
# Latency CDF (total_ms)
ax_cdf = axes[1, 2]
if succ:
x = np.sort(total_ms)
y = np.linspace(0, 1, len(x), endpoint=True)
ax_cdf.plot(x, y)
ax_cdf.set_title("Latency CDF")
ax_cdf.set_xlabel("ms")
ax_cdf.set_ylabel("fraction ≤ x")
# Ensure unique output path and create directory if needed
final_out_path = get_incremented_path(out_path)
out_dir = os.path.dirname(final_out_path)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(final_out_path, dpi=120)
print(f"Saved report to: {final_out_path}")
# Per-worker latency boxplot (top 12 by volume)
groups = {}
for r in oks:
host = urlparse(r.worker_url).netloc
groups.setdefault(host, []).append(r.total_ms)
items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12]
if items:
labels, data = zip(*items)
fig2, axb = plt.subplots(1, 1, figsize=(12, 5))
axb.boxplot(data, showfliers=False)
axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
axb.set_title("Per-worker latency (ms)")
axb.set_ylabel("ms")
plt.tight_layout()
extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png")
plt.savefig(extra_out, dpi=120)
fig2.tight_layout()
fig2.savefig(extra_out, dpi=120)
print(f"Saved worker latency plot to: {extra_out}")
if __name__ == "__main__": if __name__ == "__main__":
# Check if MODEL_NAME environment variable is set # Check if MODEL_NAME environment variable is set
@@ -16,13 +357,32 @@ if __name__ == "__main__":
help="Model to use for completions request (required if MODEL_NAME env var not set)", help="Model to use for completions request (required if MODEL_NAME env var not set)",
) )
# Parse known args to get model early, before test_load_cmd adds its args # Parse known args to get model early, before adding load args
known_args, _ = test_args.parse_known_args() known_args, _ = test_args.parse_known_args()
# Set environment variable if model was provided
if hasattr(known_args, "model") and known_args.model: if hasattr(known_args, "model") and known_args.model:
os.environ["MODEL_NAME"] = known_args.model os.environ["MODEL_NAME"] = known_args.model
print(f"Set MODEL_NAME environment variable to: {known_args.model}") print(f"Set MODEL_NAME environment variable to: {known_args.model}")
# Now call test_load_cmd normally - it will add its own args and re-parse # Load test args
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args) test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
args = test_args.parse_args()
server_url = {
"prod": "https://run.vast.ai",
"alpha": "https://run-alpha.vast.ai",
"candidate": "https://run-candidate.vast.ai",
"local": "http://localhost:8080"
}.get(args.instance, "http://localhost:8080")
run_load_with_metrics(
num_requests=args.num_requests,
requests_per_second=args.requests_per_second,
endpoint_group_name=args.endpoint_group_name,
account_api_key=args.api_key,
server_url=server_url,
worker_endpoint=WORKER_ENDPOINT,
instance=args.instance,
out_path=args.out_path,
)