Compare commits

..

2 Commits

Author SHA1 Message Date
Lucas Armand 3988cf553f Suppress matplot debug logs 2025-10-10 11:57:46 -07:00
Colter Downing a00c1adab5 improved test load 2025-10-09 19:37:39 -07:00
10 changed files with 144 additions and 600 deletions
+37 -132
View File
@@ -5,14 +5,13 @@ import base64
import subprocess
import dataclasses
import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task, get_running_loop
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property
from distutils.util import strtobool
from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
import asyncio
import requests
from Crypto.Signature import pkcs1_15
@@ -26,12 +25,8 @@ from lib.data_types import (
LogAction,
ApiPayload_T,
JsonDataException,
RequestMetrics,
BenchmarkResult
)
VERSION = "0.1.0"
MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
@@ -47,7 +42,7 @@ class Backend:
This class is responsible for:
1. Tailing logs and updating load time metrics
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
sending the request. It also updates metrics as it makes those requests.
sending the request. It also updates metrics as it makes those requests.
3. Running a benchmark from an EndpointHandler
"""
@@ -58,9 +53,7 @@ class Backend:
EndpointHandler # this endpoint handler will be used for benchmarking
)
log_actions: List[Tuple[LogAction, str]]
max_wait_time: float = 10.0
reqnum = -1
version = VERSION
msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
unsecured: bool = dataclasses.field(
@@ -69,16 +62,10 @@ class Backend:
def __post_init__(self):
self.metrics = Metrics()
self.metrics._set_version(self.version)
self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False
# NEW: FIFO queue + worker count
self.request_queue: "asyncio.Queue[tuple[EndpointHandler[ApiPayload_T], web.Request, asyncio.Future]]" = asyncio.Queue()
# If parallel allowed, let multiple workers drain the queue (order preserved by FIFO per worker; overall start order is FIFO).
self._num_workers: int = 1 if not self.allow_parallel_requests else int(os.environ.get("WORKERS", "4"))
@property
def pubkey(self) -> Optional[RSA.RsaKey]:
if self._pubkey is None:
@@ -96,22 +83,6 @@ class Backend:
timeout = ClientTimeout(total=None)
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
async def _worker(self):
while True:
handler, request, fut = await self.request_queue.get()
try:
# Skip if already cancelled while waiting in the queue
if fut.cancelled():
continue
res = await self.__process_enqueued_request(handler, request)
if not fut.cancelled():
fut.set_result(res)
except Exception as e:
if not fut.cancelled():
fut.set_exception(e)
finally:
self.request_queue.task_done()
def create_handler(
self,
handler: EndpointHandler[ApiPayload_T],
@@ -148,36 +119,7 @@ class Backend:
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""use this function to enqueue requests for FIFO processing"""
loop = get_running_loop()
fut: asyncio.Future = loop.create_future()
# If the client disconnects while waiting in the FIFO, cancel the future so the worker skips it
cancel_watch = create_task(request.wait_for_disconnection())
def _cancel_if_disconnected(_):
if not fut.done():
fut.cancel()
cancel_watch.add_done_callback(_cancel_if_disconnected)
try:
await self.request_queue.put((handler, request, fut))
return await fut
except asyncio.CancelledError:
# Propagate cancellation to ensure aiohttp doesn't expect a response body
raise
finally:
# Best-effort cleanup of the watcher
cancel_watch.cancel()
async def __process_enqueued_request(
self,
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""
This contains the original __handle_request logic and is invoked by workers,
ensuring FIFO execution via asyncio.Queue.
"""
"""use this function to forward requests to the model endpoint"""
try:
data = await request.json()
auth_data, payload = handler.get_data_from_request(data)
@@ -185,63 +127,56 @@ class Backend:
return web.json_response(data=e.message, status=422)
except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(
request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created"
)
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics)
raise asyncio.CancelledError
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload)
return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
await self.sem.acquire()
log.debug(
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
try:
response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status
log.debug(
" ".join(
[
f"request with reqnum:{request_metrics.reqnum}",
f"request with reqnum:{auth_data.reqnum}",
f"returned status code: {status_code},",
]
)
)
res = await handler.generate_client_response(request, response)
self.metrics._request_success(request_metrics)
self.metrics._request_success(workload=workload)
return res
except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(request_metrics)
self.metrics._request_errored(workload=workload)
return web.Response(status=500)
finally:
self.metrics._request_end(
workload=workload,
reqnum=auth_data.reqnum,
)
self.sem.release()
###########
if self.__check_signature(auth_data) is False:
self.metrics._request_reject(request_metrics)
return web.Response(status=401)
if self.metrics.model_metrics.wait_time > self.max_wait_time:
self.metrics._request_reject(request_metrics)
return web.Response(status=429)
acquired = False
try:
self.metrics._request_start(request_metrics)
# Preserve existing semaphore behavior for serializing requests when requested
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
acquired = True
log.debug(
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait(
[
create_task(make_request()),
@@ -249,27 +184,11 @@ class Backend:
],
return_when=FIRST_COMPLETED,
)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
done_task = done.pop()
try:
return done_task.result()
except Exception as e:
log.debug(f"Request task raised exception: {e}")
return web.Response(status=500)
except asyncio.CancelledError:
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499)
[task.cancel() for task in pending]
return done.pop().result()
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
finally:
# Always release the semaphore if it was acquired
if acquired:
self.sem.release()
self.metrics._request_end(request_metrics)
@cached_property
def healthcheck_session(self):
@@ -309,14 +228,8 @@ class Backend:
self.backend_errored(str(e))
async def _start_tracking(self) -> None:
# Start the FIFO workers alongside existing loops
worker_tasks = tuple(self._worker() for _ in range(self._num_workers))
await gather(
self.__read_logs(),
self.metrics._send_metrics_loop(),
self.__healthcheck(),
self.metrics._send_delete_requests_loop(),
*worker_tasks,
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
)
def backend_errored(self, msg: str) -> None:
@@ -395,26 +308,18 @@ class Backend:
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
start = time.time()
benchmark_requests = []
tasks = []
total_workload = 0
for i in range(concurrent_requests):
for _ in range(concurrent_requests):
payload = self.benchmark_handler.make_benchmark_payload()
workload = payload.count_workload()
task = self.__call_api(handler=self.benchmark_handler, payload=payload)
benchmark_requests.append(
BenchmarkResult(request_idx=i, workload=workload, task=task)
total_workload += payload.count_workload()
tasks.append(
self.__call_api(handler=self.benchmark_handler, payload=payload)
)
responses = await gather(*[br.task for br in benchmark_requests])
for br, response in zip(benchmark_requests, responses):
br.response = response
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
responses = await gather(*tasks)
time_elapsed = time.time() - start
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
if successful_responses == 0:
self.backend_errored("No successful responses from benchmark")
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
throughput = total_workload / time_elapsed
sum_throughput += throughput
@@ -428,7 +333,7 @@ class Backend:
f"Run: {run}, concurrent_requests: {concurrent_requests}",
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
f"Throughput: {throughput} workload/s",
f"Successful responses: {successful_responses}/{concurrent_requests}",
f"Successful responses: {len([r for r in responses if r.status == 200])}",
"#" * 60,
]
)
@@ -482,7 +387,7 @@ class Backend:
if line:
await handle_log_line(line.rstrip())
else:
await asyncio.sleep(LOG_POLL_INTERVAL)
time.sleep(LOG_POLL_INTERVAL)
###########
+8 -47
View File
@@ -3,7 +3,7 @@ import logging
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
from aiohttp import web, ClientResponse
import inspect
@@ -70,7 +70,6 @@ class AuthData:
endpoint: str
reqnum: int
url: str
request_idx: int
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -197,26 +196,6 @@ class SystemMetrics:
self.model_loading_time = None
@dataclass
class RequestMetrics:
"""Tracks metrics for an active request."""
request_idx: int
reqnum: int
workload: float
status: str
success: bool = False
@dataclass
class BenchmarkResult:
request_idx: int
workload: float
task: Awaitable[ClientResponse]
response: Optional[ClientResponse] = None
@property
def is_successful(self) -> bool:
return self.response is not None and self.response.status == 200
@dataclass
class ModelMetrics:
"""Model specific metrics"""
@@ -226,14 +205,12 @@ class ModelMetrics:
workload_received: float
workload_cancelled: float
workload_errored: float
workload_rejected: float
# these are not
workload_pending: float
error_msg: Optional[str]
max_throughput: float
requests_recieved: Set[int] = field(default_factory=set)
requests_working: dict[int, RequestMetrics] = field(default_factory=dict)
requests_deleting: list[RequestMetrics] = field(default_factory=list)
requests_working: Set[int] = field(default_factory=set)
last_update: float = field(default_factory=time.time)
@classmethod
@@ -243,30 +220,19 @@ class ModelMetrics:
workload_served=0.0,
workload_cancelled=0.0,
workload_errored=0.0,
workload_rejected=0.0,
workload_received=0.0,
error_msg=None,
max_throughput=0.0,
)
@property
def cur_perf(self) -> float:
return max(self.workload_served / (time.time() - self.last_update), 0.0)
@property
def workload_processing(self) -> float:
return max(self.workload_received - self.workload_cancelled, 0.0)
@property
def wait_time(self) -> float:
if (len(self.requests_working) == 0):
return 0.0
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
@property
def cur_load(self) -> float:
return sum([request.workload for request in self.requests_working.values()])
@property
def working_request_idxs(self) -> list[int]:
return [req.request_idx for req in self.requests_working.values()]
def set_errored(self, error_msg):
self.reset()
self.error_msg = error_msg
@@ -276,20 +242,16 @@ class ModelMetrics:
self.workload_received = 0
self.workload_cancelled = 0
self.workload_errored = 0
self.workload_rejected = 0
self.last_update = time.time()
@dataclass
class AutoScalerData:
class AutoScalaerData:
"""Data that is reported to autoscaler"""
id: int
version: str
loadtime: float
cur_load: float
rej_load: float
new_load: float
error_msg: str
max_perf: float
cur_perf: float
@@ -298,7 +260,6 @@ class AutoScalerData:
num_requests_working: int
num_requests_recieved: int
additional_disk_usage: float
working_request_idxs: list[int]
url: str
+31 -138
View File
@@ -5,14 +5,13 @@ import json
from asyncio import sleep
from dataclasses import dataclass, asdict, field
from functools import cache
import asyncio
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError
from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics
import requests
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
from typing import Awaitable, NoReturn, List
METRICS_UPDATE_INTERVAL = 1
DELETE_REQUESTS_INTERVAL = 1
log = logging.getLogger(__file__)
@@ -27,9 +26,7 @@ def get_url() -> str:
@dataclass
class Metrics:
version: str = "0"
last_metric_update: float = 0.0
last_request_served: float = 0.0
update_pending: bool = False
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
report_addr: List[str] = field(
@@ -38,84 +35,42 @@ class Metrics:
url: str = field(default_factory=get_url)
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
_session: ClientSession | None = field(default=None, init=False, repr=False)
async def http(self) -> ClientSession:
if self._session is None:
self._session = ClientSession(
timeout=ClientTimeout(total=10),
connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True)
)
return self._session
async def aclose(self) -> None:
if self._session is not None:
await self._session.close()
self._session = None
def _request_start(self, request: RequestMetrics) -> None:
def _request_start(self, workload: float, reqnum: int) -> None:
"""
this function is called prior to forwarding a request to a model API.
"""
log.debug("request start")
request.status = "Started"
self.model_metrics.workload_pending += request.workload
self.model_metrics.workload_received += request.workload
self.model_metrics.requests_recieved.add(request.reqnum)
self.model_metrics.requests_working[request.reqnum] = request
self.update_pending = True
self.model_metrics.workload_pending += workload
self.model_metrics.workload_received += workload
self.model_metrics.requests_recieved.add(reqnum)
self.model_metrics.requests_working.add(reqnum)
def _request_end(self, request: RequestMetrics) -> None:
def _request_end(self, workload: float, reqnum: int) -> None:
"""
this function is called after handling of a request ends, regardless of the outcome
"""
self.model_metrics.workload_pending -= request.workload
self.model_metrics.requests_working.pop(request.reqnum, None)
self.model_metrics.requests_deleting.append(request)
self.last_request_served = time.time()
self.model_metrics.workload_pending -= workload
self.model_metrics.requests_working.discard(reqnum)
def _request_success(self, request: RequestMetrics) -> None:
def _request_success(self, workload: float) -> None:
"""
this function is called after a response from model API is received and forwarded.
"""
self.model_metrics.workload_served += request.workload
request.status = "Success"
request.success = True
self.model_metrics.workload_served += workload
self.update_pending = True
def _request_errored(self, request: RequestMetrics) -> None:
def _request_errored(self, workload: float) -> None:
"""
this function is called if model API returns an error
"""
self.model_metrics.workload_errored += request.workload
request.status = "Error"
request.success = False
self.update_pending = True
self.model_metrics.workload_errored += workload
def _request_canceled(self, request: RequestMetrics) -> None:
def _request_canceled(self, workload: float) -> None:
"""
this function is called if client drops connection before model API has responded
"""
self.model_metrics.workload_cancelled += request.workload
request.success = True
request.status = "Cancelled"
def _request_reject(self, request: RequestMetrics):
"""
this function is called if the current wait time for the model is above max_wait_time
"""
self.model_metrics.requests_recieved.add(request.reqnum)
self.model_metrics.requests_deleting.append(request)
self.model_metrics.workload_rejected += request.workload
request.success = False
request.status = "Rejected"
self.update_pending = True
async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]:
while True:
await sleep(DELETE_REQUESTS_INTERVAL)
if len(self.model_metrics.requests_deleting) > 0:
await self.__send_delete_requests_and_reset()
self.model_metrics.workload_cancelled += workload
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
while True:
@@ -123,10 +78,10 @@ class Metrics:
elapsed = time.time() - self.last_metric_update
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
await self.__send_metrics_and_reset()
self.__send_metrics_and_reset(elapsed)
elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
await self.__send_metrics_and_reset()
self.__send_metrics_and_reset(elapsed)
def _model_loaded(self, max_throughput: float) -> None:
self.system_metrics.model_loading_time = (
@@ -139,88 +94,27 @@ class Metrics:
self.model_metrics.set_errored(error_msg)
self.system_metrics.model_is_loaded = True
def _set_version(self, version: str) -> None:
self.version = version
#######################################Private#######################################
async def __send_delete_requests_and_reset(self):
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
data = {
"worker_id": self.id,
"request_idxs": idxs,
"success": success_flag,
}
log.debug(
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
)
full_path = report_addr.rstrip("/") + "/delete_requests/"
for attempt in range(1, 4):
try:
session = await self.http()
async with session.post(full_path, json=data) as res:
log.debug(f"delete_requests response: {res.status}")
res.raise_for_status()
return True
except asyncio.TimeoutError:
log.debug("delete_requests timed out")
except (ClientResponseError, Exception) as e:
log.debug(f"delete_requests failed with error: {e}")
await asyncio.sleep(2)
log.debug(f"retrying delete_request, attempt: {attempt}")
return False
def __send_metrics_and_reset(self, elapsed):
# Take a snapshot of what we plan to send this tick.
# New arrivals after this snapshot will remain in the queue for the next tick.
snapshot = list(self.model_metrics.requests_deleting)
success_idxs = [r.request_idx for r in snapshot if r.success is True]
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
if not success_idxs and not failed_idxs:
return # nothing to do
for report_addr in self.report_addr:
sent_success = True
sent_failed = True
if success_idxs:
sent_success = await post(report_addr, success_idxs, True)
if failed_idxs:
sent_failed = await post(report_addr, failed_idxs, False)
if sent_success and sent_failed:
# Remove only the items we actually sent from the live queue.
sent_set = set(success_idxs) | set(failed_idxs)
self.model_metrics.requests_deleting[:] = [
r for r in self.model_metrics.requests_deleting
if r.request_idx not in sent_set
]
break
async def __send_metrics_and_reset(self):
def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData(
def compute_autoscaler_data() -> AutoScalaerData:
return AutoScalaerData(
id=self.id,
version=self.version,
loadtime=(self.system_metrics.model_loading_time or 0.0),
new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected,
cur_load=(self.model_metrics.workload_processing / elapsed),
max_perf=self.model_metrics.max_throughput,
cur_perf=self.model_metrics.workload_served,
cur_perf=self.model_metrics.cur_perf,
error_msg=self.model_metrics.error_msg or "",
num_requests_working=len(self.model_metrics.requests_working),
num_requests_recieved=len(self.model_metrics.requests_recieved),
additional_disk_usage=self.system_metrics.additional_disk_usage,
working_request_idxs=self.model_metrics.working_request_idxs,
cur_capacity=0,
max_capacity=0,
url=self.url,
)
async def send_data(report_addr: str) -> bool:
def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data()
full_path = report_addr.rstrip("/") + "/worker_status/"
log.debug(
@@ -235,15 +129,14 @@ class Metrics:
)
for attempt in range(1, 4):
try:
session = await self.http()
async with session.post(full_path, json=asdict(data)) as res:
res.raise_for_status()
res = requests.post(full_path, json=asdict(data), timeout=1)
res.raise_for_status()
return True
except asyncio.TimeoutError:
except requests.Timeout:
log.debug(f"autoscaler status update timed out")
except (ClientResponseError, Exception) as e:
except Exception as e:
log.debug(f"autoscaler status update failed with error: {e}")
await asyncio.sleep(2)
time.sleep(2)
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
log.debug(f"failed to send update through {report_addr}")
return False
@@ -253,7 +146,7 @@ class Metrics:
self.system_metrics.update_disk_usage()
for report_addr in self.report_addr:
success = await send_data(report_addr)
success = send_data(report_addr)
if success is True:
break
self.update_pending = False
+2 -2
View File
@@ -59,12 +59,12 @@ then
fi
# Fork testing
[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
if [[ -n ${PYWORKER_REF:-} ]]; then
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
fi
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
uv venv --managed-python "$ENV_PATH" -p 3.10
source "$ENV_PATH/bin/activate"
uv pip install -r "${SERVER_DIR}/requirements.txt"
+3 -15
View File
@@ -12,21 +12,9 @@ A docker image is provided but you may use any if the above requirements are met
## Benchmarking
### Custom Benchmark Workflows
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
**Ways to provide the benchmark file:**
- Fork this repository and add your `benchmark.json` file
- Write the file during worker provisioning (onstart script or setup phase)
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
### Default Benchmark (Fallback)
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
| Environment Variable | Default Value | Description |
| -------------------- | ------------- | ----------- |
@@ -36,7 +24,7 @@ The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
#### Calibrating Fallback Benchmark Duration
### Calibrating Benchmark Duration
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
+4 -28
View File
@@ -5,13 +5,12 @@ import dataclasses
from typing import Dict, Any
from functools import cache
from math import ceil
from pathlib import Path
import json
import logging
from lib.data_types import ApiPayload, JsonDataException
log = logging.getLogger(__file__)
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
def count_workload() -> float:
# Always 100.0 where there is a single instance of ComfyUI handling requests
@@ -25,32 +24,9 @@ class ComfyWorkflowData(ApiPayload):
@classmethod
def for_test(cls):
"""
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
Otherwise, use the variables available to simulate workflows of the required running time
Use the variables available to simulate workflows of the required running time
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
"""
# Try to load benchmark.json
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
if benchmark_file.exists():
try:
with open(benchmark_file, "r") as f:
benchmark_workflow = json.load(f)
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"workflow_json": benchmark_workflow
}
)
except (json.JSONDecodeError, IOError):
# JSON is malformed or file can't be read, fall through to default
log.error(f"Failed to benchmark using {benchmark_file}")
# Fallback: read prompts and construct payload
log.info("Using fallback method for benchmarking")
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
test_prompt = random.choice(test_prompts).rstrip()
return cls(
input={
@@ -1,107 +0,0 @@
{
"3": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "text, watermark",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
+1 -2
View File
@@ -19,7 +19,6 @@ MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: ", # This error is emitted when the model file is not there at all
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
]
@@ -71,7 +70,7 @@ class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
@property
def healthcheck_endpoint(self) -> Optional[str]:
return f"{MODEL_SERVER_URL}/health"
return "/health"
@classmethod
def payload_cls(cls) -> Type[ComfyWorkflowData]:
+4 -29
View File
@@ -119,25 +119,14 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
class CompletionsData(GenericData):
@classmethod
def for_test(cls) -> "CompletionsData":
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
test_input = {
"model": model,
"prompt": f"{system_prompt}\n\n{unique_question}",
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
@@ -164,18 +153,7 @@ class ChatCompletionsData(GenericData):
@classmethod
def for_test(cls) -> "ChatCompletionsData":
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
@@ -183,10 +161,7 @@ class ChatCompletionsData(GenericData):
# Chat completions use messages format instead of prompt
test_input = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt}, # Shared prefix
{"role": "user", "content": unique_question} # Unique per request
],
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"max_tokens": 500,
}
+54 -100
View File
@@ -42,7 +42,6 @@ class ReqResult:
total_ms: float
ok: bool
error: str = ""
status_code: int = 0
t_start: float = 0.0
t_end: float = 0.0
workload: float = 0.0
@@ -59,73 +58,31 @@ def do_one(endpoint_name: str,
route_session,
worker_session):
try:
workload = payload.count_workload()
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
u = payload.count_workload()
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": u}
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
start = time.time()
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
t_after_route = time.time()
if r0.status_code != 200:
results_list.append(ReqResult(worker_url="",
route_ms=(t_after_route - start) * 1000.0,
worker_ms=0.0,
total_ms=(t_after_route - start) * 1000.0,
ok=False,
error=f"route error {r0.reason} {r0.text}",
status_code=r0.status_code,
t_start=start - t0,
t_end=t_after_route - t0,
workload=workload))
results_list.append(ReqResult("", (t_after_route - start) * 1000.0, 0.0, (t_after_route - start) * 1000.0, False,
f"route {r0.status_code} {r0.text}"))
return
msg = r0.json()
# 1) Check if we got a worker back from route
worker_url = msg.get("url", "")
if not worker_url:
status = msg.get("status", "")
# 1) "Status" is in the response when no worker is ready
worker_sampled = True
status = msg.get("status", "")
if status:
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
if m:
tot, loading, standby, err = map(int, m.groups())
idle = max(tot - loading - standby - err, 0)
status_samples.append((time.time() - t0, idle))
worker_sampled = False
# 2) If we got a worker, send the request
if worker_url:
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
t_before_worker = time.time()
r1 = worker_session.post(
urljoin(worker_url, worker_endpoint),
json=req,
verify=get_cert_file_path(),
timeout=(4, 120),
)
t_after_worker = time.time()
if r1.status_code != 200:
results_list.append(ReqResult(worker_url=worker_url,
route_ms=(t_after_route - start) * 1000.0,
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
total_ms=(t_after_worker - start) * 1000.0,
ok=False,
error=f"worker inference error {r1.reason} {r1.text}",
status_code=r1.status_code,
t_start=start - t0,
t_end=t_after_worker - t0,
workload=workload))
return
# Success case
results_list.append(ReqResult(worker_url=worker_url,
route_ms=(t_after_route - start) * 1000.0,
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
total_ms=(t_after_worker - start) * 1000.0,
ok=True,
error="",
status_code=200,
t_start=start - t0,
t_end=t_after_worker - t0,
workload=workload))
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
if worker_url:
# 2) Otherwise (successful request), sample via /get_endpoint_workers/ for eligible (idle) worker tracking
if worker_sampled:
try:
r_status = route_session.post(
urljoin(server_url, "/get_endpoint_workers/"),
@@ -143,18 +100,29 @@ def do_one(endpoint_name: str,
status_samples.append((time.time() - t0, idle))
except Exception:
pass
# 3) Send the request
worker_address = msg["url"]
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
t1 = time.time()
# Use explicit connect/read timeouts to avoid long hangs
r1 = worker_session.post(
urljoin(worker_address, worker_endpoint),
json=req,
verify=get_cert_file_path(),
timeout=(4, 120),
)
t2 = time.time()
if r1.status_code != 200:
results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0,
(t2 - start) * 1000.0, False,
f"infer {r1.status_code} {r1.text}"))
return
results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0, (t2 - start) * 1000.0,
True, "", t_start=start - t0, t_end=t2 - t0, workload=u))
except Exception as e:
t = time.time()
results_list.append(ReqResult(worker_url="",
route_ms=0.0,
worker_ms=0.0,
total_ms=0.0,
ok=False,
error=f"unknown error {e}",
status_code=0,
t_start=t - t0,
t_end=t - t0,
workload=0.0))
results_list.append(ReqResult("", (t - start) * 1000.0, 0.0, (t - start) * 1000.0, False, str(e)))
def run_load_with_metrics(num_requests: int,
requests_per_second: float,
@@ -164,7 +132,7 @@ def run_load_with_metrics(num_requests: int,
worker_endpoint: str,
instance: str,
out_path: str):
# Resolve endpoint id + endpoint-scoped API key
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
account_api_key=account_api_key,
instance=instance)
@@ -177,7 +145,8 @@ def run_load_with_metrics(num_requests: int,
t0 = time.time()
results = []
status_samples = []
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
# Concurrency control
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "1024"))
submit_queue_factor = 2 # cap queued tasks to reduce memory
# Shared HTTP sessions with connection pooling (persistent connections)
@@ -189,9 +158,9 @@ def run_load_with_metrics(num_requests: int,
return sess
# Router: mostly single host, small connection pool is sufficient
route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
route_session = make_session(pool_connections=8, pool_maxsize=max_concurrency)
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
worker_session = make_session(pool_connections=max(256, max_concurrency), pool_maxsize=max_concurrency)
# Fire requests using a thread pool, scheduling at requested RPS
inflight = set()
@@ -240,12 +209,11 @@ def run_load_with_metrics(num_requests: int,
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
#route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
avg_total = float(np.mean(total_ms)) if succ else 0.0
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
avg_route = float(np.mean(route_ms)) if succ else 0.0
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
total_compute_time_ms = float(np.sum(worker_ms)) if succ else 0.0
# Distribution over workers (by host:port)
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
@@ -272,11 +240,11 @@ def run_load_with_metrics(num_requests: int,
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
print(f"Total compute time (sum worker latency, s): {total_compute_time_ms/1000.0:.2f}")
if errs:
print("Sample errors:")
for e in errs[:5]:
print(f" {e.status_code} {e.error}")
print(f" {e.error}")
# Plot: 2x3 grid
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
@@ -296,7 +264,7 @@ def run_load_with_metrics(num_requests: int,
# Latency histogram (total)
ax1 = axes[0, 1]
if succ:
ax1.hist(total_ms, bins=30)
ax1.hist(total_ms, bins=30, color="#4e79a7")
ax1.set_title("Total latency (ms)")
ax1.set_xlabel("ms")
ax1.set_ylabel("freq")
@@ -322,7 +290,7 @@ def run_load_with_metrics(num_requests: int,
ax_idle.plot(ts, vals, "-o", ms=3)
ax_idle.set_title("Completions per second")
ax_idle.set_xlabel("time (s)")
ax_idle.set_ylabel("completions / sec")
ax_idle.set_ylabel("req/s")
# Summary text
ax3 = axes[1, 1]
@@ -330,36 +298,22 @@ def run_load_with_metrics(num_requests: int,
text = (
f"Total requests: {total_reqs}\n"
f"Success: {succ} Errors: {len(errs)}\n"
f"Avg total latency: {avg_total:.1f} ms\n"
f"Avg latency: {avg_total:.1f} ms\n"
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
f"Avg route latency: {avg_route:.1f} ms\n"
f"Avg worker latency: {avg_worker:.1f} ms\n"
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
f"Total compute time: {total_compute_time_ms/1000.0:.2f} s"
)
ax3.set_title("Summary")
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
# Error count over time
ax_errors = axes[1, 2]
all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
if all_end_times:
min_second = min(all_end_times)
max_second = max(all_end_times)
# Count errors per second
errors_per_second = {}
for result in errs:
second = int(result.t_end)
errors_per_second[second] = errors_per_second.get(second, 0) + 1
# Create complete timeline including zeros
time_seconds = list(range(min_second, max_second + 1))
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
ax_errors.set_title("Errors per second")
ax_errors.set_xlabel("time (s)")
ax_errors.set_ylabel("errors / sec")
# Latency CDF (total_ms)
ax_cdf = axes[1, 2]
if succ:
x = np.sort(total_ms)
y = np.linspace(0, 1, len(x), endpoint=True)
ax_cdf.plot(x, y)
ax_cdf.set_title("Latency CDF")
ax_cdf.set_xlabel("ms")
ax_cdf.set_ylabel("fraction ≤ x")
# Ensure unique output path and create directory if needed
final_out_path = get_incremented_path(out_path)