Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 50f13d6288 | |||
| a6921de6a2 | |||
| dcb7d036ed | |||
| b8223879c9 | |||
| 298590fb88 | |||
| 814c3acd4c | |||
| 22bca74087 | |||
| 9c795e2a01 | |||
| 830b532781 | |||
| d6a6e34c6b | |||
| ac1e109c48 | |||
| d6eb498ee4 | |||
| bcecd6df40 | |||
| 4d9bf2048c | |||
| 7788bc4a62 | |||
| 70d51bafe1 | |||
| 63909736bb | |||
| f4f7080df1 | |||
| d51a338e8f | |||
| 92a04bd7af | |||
| c98d661513 | |||
| f6fd1c6ac1 | |||
| 055e346c8c | |||
| 1cedb28acf | |||
| ec25dda3ad | |||
| 0397af719d | |||
| 3786cf978d | |||
| a86d4bcf9c | |||
| e9b6a14a5e | |||
| cadac033e1 |
+20
-11
@@ -26,7 +26,8 @@ from lib.data_types import (
|
|||||||
LogAction,
|
LogAction,
|
||||||
ApiPayload_T,
|
ApiPayload_T,
|
||||||
JsonDataException,
|
JsonDataException,
|
||||||
RequestMetrics
|
RequestMetrics,
|
||||||
|
BenchmarkResult
|
||||||
)
|
)
|
||||||
|
|
||||||
VERSION = "0.1.0"
|
VERSION = "0.1.0"
|
||||||
@@ -285,7 +286,7 @@ class Backend:
|
|||||||
message = {
|
message = {
|
||||||
key: value
|
key: value
|
||||||
for (key, value) in (dataclasses.asdict(auth_data).items())
|
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||||
if key != "signature"
|
if key != "signature" and key != "__request_id"
|
||||||
}
|
}
|
||||||
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -295,7 +296,7 @@ class Backend:
|
|||||||
elif message in self.msg_history:
|
elif message in self.msg_history:
|
||||||
log.debug(f"message: {message} already in message history")
|
log.debug(f"message: {message} already in message history")
|
||||||
return False
|
return False
|
||||||
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
|
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
|
||||||
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||||
self.msg_history.append(message)
|
self.msg_history.append(message)
|
||||||
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||||
@@ -332,18 +333,26 @@ class Backend:
|
|||||||
|
|
||||||
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
tasks = []
|
benchmark_requests = []
|
||||||
total_workload = 0
|
|
||||||
|
|
||||||
for _ in range(concurrent_requests):
|
for i in range(concurrent_requests):
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
total_workload += payload.count_workload()
|
workload = payload.count_workload()
|
||||||
tasks.append(
|
task = self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||||
self.__call_api(handler=self.benchmark_handler, payload=payload)
|
benchmark_requests.append(
|
||||||
|
BenchmarkResult(request_idx=i, workload=workload, task=task)
|
||||||
)
|
)
|
||||||
|
|
||||||
responses = await gather(*tasks)
|
responses = await gather(*[br.task for br in benchmark_requests])
|
||||||
|
for br, response in zip(benchmark_requests, responses):
|
||||||
|
br.response = response
|
||||||
|
|
||||||
|
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
||||||
time_elapsed = time.time() - start
|
time_elapsed = time.time() - start
|
||||||
|
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
||||||
|
if successful_responses == 0:
|
||||||
|
self.backend_errored("No successful responses from benchmark")
|
||||||
|
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
|
||||||
|
|
||||||
throughput = total_workload / time_elapsed
|
throughput = total_workload / time_elapsed
|
||||||
sum_throughput += throughput
|
sum_throughput += throughput
|
||||||
@@ -357,7 +366,7 @@ class Backend:
|
|||||||
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
||||||
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
||||||
f"Throughput: {throughput} workload/s",
|
f"Throughput: {throughput} workload/s",
|
||||||
f"Successful responses: {len([r for r in responses if r.status == 200])}",
|
f"Successful responses: {successful_responses}/{concurrent_requests}",
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
+17
-5
@@ -3,7 +3,7 @@ import logging
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
|
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
|
||||||
from aiohttp import web, ClientResponse
|
from aiohttp import web, ClientResponse
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
|
|||||||
class AuthData:
|
class AuthData:
|
||||||
"""data used to authenticate requester"""
|
"""data used to authenticate requester"""
|
||||||
|
|
||||||
signature: str
|
|
||||||
cost: str
|
cost: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
reqnum: int
|
reqnum: int
|
||||||
url: str
|
|
||||||
request_idx: int
|
request_idx: int
|
||||||
|
signature: str
|
||||||
|
url: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
||||||
@@ -190,10 +190,11 @@ class SystemMetrics:
|
|||||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||||
self.last_disk_usage = disk_usage
|
self.last_disk_usage = disk_usage
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, expected: float | None) -> None:
|
||||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||||
# as well: they should send model_loading_time once when they are done loading
|
# as well: they should send model_loading_time once when they are done loading
|
||||||
|
if self.model_loading_time == expected:
|
||||||
self.model_loading_time = None
|
self.model_loading_time = None
|
||||||
|
|
||||||
|
|
||||||
@@ -206,6 +207,17 @@ class RequestMetrics:
|
|||||||
status: str
|
status: str
|
||||||
success: bool = False
|
success: bool = False
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkResult:
|
||||||
|
request_idx: int
|
||||||
|
workload: float
|
||||||
|
task: Awaitable[ClientResponse]
|
||||||
|
response: Optional[ClientResponse] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_successful(self) -> bool:
|
||||||
|
return self.response is not None and self.response.status == 200
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelMetrics:
|
class ModelMetrics:
|
||||||
"""Model specific metrics"""
|
"""Model specific metrics"""
|
||||||
@@ -246,7 +258,7 @@ class ModelMetrics:
|
|||||||
def wait_time(self) -> float:
|
def wait_time(self) -> float:
|
||||||
if (len(self.requests_working) == 0):
|
if (len(self.requests_working) == 0):
|
||||||
return 0.0
|
return 0.0
|
||||||
return sum([request.workload for request in self.requests_working.values()]) / self.max_throughput
|
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cur_load(self) -> float:
|
def cur_load(self) -> float:
|
||||||
|
|||||||
+43
-12
@@ -145,41 +145,68 @@ class Metrics:
|
|||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
|
|
||||||
async def __send_delete_requests_and_reset(self):
|
async def __send_delete_requests_and_reset(self):
|
||||||
|
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||||
async def send_data(report_addr: str, success: bool) -> bool:
|
|
||||||
data = {
|
data = {
|
||||||
"worker_id": self.id,
|
"worker_id": self.id,
|
||||||
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success],
|
"request_idxs": idxs,
|
||||||
"success": success
|
"success": success_flag,
|
||||||
}
|
}
|
||||||
|
log.debug(
|
||||||
|
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
|
||||||
|
)
|
||||||
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
||||||
for attempt in range(1, 4):
|
for attempt in range(1, 4):
|
||||||
try:
|
try:
|
||||||
session = await self.http()
|
session = await self.http()
|
||||||
async with session.post(full_path, json=data) as res:
|
async with session.post(full_path, json=data) as res:
|
||||||
|
log.debug(f"delete_requests response: {res.status}")
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
return True
|
return True
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
log.debug(f"delete_requests timed out")
|
log.debug("delete_requests timed out")
|
||||||
except (ClientResponseError, Exception) as e:
|
except (ClientResponseError, Exception) as e:
|
||||||
log.debug(f"delete_requests failed with error: {e}")
|
log.debug(f"delete_requests failed with error: {e}")
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
log.debug(f"retrying delete_request, attempt: {attempt}")
|
log.debug(f"retrying delete_request, attempt: {attempt}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Take a snapshot of what we plan to send this tick.
|
||||||
|
# New arrivals after this snapshot will remain in the queue for the next tick.
|
||||||
|
snapshot = list(self.model_metrics.requests_deleting)
|
||||||
|
success_idxs = [r.request_idx for r in snapshot if r.success is True]
|
||||||
|
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
|
||||||
|
|
||||||
|
if not success_idxs and not failed_idxs:
|
||||||
|
return # nothing to do
|
||||||
|
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
success = await send_data(report_addr, success=True) and await send_data(report_addr, success=False)
|
sent_success = True
|
||||||
if success is True:
|
sent_failed = True
|
||||||
self.model_metrics.requests_deleting.clear()
|
|
||||||
|
if success_idxs:
|
||||||
|
sent_success = await post(report_addr, success_idxs, True)
|
||||||
|
if failed_idxs:
|
||||||
|
sent_failed = await post(report_addr, failed_idxs, False)
|
||||||
|
|
||||||
|
if sent_success and sent_failed:
|
||||||
|
# Remove only the items we actually sent from the live queue.
|
||||||
|
sent_set = set(success_idxs) | set(failed_idxs)
|
||||||
|
self.model_metrics.requests_deleting[:] = [
|
||||||
|
r for r in self.model_metrics.requests_deleting
|
||||||
|
if r.request_idx not in sent_set
|
||||||
|
]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
async def __send_metrics_and_reset(self):
|
async def __send_metrics_and_reset(self):
|
||||||
|
|
||||||
|
loadtime_snapshot = self.system_metrics.model_loading_time
|
||||||
|
|
||||||
def compute_autoscaler_data() -> AutoScalerData:
|
def compute_autoscaler_data() -> AutoScalerData:
|
||||||
return AutoScalerData(
|
return AutoScalerData(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
version=self.version,
|
version=self.version,
|
||||||
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
loadtime=(loadtime_snapshot or 0.0),
|
||||||
new_load=self.model_metrics.workload_processing,
|
new_load=self.model_metrics.workload_processing,
|
||||||
cur_load=self.model_metrics.cur_load,
|
cur_load=self.model_metrics.cur_load,
|
||||||
rej_load=self.model_metrics.workload_rejected,
|
rej_load=self.model_metrics.workload_rejected,
|
||||||
@@ -227,11 +254,15 @@ class Metrics:
|
|||||||
|
|
||||||
self.system_metrics.update_disk_usage()
|
self.system_metrics.update_disk_usage()
|
||||||
|
|
||||||
|
sent = False
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
success = await send_data(report_addr)
|
if await send_data(report_addr):
|
||||||
if success is True:
|
sent = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if sent:
|
||||||
|
# clear the one-shot loadtime only if we actually sent *this* value
|
||||||
|
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||||
self.update_pending = False
|
self.update_pending = False
|
||||||
self.model_metrics.reset()
|
self.model_metrics.reset()
|
||||||
self.system_metrics.reset()
|
|
||||||
self.last_metric_update = time.time()
|
self.last_metric_update = time.time()
|
||||||
|
|||||||
+6
-6
@@ -292,12 +292,12 @@ def test_load_cmd(
|
|||||||
args = arg_parser.parse_args()
|
args = arg_parser.parse_args()
|
||||||
if hasattr(args, "comfy_model"):
|
if hasattr(args, "comfy_model"):
|
||||||
os.environ["COMFY_MODEL"] = args.comfy_model
|
os.environ["COMFY_MODEL"] = args.comfy_model
|
||||||
server_url = dict(
|
server_url = {
|
||||||
prod="https://run.vast.ai",
|
"prod": "https://run.vast.ai",
|
||||||
alpha="https://run-alpha.vast.ai",
|
"alpha": "https://run-alpha.vast.ai",
|
||||||
candidate="https://run-candidate.vast.ai",
|
"candidate": "https://run-candidate.vast.ai",
|
||||||
local="http://localhost:8080",
|
"local": "http://localhost:8080",
|
||||||
)[args.instance]
|
}.get(args.instance, "http://localhost:8080")
|
||||||
run_test(
|
run_test(
|
||||||
num_requests=args.num_requests,
|
num_requests=args.num_requests,
|
||||||
requests_per_second=args.requests_per_second,
|
requests_per_second=args.requests_per_second,
|
||||||
|
|||||||
+43
-5
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
import time
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -16,6 +17,38 @@ class Endpoint:
|
|||||||
Utility class for handling endpoint operations.
|
Utility class for handling endpoint operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_endpoint_info(
|
||||||
|
endpoint_name: str, account_api_key: str, instance: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||||
|
url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}"
|
||||||
|
# Retry a few times to smooth over transient propagation/network delays
|
||||||
|
for attempt in range(4):
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=headers, timeout=8)
|
||||||
|
if response.status_code != 200:
|
||||||
|
# brief backoff and retry
|
||||||
|
time.sleep(0.3 * (attempt + 1))
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except Exception:
|
||||||
|
# JSON parse failed; backoff and retry
|
||||||
|
time.sleep(0.3 * (attempt + 1))
|
||||||
|
continue
|
||||||
|
result = data.get("results", []) if isinstance(data, dict) else []
|
||||||
|
endpoint = next(
|
||||||
|
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if endpoint and endpoint.get("id") and endpoint.get("api_key"):
|
||||||
|
return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")}
|
||||||
|
except Exception:
|
||||||
|
# network or other transient error; retry
|
||||||
|
time.sleep(0.3 * (attempt + 1))
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_autoscaler_server_url(instance: str) -> str:
|
def get_autoscaler_server_url(instance: str) -> str:
|
||||||
endpoints = {
|
endpoints = {
|
||||||
@@ -23,7 +56,10 @@ class Endpoint:
|
|||||||
"candidate": "run-candidate",
|
"candidate": "run-candidate",
|
||||||
"prod": "run",
|
"prod": "run",
|
||||||
}
|
}
|
||||||
return f"https://{endpoints[instance]}.vast.ai/"
|
host = endpoints.get(instance)
|
||||||
|
if host:
|
||||||
|
return f"https://{host}.vast.ai/"
|
||||||
|
return "http://localhost:8080"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_server_url(instance: str) -> str:
|
def get_server_url(instance: str) -> str:
|
||||||
@@ -32,7 +68,8 @@ class Endpoint:
|
|||||||
"candidate": "candidate",
|
"candidate": "candidate",
|
||||||
"prod": "console",
|
"prod": "console",
|
||||||
}
|
}
|
||||||
return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
|
host = endpoints.get(instance, "alpha")
|
||||||
|
return f"https://{host}.vast.ai/api/v0/endptjobs/"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_endpoint_api_key(
|
def get_endpoint_api_key(
|
||||||
@@ -55,6 +92,7 @@ class Endpoint:
|
|||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
timeout=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
@@ -64,14 +102,14 @@ class Endpoint:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
except requests.exceptions.JSONDecodeError as e:
|
except Exception as e:
|
||||||
log.debug(f"Failed to parse JSON response: {e}")
|
log.debug(f"Failed to parse JSON response: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
result = data.get("results", [])
|
result = data.get("results", [])
|
||||||
|
|
||||||
endpoint: Optional[Dict[str, Any]] = next(
|
endpoint: Optional[Dict[str, Any]] = next(
|
||||||
(item for item in result if item["endpoint_name"] == endpoint_name),
|
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if not endpoint:
|
if not endpoint:
|
||||||
|
|||||||
@@ -12,9 +12,21 @@ A docker image is provided but you may use any if the above requirements are met
|
|||||||
|
|
||||||
## Benchmarking
|
## Benchmarking
|
||||||
|
|
||||||
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
|
### Custom Benchmark Workflows
|
||||||
|
|
||||||
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
|
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
|
||||||
|
|
||||||
|
**Ways to provide the benchmark file:**
|
||||||
|
- Fork this repository and add your `benchmark.json` file
|
||||||
|
- Write the file during worker provisioning (onstart script or setup phase)
|
||||||
|
|
||||||
|
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
|
||||||
|
|
||||||
|
### Default Benchmark (Fallback)
|
||||||
|
|
||||||
|
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
|
||||||
|
|
||||||
|
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
|
||||||
|
|
||||||
| Environment Variable | Default Value | Description |
|
| Environment Variable | Default Value | Description |
|
||||||
| -------------------- | ------------- | ----------- |
|
| -------------------- | ------------- | ----------- |
|
||||||
@@ -24,7 +36,7 @@ The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image wo
|
|||||||
|
|
||||||
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
||||||
|
|
||||||
### Calibrating Benchmark Duration
|
#### Calibrating Fallback Benchmark Duration
|
||||||
|
|
||||||
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
||||||
|
|
||||||
|
|||||||
@@ -5,12 +5,13 @@ import dataclasses
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from lib.data_types import ApiPayload, JsonDataException
|
from lib.data_types import ApiPayload, JsonDataException
|
||||||
|
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
|
|
||||||
test_prompts = f.readlines()
|
|
||||||
|
|
||||||
def count_workload() -> float:
|
def count_workload() -> float:
|
||||||
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
||||||
@@ -24,9 +25,32 @@ class ComfyWorkflowData(ApiPayload):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls):
|
def for_test(cls):
|
||||||
"""
|
"""
|
||||||
Use the variables available to simulate workflows of the required running time
|
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
|
||||||
|
Otherwise, use the variables available to simulate workflows of the required running time
|
||||||
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
||||||
"""
|
"""
|
||||||
|
# Try to load benchmark.json
|
||||||
|
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
|
||||||
|
|
||||||
|
if benchmark_file.exists():
|
||||||
|
try:
|
||||||
|
with open(benchmark_file, "r") as f:
|
||||||
|
benchmark_workflow = json.load(f)
|
||||||
|
return cls(
|
||||||
|
input={
|
||||||
|
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||||
|
"workflow_json": benchmark_workflow
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, IOError):
|
||||||
|
# JSON is malformed or file can't be read, fall through to default
|
||||||
|
log.error(f"Failed to benchmark using {benchmark_file}")
|
||||||
|
|
||||||
|
# Fallback: read prompts and construct payload
|
||||||
|
log.info("Using fallback method for benchmarking")
|
||||||
|
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
|
||||||
|
test_prompts = f.readlines()
|
||||||
|
|
||||||
test_prompt = random.choice(test_prompts).rstrip()
|
test_prompt = random.choice(test_prompts).rstrip()
|
||||||
return cls(
|
return cls(
|
||||||
input={
|
input={
|
||||||
|
|||||||
@@ -0,0 +1,107 @@
|
|||||||
|
{
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"seed": "__RANDOM_INT__",
|
||||||
|
"steps": 20,
|
||||||
|
"cfg": 8,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"denoise": 1,
|
||||||
|
"model": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"7",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"5",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"_meta": {
|
||||||
|
"title": "KSampler"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Checkpoint"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
"batch_size": 1
|
||||||
|
},
|
||||||
|
"class_type": "EmptyLatentImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Empty Latent Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "text, watermark",
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"4",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE Decode"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "ComfyUI",
|
||||||
|
"images": [
|
||||||
|
"8",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Save Image"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
|||||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||||
"Value not in list: ", # This error is emitted when the model file is not there at all
|
"Value not in list: ", # This error is emitted when the model file is not there at all
|
||||||
|
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -119,14 +119,25 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
|
|||||||
class CompletionsData(GenericData):
|
class CompletionsData(GenericData):
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls) -> "CompletionsData":
|
def for_test(cls) -> "CompletionsData":
|
||||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||||
|
|
||||||
|
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||||
|
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||||
|
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||||
|
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||||
|
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||||
|
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||||
|
woodlands, shrublands, and mountainous areas.
|
||||||
|
|
||||||
|
Please answer the following question based on the above context."""
|
||||||
|
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||||
model = os.environ.get("MODEL_NAME")
|
model = os.environ.get("MODEL_NAME")
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError("MODEL_NAME environment variable not set")
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
|
|
||||||
test_input = {
|
test_input = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": prompt,
|
"prompt": f"{system_prompt}\n\n{unique_question}",
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_tokens": 500,
|
"max_tokens": 500,
|
||||||
}
|
}
|
||||||
@@ -153,7 +164,18 @@ class ChatCompletionsData(GenericData):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls) -> "ChatCompletionsData":
|
def for_test(cls) -> "ChatCompletionsData":
|
||||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||||
|
|
||||||
|
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||||
|
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||||
|
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||||
|
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||||
|
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||||
|
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||||
|
woodlands, shrublands, and mountainous areas.
|
||||||
|
|
||||||
|
Please answer the following question based on the above context."""
|
||||||
|
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||||
model = os.environ.get("MODEL_NAME")
|
model = os.environ.get("MODEL_NAME")
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError("MODEL_NAME environment variable not set")
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
@@ -161,7 +183,10 @@ class ChatCompletionsData(GenericData):
|
|||||||
# Chat completions use messages format instead of prompt
|
# Chat completions use messages format instead of prompt
|
||||||
test_input = {
|
test_input = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt}, # Shared prefix
|
||||||
|
{"role": "user", "content": unique_question} # Unique per request
|
||||||
|
],
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_tokens": 500,
|
"max_tokens": 500,
|
||||||
}
|
}
|
||||||
|
|||||||
+414
-8
@@ -1,8 +1,395 @@
|
|||||||
from lib.test_utils import test_load_cmd, test_args
|
from lib.test_utils import test_args
|
||||||
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
from lib.data_types import AuthData
|
||||||
from .data_types.server import CompletionsData
|
from .data_types.server import CompletionsData
|
||||||
import os
|
|
||||||
|
|
||||||
WORKER_ENDPOINT = "/v1/completions"
|
import os
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import requests
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from collections import Counter
|
||||||
|
from urllib.parse import urljoin, urlparse
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Headless plotting
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import logging
|
||||||
|
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
|
||||||
|
def get_incremented_path(path: str) -> str:
|
||||||
|
base, ext = os.path.splitext(path)
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return path
|
||||||
|
i = 1
|
||||||
|
while os.path.exists(f"{base}-{i}{ext}"):
|
||||||
|
i += 1
|
||||||
|
return f"{base}-{i}{ext}"
|
||||||
|
|
||||||
|
WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReqResult:
|
||||||
|
worker_url: str
|
||||||
|
route_ms: float
|
||||||
|
worker_ms: float
|
||||||
|
total_ms: float
|
||||||
|
ok: bool
|
||||||
|
error: str = ""
|
||||||
|
status_code: int = 0
|
||||||
|
t_start: float = 0.0
|
||||||
|
t_end: float = 0.0
|
||||||
|
workload: float = 0.0
|
||||||
|
|
||||||
|
def do_one(endpoint_name: str,
|
||||||
|
endpoint_id: int,
|
||||||
|
endpoint_api_key: str,
|
||||||
|
server_url: str,
|
||||||
|
worker_endpoint: str,
|
||||||
|
payload,
|
||||||
|
results_list,
|
||||||
|
t0,
|
||||||
|
status_samples,
|
||||||
|
route_session,
|
||||||
|
worker_session):
|
||||||
|
try:
|
||||||
|
workload = payload.count_workload()
|
||||||
|
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
|
||||||
|
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
|
||||||
|
start = time.time()
|
||||||
|
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
|
||||||
|
t_after_route = time.time()
|
||||||
|
if r0.status_code != 200:
|
||||||
|
results_list.append(ReqResult(worker_url="",
|
||||||
|
route_ms=(t_after_route - start) * 1000.0,
|
||||||
|
worker_ms=0.0,
|
||||||
|
total_ms=(t_after_route - start) * 1000.0,
|
||||||
|
ok=False,
|
||||||
|
error=f"route error {r0.reason} {r0.text}",
|
||||||
|
status_code=r0.status_code,
|
||||||
|
t_start=start - t0,
|
||||||
|
t_end=t_after_route - t0,
|
||||||
|
workload=workload))
|
||||||
|
return
|
||||||
|
msg = r0.json()
|
||||||
|
|
||||||
|
# 1) Check if we got a worker back from route
|
||||||
|
worker_url = msg.get("url", "")
|
||||||
|
if not worker_url:
|
||||||
|
status = msg.get("status", "")
|
||||||
|
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
|
||||||
|
if m:
|
||||||
|
tot, loading, standby, err = map(int, m.groups())
|
||||||
|
idle = max(tot - loading - standby - err, 0)
|
||||||
|
status_samples.append((time.time() - t0, idle))
|
||||||
|
|
||||||
|
# 2) If we got a worker, send the request
|
||||||
|
if worker_url:
|
||||||
|
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
|
||||||
|
t_before_worker = time.time()
|
||||||
|
r1 = worker_session.post(
|
||||||
|
urljoin(worker_url, worker_endpoint),
|
||||||
|
json=req,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
|
timeout=(4, 120),
|
||||||
|
)
|
||||||
|
t_after_worker = time.time()
|
||||||
|
if r1.status_code != 200:
|
||||||
|
results_list.append(ReqResult(worker_url=worker_url,
|
||||||
|
route_ms=(t_after_route - start) * 1000.0,
|
||||||
|
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||||
|
total_ms=(t_after_worker - start) * 1000.0,
|
||||||
|
ok=False,
|
||||||
|
error=f"worker inference error {r1.reason} {r1.text}",
|
||||||
|
status_code=r1.status_code,
|
||||||
|
t_start=start - t0,
|
||||||
|
t_end=t_after_worker - t0,
|
||||||
|
workload=workload))
|
||||||
|
return
|
||||||
|
# Success case
|
||||||
|
results_list.append(ReqResult(worker_url=worker_url,
|
||||||
|
route_ms=(t_after_route - start) * 1000.0,
|
||||||
|
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||||
|
total_ms=(t_after_worker - start) * 1000.0,
|
||||||
|
ok=True,
|
||||||
|
error="",
|
||||||
|
status_code=200,
|
||||||
|
t_start=start - t0,
|
||||||
|
t_end=t_after_worker - t0,
|
||||||
|
workload=workload))
|
||||||
|
|
||||||
|
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
|
||||||
|
if worker_url:
|
||||||
|
try:
|
||||||
|
r_status = route_session.post(
|
||||||
|
urljoin(server_url, "/get_endpoint_workers/"),
|
||||||
|
json={"id": endpoint_id},
|
||||||
|
headers={"Authorization": f"Bearer {endpoint_api_key}"},
|
||||||
|
timeout=3,
|
||||||
|
)
|
||||||
|
if r_status.status_code == 200:
|
||||||
|
workers = r_status.json()
|
||||||
|
idle = 0
|
||||||
|
for w in workers:
|
||||||
|
st = str(w.get("status", "")).lower()
|
||||||
|
if (st in ("idle")):
|
||||||
|
idle += 1
|
||||||
|
status_samples.append((time.time() - t0, idle))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
t = time.time()
|
||||||
|
results_list.append(ReqResult(worker_url="",
|
||||||
|
route_ms=0.0,
|
||||||
|
worker_ms=0.0,
|
||||||
|
total_ms=0.0,
|
||||||
|
ok=False,
|
||||||
|
error=f"unknown error {e}",
|
||||||
|
status_code=0,
|
||||||
|
t_start=t - t0,
|
||||||
|
t_end=t - t0,
|
||||||
|
workload=0.0))
|
||||||
|
|
||||||
|
def run_load_with_metrics(num_requests: int,
|
||||||
|
requests_per_second: float,
|
||||||
|
endpoint_group_name: str,
|
||||||
|
account_api_key: str,
|
||||||
|
server_url: str,
|
||||||
|
worker_endpoint: str,
|
||||||
|
instance: str,
|
||||||
|
out_path: str):
|
||||||
|
|
||||||
|
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
|
||||||
|
account_api_key=account_api_key,
|
||||||
|
instance=instance)
|
||||||
|
if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"):
|
||||||
|
print(f"Endpoint {endpoint_group_name} not found for API key")
|
||||||
|
return
|
||||||
|
endpoint_id = int(ep_info["id"])
|
||||||
|
endpoint_api_key = ep_info["api_key"]
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
results = []
|
||||||
|
status_samples = []
|
||||||
|
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
|
||||||
|
submit_queue_factor = 2 # cap queued tasks to reduce memory
|
||||||
|
|
||||||
|
# Shared HTTP sessions with connection pooling (persistent connections)
|
||||||
|
def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session:
|
||||||
|
sess = requests.Session()
|
||||||
|
adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0)
|
||||||
|
sess.mount("https://", adapter)
|
||||||
|
sess.mount("http://", adapter)
|
||||||
|
return sess
|
||||||
|
|
||||||
|
# Router: mostly single host, small connection pool is sufficient
|
||||||
|
route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
|
||||||
|
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
|
||||||
|
worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
|
||||||
|
|
||||||
|
# Fire requests using a thread pool, scheduling at requested RPS
|
||||||
|
inflight = set()
|
||||||
|
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||||
|
for i in range(num_requests):
|
||||||
|
# Pace submissions to RPS
|
||||||
|
target_time = t0 + i / max(requests_per_second, 1e-9)
|
||||||
|
sleep_s = target_time - time.time()
|
||||||
|
if sleep_s > 0:
|
||||||
|
time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive
|
||||||
|
|
||||||
|
payload = CompletionsData.for_test()
|
||||||
|
fut = executor.submit(
|
||||||
|
do_one,
|
||||||
|
endpoint_group_name,
|
||||||
|
endpoint_id,
|
||||||
|
endpoint_api_key,
|
||||||
|
server_url,
|
||||||
|
worker_endpoint,
|
||||||
|
payload,
|
||||||
|
results,
|
||||||
|
t0,
|
||||||
|
status_samples,
|
||||||
|
route_session,
|
||||||
|
worker_session,
|
||||||
|
)
|
||||||
|
inflight.add(fut)
|
||||||
|
# Prevent unbounded queue growth
|
||||||
|
if len(inflight) >= max_concurrency * submit_queue_factor:
|
||||||
|
done, not_done = wait(inflight, return_when=FIRST_COMPLETED)
|
||||||
|
inflight = not_done
|
||||||
|
# Wait for all outstanding tasks
|
||||||
|
if inflight:
|
||||||
|
wait(inflight)
|
||||||
|
# Close sessions
|
||||||
|
try:
|
||||||
|
route_session.close()
|
||||||
|
finally:
|
||||||
|
worker_session.close()
|
||||||
|
|
||||||
|
# Aggregate results
|
||||||
|
oks = [r for r in results if r.ok]
|
||||||
|
errs = [r for r in results if not r.ok]
|
||||||
|
total_reqs = len(results)
|
||||||
|
succ = len(oks)
|
||||||
|
|
||||||
|
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
|
||||||
|
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
|
||||||
|
route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
|
||||||
|
|
||||||
|
avg_total = float(np.mean(total_ms)) if succ else 0.0
|
||||||
|
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
|
||||||
|
avg_route = float(np.mean(route_ms)) if succ else 0.0
|
||||||
|
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
|
||||||
|
|
||||||
|
# Distribution over workers (by host:port)
|
||||||
|
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
|
||||||
|
dist = Counter(hosts)
|
||||||
|
|
||||||
|
# Idle over time (mode per second)
|
||||||
|
idle_ts, idle_vals = [], []
|
||||||
|
if status_samples:
|
||||||
|
buckets = {}
|
||||||
|
for ts, idle in status_samples:
|
||||||
|
k = int(ts)
|
||||||
|
buckets.setdefault(k, []).append(idle)
|
||||||
|
keys = sorted(buckets.keys())
|
||||||
|
idle_ts = keys
|
||||||
|
# Use the most frequent sampled value per second (mode) to keep integer counts
|
||||||
|
idle_vals = []
|
||||||
|
for k in keys:
|
||||||
|
vals_k = [int(v) for v in buckets[k]]
|
||||||
|
if vals_k:
|
||||||
|
cnt = Counter(vals_k)
|
||||||
|
idle_vals.append(cnt.most_common(1)[0][0])
|
||||||
|
else:
|
||||||
|
idle_vals.append(0)
|
||||||
|
|
||||||
|
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
|
||||||
|
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
|
||||||
|
print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
|
||||||
|
if errs:
|
||||||
|
print("Sample errors:")
|
||||||
|
for e in errs[:5]:
|
||||||
|
print(f" {e.status_code} {e.error}")
|
||||||
|
|
||||||
|
# Plot: 2x3 grid
|
||||||
|
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||||||
|
fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}")
|
||||||
|
|
||||||
|
# Dist per worker
|
||||||
|
ax0 = axes[0, 0]
|
||||||
|
if dist:
|
||||||
|
items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True)
|
||||||
|
labels, counts = zip(*items)
|
||||||
|
ax0.bar(range(len(labels)), counts)
|
||||||
|
ax0.set_xticks(range(len(labels)))
|
||||||
|
ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||||
|
ax0.set_title("Request distribution over workers")
|
||||||
|
ax0.set_ylabel("count")
|
||||||
|
|
||||||
|
# Latency histogram (total)
|
||||||
|
ax1 = axes[0, 1]
|
||||||
|
if succ:
|
||||||
|
ax1.hist(total_ms, bins=30)
|
||||||
|
ax1.set_title("Total latency (ms)")
|
||||||
|
ax1.set_xlabel("ms")
|
||||||
|
ax1.set_ylabel("freq")
|
||||||
|
|
||||||
|
# Eligible workers over time
|
||||||
|
ax_idle = axes[0, 2]
|
||||||
|
if idle_ts:
|
||||||
|
ax_idle.plot(idle_ts, idle_vals, "-o", ms=3)
|
||||||
|
ax_idle.set_title("Eligible workers over time")
|
||||||
|
ax_idle.set_xlabel("time (s)")
|
||||||
|
ax_idle.set_ylabel("eligible count")
|
||||||
|
|
||||||
|
# Throughput over time (completions/sec)
|
||||||
|
ax_idle = axes[1, 0]
|
||||||
|
ax_idle.clear()
|
||||||
|
if succ:
|
||||||
|
per_sec = {}
|
||||||
|
for r in oks:
|
||||||
|
s = int(r.t_end)
|
||||||
|
per_sec[s] = per_sec.get(s, 0) + 1
|
||||||
|
ts = sorted(per_sec.keys())
|
||||||
|
vals = [per_sec[t] for t in ts]
|
||||||
|
ax_idle.plot(ts, vals, "-o", ms=3)
|
||||||
|
ax_idle.set_title("Completions per second")
|
||||||
|
ax_idle.set_xlabel("time (s)")
|
||||||
|
ax_idle.set_ylabel("completions / sec")
|
||||||
|
|
||||||
|
# Summary text
|
||||||
|
ax3 = axes[1, 1]
|
||||||
|
ax3.axis("off")
|
||||||
|
text = (
|
||||||
|
f"Total requests: {total_reqs}\n"
|
||||||
|
f"Success: {succ} Errors: {len(errs)}\n"
|
||||||
|
f"Avg total latency: {avg_total:.1f} ms\n"
|
||||||
|
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
|
||||||
|
f"Avg route latency: {avg_route:.1f} ms\n"
|
||||||
|
f"Avg worker latency: {avg_worker:.1f} ms\n"
|
||||||
|
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
|
||||||
|
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
|
||||||
|
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
|
||||||
|
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
|
||||||
|
)
|
||||||
|
ax3.set_title("Summary")
|
||||||
|
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
|
||||||
|
|
||||||
|
# Error count over time
|
||||||
|
ax_errors = axes[1, 2]
|
||||||
|
all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
|
||||||
|
if all_end_times:
|
||||||
|
min_second = min(all_end_times)
|
||||||
|
max_second = max(all_end_times)
|
||||||
|
# Count errors per second
|
||||||
|
errors_per_second = {}
|
||||||
|
for result in errs:
|
||||||
|
second = int(result.t_end)
|
||||||
|
errors_per_second[second] = errors_per_second.get(second, 0) + 1
|
||||||
|
# Create complete timeline including zeros
|
||||||
|
time_seconds = list(range(min_second, max_second + 1))
|
||||||
|
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
|
||||||
|
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
|
||||||
|
ax_errors.set_title("Errors per second")
|
||||||
|
ax_errors.set_xlabel("time (s)")
|
||||||
|
ax_errors.set_ylabel("errors / sec")
|
||||||
|
|
||||||
|
# Ensure unique output path and create directory if needed
|
||||||
|
final_out_path = get_incremented_path(out_path)
|
||||||
|
out_dir = os.path.dirname(final_out_path)
|
||||||
|
if out_dir:
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||||
|
plt.savefig(final_out_path, dpi=120)
|
||||||
|
print(f"Saved report to: {final_out_path}")
|
||||||
|
|
||||||
|
# Per-worker latency boxplot (top 12 by volume)
|
||||||
|
groups = {}
|
||||||
|
for r in oks:
|
||||||
|
host = urlparse(r.worker_url).netloc
|
||||||
|
groups.setdefault(host, []).append(r.total_ms)
|
||||||
|
items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12]
|
||||||
|
if items:
|
||||||
|
labels, data = zip(*items)
|
||||||
|
fig2, axb = plt.subplots(1, 1, figsize=(12, 5))
|
||||||
|
axb.boxplot(data, showfliers=False)
|
||||||
|
axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||||
|
axb.set_title("Per-worker latency (ms)")
|
||||||
|
axb.set_ylabel("ms")
|
||||||
|
plt.tight_layout()
|
||||||
|
extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png")
|
||||||
|
plt.savefig(extra_out, dpi=120)
|
||||||
|
fig2.tight_layout()
|
||||||
|
fig2.savefig(extra_out, dpi=120)
|
||||||
|
print(f"Saved worker latency plot to: {extra_out}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Check if MODEL_NAME environment variable is set
|
# Check if MODEL_NAME environment variable is set
|
||||||
@@ -16,13 +403,32 @@ if __name__ == "__main__":
|
|||||||
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse known args to get model early, before test_load_cmd adds its args
|
# Parse known args to get model early, before adding load args
|
||||||
known_args, _ = test_args.parse_known_args()
|
known_args, _ = test_args.parse_known_args()
|
||||||
|
|
||||||
# Set environment variable if model was provided
|
|
||||||
if hasattr(known_args, "model") and known_args.model:
|
if hasattr(known_args, "model") and known_args.model:
|
||||||
os.environ["MODEL_NAME"] = known_args.model
|
os.environ["MODEL_NAME"] = known_args.model
|
||||||
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
||||||
|
|
||||||
# Now call test_load_cmd normally - it will add its own args and re-parse
|
# Load test args
|
||||||
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)
|
test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
|
||||||
|
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
|
||||||
|
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
|
||||||
|
args = test_args.parse_args()
|
||||||
|
|
||||||
|
server_url = {
|
||||||
|
"prod": "https://run.vast.ai",
|
||||||
|
"alpha": "https://run-alpha.vast.ai",
|
||||||
|
"candidate": "https://run-candidate.vast.ai",
|
||||||
|
"local": "http://localhost:8080"
|
||||||
|
}.get(args.instance, "http://localhost:8080")
|
||||||
|
|
||||||
|
run_load_with_metrics(
|
||||||
|
num_requests=args.num_requests,
|
||||||
|
requests_per_second=args.requests_per_second,
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
|
account_api_key=args.api_key,
|
||||||
|
server_url=server_url,
|
||||||
|
worker_endpoint=WORKER_ENDPOINT,
|
||||||
|
instance=args.instance,
|
||||||
|
out_path=args.out_path,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user