Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3988cf553f | |||
| a00c1adab5 |
+58
-94
@@ -12,7 +12,6 @@ from distutils.util import strtobool
|
|||||||
|
|
||||||
from anyio import open_file
|
from anyio import open_file
|
||||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from Crypto.Signature import pkcs1_15
|
from Crypto.Signature import pkcs1_15
|
||||||
@@ -26,12 +25,8 @@ from lib.data_types import (
|
|||||||
LogAction,
|
LogAction,
|
||||||
ApiPayload_T,
|
ApiPayload_T,
|
||||||
JsonDataException,
|
JsonDataException,
|
||||||
RequestMetrics,
|
|
||||||
BenchmarkResult
|
|
||||||
)
|
)
|
||||||
|
|
||||||
VERSION = "0.2.0"
|
|
||||||
|
|
||||||
MSG_HISTORY_LEN = 100
|
MSG_HISTORY_LEN = 100
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
@@ -58,25 +53,15 @@ class Backend:
|
|||||||
EndpointHandler # this endpoint handler will be used for benchmarking
|
EndpointHandler # this endpoint handler will be used for benchmarking
|
||||||
)
|
)
|
||||||
log_actions: List[Tuple[LogAction, str]]
|
log_actions: List[Tuple[LogAction, str]]
|
||||||
max_wait_time: float = 10.0
|
|
||||||
reqnum = -1
|
reqnum = -1
|
||||||
version = VERSION
|
|
||||||
msg_history = []
|
msg_history = []
|
||||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||||
unsecured: bool = dataclasses.field(
|
unsecured: bool = dataclasses.field(
|
||||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||||
)
|
)
|
||||||
report_addr: str = dataclasses.field(
|
|
||||||
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
|
||||||
)
|
|
||||||
mtoken: str = dataclasses.field(
|
|
||||||
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.metrics = Metrics()
|
self.metrics = Metrics()
|
||||||
self.metrics._set_version(self.version)
|
|
||||||
self.metrics._set_mtoken(self.mtoken)
|
|
||||||
self._total_pubkey_fetch_errors = 0
|
self._total_pubkey_fetch_errors = 0
|
||||||
self._pubkey = self._fetch_pubkey()
|
self._pubkey = self._fetch_pubkey()
|
||||||
self.__start_healthcheck: bool = False
|
self.__start_healthcheck: bool = False
|
||||||
@@ -111,19 +96,23 @@ class Backend:
|
|||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
def _fetch_pubkey(self):
|
def _fetch_pubkey(self):
|
||||||
report_addr = self.report_addr.rstrip("/")
|
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||||
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
result = subprocess.check_output(command, universal_newlines=True)
|
||||||
try:
|
log.debug("public key:")
|
||||||
result = subprocess.check_output(command, universal_newlines=True)
|
log.debug(result)
|
||||||
log.debug("public key:")
|
key = None
|
||||||
log.debug(result)
|
for _ in range(5):
|
||||||
key = RSA.import_key(result)
|
try:
|
||||||
if key is not None:
|
key = RSA.import_key(result)
|
||||||
return key
|
break
|
||||||
except (ValueError , subprocess.CalledProcessError) as e:
|
except ValueError as e:
|
||||||
log.debug(f"Error downloading key: {e}")
|
log.debug(f"Error downloading key: {e}")
|
||||||
self.backend_errored("Failed to get autoscaler pubkey")
|
time.sleep(15)
|
||||||
|
if key is None:
|
||||||
|
self._total_pubkey_fetch_errors += 1
|
||||||
|
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
||||||
|
self.backend_errored("Failed to get autoscaler pubkey")
|
||||||
|
return key
|
||||||
|
|
||||||
async def __handle_request(
|
async def __handle_request(
|
||||||
self,
|
self,
|
||||||
@@ -139,56 +128,55 @@ class Backend:
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
|
||||||
|
|
||||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||||
await request.wait_for_disconnection()
|
await request.wait_for_disconnection()
|
||||||
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
|
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
||||||
self.metrics._request_canceled(request_metrics)
|
self.metrics._request_canceled(workload=workload)
|
||||||
raise asyncio.CancelledError
|
return web.Response(status=500)
|
||||||
|
|
||||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||||
|
log.debug(f"got request, {auth_data.reqnum}")
|
||||||
|
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
|
||||||
|
if self.allow_parallel_requests is False:
|
||||||
|
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
|
||||||
|
await self.sem.acquire()
|
||||||
|
log.debug(
|
||||||
|
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
|
||||||
try:
|
try:
|
||||||
response = await self.__call_api(handler=handler, payload=payload)
|
response = await self.__call_api(handler=handler, payload=payload)
|
||||||
status_code = response.status
|
status_code = response.status
|
||||||
log.debug(
|
log.debug(
|
||||||
" ".join(
|
" ".join(
|
||||||
[
|
[
|
||||||
f"request with reqnum:{request_metrics.reqnum}",
|
f"request with reqnum:{auth_data.reqnum}",
|
||||||
f"returned status code: {status_code},",
|
f"returned status code: {status_code},",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
res = await handler.generate_client_response(request, response)
|
res = await handler.generate_client_response(request, response)
|
||||||
self.metrics._request_success(request_metrics)
|
self.metrics._request_success(workload=workload)
|
||||||
return res
|
return res
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
log.debug(f"[backend] Request error: {e}")
|
log.debug(f"[backend] Request error: {e}")
|
||||||
self.metrics._request_errored(request_metrics)
|
self.metrics._request_errored(workload=workload)
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
finally:
|
||||||
|
self.metrics._request_end(
|
||||||
|
workload=workload,
|
||||||
|
reqnum=auth_data.reqnum,
|
||||||
|
)
|
||||||
|
self.sem.release()
|
||||||
|
|
||||||
###########
|
###########
|
||||||
|
|
||||||
if self.__check_signature(auth_data) is False:
|
if self.__check_signature(auth_data) is False:
|
||||||
self.metrics._request_reject(request_metrics)
|
|
||||||
return web.Response(status=401)
|
return web.Response(status=401)
|
||||||
|
|
||||||
if self.metrics.model_metrics.wait_time > self.max_wait_time:
|
|
||||||
self.metrics._request_reject(request_metrics)
|
|
||||||
return web.Response(status=429)
|
|
||||||
|
|
||||||
acquired = False
|
|
||||||
try:
|
try:
|
||||||
self.metrics._request_start(request_metrics)
|
|
||||||
if self.allow_parallel_requests is False:
|
|
||||||
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
|
||||||
await self.sem.acquire()
|
|
||||||
acquired = True
|
|
||||||
log.debug(
|
|
||||||
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
|
||||||
done, pending = await wait(
|
done, pending = await wait(
|
||||||
[
|
[
|
||||||
create_task(make_request()),
|
create_task(make_request()),
|
||||||
@@ -196,27 +184,11 @@ class Backend:
|
|||||||
],
|
],
|
||||||
return_when=FIRST_COMPLETED,
|
return_when=FIRST_COMPLETED,
|
||||||
)
|
)
|
||||||
for t in pending:
|
[task.cancel() for task in pending]
|
||||||
t.cancel()
|
return done.pop().result()
|
||||||
await asyncio.gather(*pending, return_exceptions=True)
|
|
||||||
|
|
||||||
done_task = done.pop()
|
|
||||||
try:
|
|
||||||
return done_task.result()
|
|
||||||
except Exception as e:
|
|
||||||
log.debug(f"Request task raised exception: {e}")
|
|
||||||
return web.Response(status=500)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# Client is gone. Do not write a response; just unwind.
|
|
||||||
return web.Response(status=499)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Exception in main handler loop {e}")
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
finally:
|
|
||||||
# Always release the semaphore if it was acquired
|
|
||||||
if acquired:
|
|
||||||
self.sem.release()
|
|
||||||
self.metrics._request_end(request_metrics)
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def healthcheck_session(self):
|
def healthcheck_session(self):
|
||||||
@@ -257,7 +229,7 @@ class Backend:
|
|||||||
|
|
||||||
async def _start_tracking(self) -> None:
|
async def _start_tracking(self) -> None:
|
||||||
await gather(
|
await gather(
|
||||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
|
||||||
)
|
)
|
||||||
|
|
||||||
def backend_errored(self, msg: str) -> None:
|
def backend_errored(self, msg: str) -> None:
|
||||||
@@ -289,7 +261,7 @@ class Backend:
|
|||||||
message = {
|
message = {
|
||||||
key: value
|
key: value
|
||||||
for (key, value) in (dataclasses.asdict(auth_data).items())
|
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||||
if key != "signature" and key != "__request_id"
|
if key != "signature"
|
||||||
}
|
}
|
||||||
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -299,7 +271,7 @@ class Backend:
|
|||||||
elif message in self.msg_history:
|
elif message in self.msg_history:
|
||||||
log.debug(f"message: {message} already in message history")
|
log.debug(f"message: {message} already in message history")
|
||||||
return False
|
return False
|
||||||
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
|
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
|
||||||
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||||
self.msg_history.append(message)
|
self.msg_history.append(message)
|
||||||
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||||
@@ -318,10 +290,10 @@ class Backend:
|
|||||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||||
log.debug("already ran benchmark")
|
log.debug("already ran benchmark")
|
||||||
# trigger model load
|
# trigger model load
|
||||||
# payload = self.benchmark_handler.make_benchmark_payload()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
# _ = await self.__call_api(
|
_ = await self.__call_api(
|
||||||
# handler=self.benchmark_handler, payload=payload
|
handler=self.benchmark_handler, payload=payload
|
||||||
# )
|
)
|
||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
@@ -336,26 +308,18 @@ class Backend:
|
|||||||
|
|
||||||
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
benchmark_requests = []
|
tasks = []
|
||||||
|
total_workload = 0
|
||||||
|
|
||||||
for i in range(concurrent_requests):
|
for _ in range(concurrent_requests):
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
workload = payload.count_workload()
|
total_workload += payload.count_workload()
|
||||||
task = self.__call_api(handler=self.benchmark_handler, payload=payload)
|
tasks.append(
|
||||||
benchmark_requests.append(
|
self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||||
BenchmarkResult(request_idx=i, workload=workload, task=task)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
responses = await gather(*[br.task for br in benchmark_requests])
|
responses = await gather(*tasks)
|
||||||
for br, response in zip(benchmark_requests, responses):
|
|
||||||
br.response = response
|
|
||||||
|
|
||||||
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
|
||||||
time_elapsed = time.time() - start
|
time_elapsed = time.time() - start
|
||||||
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
|
||||||
if successful_responses == 0:
|
|
||||||
self.backend_errored("No successful responses from benchmark")
|
|
||||||
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
|
|
||||||
|
|
||||||
throughput = total_workload / time_elapsed
|
throughput = total_workload / time_elapsed
|
||||||
sum_throughput += throughput
|
sum_throughput += throughput
|
||||||
@@ -369,7 +333,7 @@ class Backend:
|
|||||||
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
||||||
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
||||||
f"Throughput: {throughput} workload/s",
|
f"Throughput: {throughput} workload/s",
|
||||||
f"Successful responses: {successful_responses}/{concurrent_requests}",
|
f"Successful responses: {len([r for r in responses if r.status == 200])}",
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -396,7 +360,7 @@ class Backend:
|
|||||||
)
|
)
|
||||||
# some backends need a few seconds after logging successful startup before
|
# some backends need a few seconds after logging successful startup before
|
||||||
# they can begin accepting requests
|
# they can begin accepting requests
|
||||||
# await sleep(5)
|
await sleep(5)
|
||||||
try:
|
try:
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
self.__start_healthcheck = True
|
self.__start_healthcheck = True
|
||||||
@@ -423,7 +387,7 @@ class Backend:
|
|||||||
if line:
|
if line:
|
||||||
await handle_log_line(line.rstrip())
|
await handle_log_line(line.rstrip())
|
||||||
else:
|
else:
|
||||||
await asyncio.sleep(LOG_POLL_INTERVAL)
|
time.sleep(LOG_POLL_INTERVAL)
|
||||||
|
|
||||||
###########
|
###########
|
||||||
|
|
||||||
|
|||||||
+11
-52
@@ -3,7 +3,7 @@ import logging
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
|
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
|
||||||
from aiohttp import web, ClientResponse
|
from aiohttp import web, ClientResponse
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@@ -65,11 +65,10 @@ class ApiPayload(ABC):
|
|||||||
class AuthData:
|
class AuthData:
|
||||||
"""data used to authenticate requester"""
|
"""data used to authenticate requester"""
|
||||||
|
|
||||||
|
signature: str
|
||||||
cost: str
|
cost: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
reqnum: int
|
reqnum: int
|
||||||
request_idx: int
|
|
||||||
signature: str
|
|
||||||
url: str
|
url: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -190,34 +189,13 @@ class SystemMetrics:
|
|||||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||||
self.last_disk_usage = disk_usage
|
self.last_disk_usage = disk_usage
|
||||||
|
|
||||||
def reset(self, expected: float | None) -> None:
|
def reset(self):
|
||||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||||
# as well: they should send model_loading_time once when they are done loading
|
# as well: they should send model_loading_time once when they are done loading
|
||||||
if self.model_loading_time == expected:
|
self.model_loading_time = None
|
||||||
self.model_loading_time = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RequestMetrics:
|
|
||||||
"""Tracks metrics for an active request."""
|
|
||||||
request_idx: int
|
|
||||||
reqnum: int
|
|
||||||
workload: float
|
|
||||||
status: str
|
|
||||||
success: bool = False
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BenchmarkResult:
|
|
||||||
request_idx: int
|
|
||||||
workload: float
|
|
||||||
task: Awaitable[ClientResponse]
|
|
||||||
response: Optional[ClientResponse] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_successful(self) -> bool:
|
|
||||||
return self.response is not None and self.response.status == 200
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelMetrics:
|
class ModelMetrics:
|
||||||
"""Model specific metrics"""
|
"""Model specific metrics"""
|
||||||
@@ -227,14 +205,12 @@ class ModelMetrics:
|
|||||||
workload_received: float
|
workload_received: float
|
||||||
workload_cancelled: float
|
workload_cancelled: float
|
||||||
workload_errored: float
|
workload_errored: float
|
||||||
workload_rejected: float
|
|
||||||
# these are not
|
# these are not
|
||||||
workload_pending: float
|
workload_pending: float
|
||||||
error_msg: Optional[str]
|
error_msg: Optional[str]
|
||||||
max_throughput: float
|
max_throughput: float
|
||||||
requests_recieved: Set[int] = field(default_factory=set)
|
requests_recieved: Set[int] = field(default_factory=set)
|
||||||
requests_working: dict[int, RequestMetrics] = field(default_factory=dict)
|
requests_working: Set[int] = field(default_factory=set)
|
||||||
requests_deleting: list[RequestMetrics] = field(default_factory=list)
|
|
||||||
last_update: float = field(default_factory=time.time)
|
last_update: float = field(default_factory=time.time)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -244,30 +220,19 @@ class ModelMetrics:
|
|||||||
workload_served=0.0,
|
workload_served=0.0,
|
||||||
workload_cancelled=0.0,
|
workload_cancelled=0.0,
|
||||||
workload_errored=0.0,
|
workload_errored=0.0,
|
||||||
workload_rejected=0.0,
|
|
||||||
workload_received=0.0,
|
workload_received=0.0,
|
||||||
error_msg=None,
|
error_msg=None,
|
||||||
max_throughput=0.0,
|
max_throughput=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cur_perf(self) -> float:
|
||||||
|
return max(self.workload_served / (time.time() - self.last_update), 0.0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workload_processing(self) -> float:
|
def workload_processing(self) -> float:
|
||||||
return max(self.workload_received - self.workload_cancelled, 0.0)
|
return max(self.workload_received - self.workload_cancelled, 0.0)
|
||||||
|
|
||||||
@property
|
|
||||||
def wait_time(self) -> float:
|
|
||||||
if (len(self.requests_working) == 0):
|
|
||||||
return 0.0
|
|
||||||
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cur_load(self) -> float:
|
|
||||||
return sum([request.workload for request in self.requests_working.values()])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def working_request_idxs(self) -> list[int]:
|
|
||||||
return [req.request_idx for req in self.requests_working.values()]
|
|
||||||
|
|
||||||
def set_errored(self, error_msg):
|
def set_errored(self, error_msg):
|
||||||
self.reset()
|
self.reset()
|
||||||
self.error_msg = error_msg
|
self.error_msg = error_msg
|
||||||
@@ -277,21 +242,16 @@ class ModelMetrics:
|
|||||||
self.workload_received = 0
|
self.workload_received = 0
|
||||||
self.workload_cancelled = 0
|
self.workload_cancelled = 0
|
||||||
self.workload_errored = 0
|
self.workload_errored = 0
|
||||||
self.workload_rejected = 0
|
|
||||||
self.last_update = time.time()
|
self.last_update = time.time()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AutoScalerData:
|
class AutoScalaerData:
|
||||||
"""Data that is reported to autoscaler"""
|
"""Data that is reported to autoscaler"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
mtoken: str
|
|
||||||
version: str
|
|
||||||
loadtime: float
|
loadtime: float
|
||||||
cur_load: float
|
cur_load: float
|
||||||
rej_load: float
|
|
||||||
new_load: float
|
|
||||||
error_msg: str
|
error_msg: str
|
||||||
max_perf: float
|
max_perf: float
|
||||||
cur_perf: float
|
cur_perf: float
|
||||||
@@ -300,7 +260,6 @@ class AutoScalerData:
|
|||||||
num_requests_working: int
|
num_requests_working: int
|
||||||
num_requests_recieved: int
|
num_requests_recieved: int
|
||||||
additional_disk_usage: float
|
additional_disk_usage: float
|
||||||
working_request_idxs: list[int]
|
|
||||||
url: str
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+39
-170
@@ -5,14 +5,13 @@ import json
|
|||||||
from asyncio import sleep
|
from asyncio import sleep
|
||||||
from dataclasses import dataclass, asdict, field
|
from dataclasses import dataclass, asdict, field
|
||||||
from functools import cache
|
from functools import cache
|
||||||
import asyncio
|
|
||||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError
|
|
||||||
|
|
||||||
from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics
|
import requests
|
||||||
|
|
||||||
|
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
|
||||||
from typing import Awaitable, NoReturn, List
|
from typing import Awaitable, NoReturn, List
|
||||||
|
|
||||||
METRICS_UPDATE_INTERVAL = 1
|
METRICS_UPDATE_INTERVAL = 1
|
||||||
DELETE_REQUESTS_INTERVAL = 1
|
|
||||||
|
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
@@ -27,10 +26,7 @@ def get_url() -> str:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Metrics:
|
class Metrics:
|
||||||
version: str = "0"
|
|
||||||
mtoken: str = ""
|
|
||||||
last_metric_update: float = 0.0
|
last_metric_update: float = 0.0
|
||||||
last_request_served: float = 0.0
|
|
||||||
update_pending: bool = False
|
update_pending: bool = False
|
||||||
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
|
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
|
||||||
report_addr: List[str] = field(
|
report_addr: List[str] = field(
|
||||||
@@ -39,84 +35,42 @@ class Metrics:
|
|||||||
url: str = field(default_factory=get_url)
|
url: str = field(default_factory=get_url)
|
||||||
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
|
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
|
||||||
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
|
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
|
||||||
_session: ClientSession | None = field(default=None, init=False, repr=False)
|
|
||||||
|
|
||||||
async def http(self) -> ClientSession:
|
def _request_start(self, workload: float, reqnum: int) -> None:
|
||||||
if self._session is None:
|
|
||||||
self._session = ClientSession(
|
|
||||||
timeout=ClientTimeout(total=10),
|
|
||||||
connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True)
|
|
||||||
)
|
|
||||||
return self._session
|
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
|
||||||
if self._session is not None:
|
|
||||||
await self._session.close()
|
|
||||||
self._session = None
|
|
||||||
|
|
||||||
def _request_start(self, request: RequestMetrics) -> None:
|
|
||||||
"""
|
"""
|
||||||
this function is called prior to forwarding a request to a model API.
|
this function is called prior to forwarding a request to a model API.
|
||||||
"""
|
"""
|
||||||
log.debug("request start")
|
log.debug("request start")
|
||||||
request.status = "Started"
|
self.model_metrics.workload_pending += workload
|
||||||
self.model_metrics.workload_pending += request.workload
|
self.model_metrics.workload_received += workload
|
||||||
self.model_metrics.workload_received += request.workload
|
self.model_metrics.requests_recieved.add(reqnum)
|
||||||
self.model_metrics.requests_recieved.add(request.reqnum)
|
self.model_metrics.requests_working.add(reqnum)
|
||||||
self.model_metrics.requests_working[request.reqnum] = request
|
|
||||||
self.update_pending = True
|
|
||||||
|
|
||||||
def _request_end(self, request: RequestMetrics) -> None:
|
def _request_end(self, workload: float, reqnum: int) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called after handling of a request ends, regardless of the outcome
|
this function is called after handling of a request ends, regardless of the outcome
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_pending -= request.workload
|
self.model_metrics.workload_pending -= workload
|
||||||
self.model_metrics.requests_working.pop(request.reqnum, None)
|
self.model_metrics.requests_working.discard(reqnum)
|
||||||
self.model_metrics.requests_deleting.append(request)
|
|
||||||
self.last_request_served = time.time()
|
|
||||||
|
|
||||||
def _request_success(self, request: RequestMetrics) -> None:
|
def _request_success(self, workload: float) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called after a response from model API is received and forwarded.
|
this function is called after a response from model API is received and forwarded.
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_served += request.workload
|
self.model_metrics.workload_served += workload
|
||||||
request.status = "Success"
|
|
||||||
request.success = True
|
|
||||||
self.update_pending = True
|
self.update_pending = True
|
||||||
|
|
||||||
def _request_errored(self, request: RequestMetrics) -> None:
|
def _request_errored(self, workload: float) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called if model API returns an error
|
this function is called if model API returns an error
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_errored += request.workload
|
self.model_metrics.workload_errored += workload
|
||||||
request.status = "Error"
|
|
||||||
request.success = False
|
|
||||||
self.update_pending = True
|
|
||||||
|
|
||||||
def _request_canceled(self, request: RequestMetrics) -> None:
|
def _request_canceled(self, workload: float) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called if client drops connection before model API has responded
|
this function is called if client drops connection before model API has responded
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_cancelled += request.workload
|
self.model_metrics.workload_cancelled += workload
|
||||||
request.success = True
|
|
||||||
request.status = "Cancelled"
|
|
||||||
|
|
||||||
def _request_reject(self, request: RequestMetrics):
|
|
||||||
"""
|
|
||||||
this function is called if the current wait time for the model is above max_wait_time
|
|
||||||
"""
|
|
||||||
self.model_metrics.requests_recieved.add(request.reqnum)
|
|
||||||
self.model_metrics.requests_deleting.append(request)
|
|
||||||
self.model_metrics.workload_rejected += request.workload
|
|
||||||
request.success = False
|
|
||||||
request.status = "Rejected"
|
|
||||||
self.update_pending = True
|
|
||||||
|
|
||||||
async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]:
|
|
||||||
while True:
|
|
||||||
await sleep(DELETE_REQUESTS_INTERVAL)
|
|
||||||
if len(self.model_metrics.requests_deleting) > 0:
|
|
||||||
await self.__send_delete_requests_and_reset()
|
|
||||||
|
|
||||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||||
while True:
|
while True:
|
||||||
@@ -124,10 +78,10 @@ class Metrics:
|
|||||||
elapsed = time.time() - self.last_metric_update
|
elapsed = time.time() - self.last_metric_update
|
||||||
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
||||||
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
||||||
await self.__send_metrics_and_reset()
|
self.__send_metrics_and_reset(elapsed)
|
||||||
elif self.update_pending or elapsed > 10:
|
elif self.update_pending or elapsed > 10:
|
||||||
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
||||||
await self.__send_metrics_and_reset()
|
self.__send_metrics_and_reset(elapsed)
|
||||||
|
|
||||||
def _model_loaded(self, max_throughput: float) -> None:
|
def _model_loaded(self, max_throughput: float) -> None:
|
||||||
self.system_metrics.model_loading_time = (
|
self.system_metrics.model_loading_time = (
|
||||||
@@ -140,130 +94,49 @@ class Metrics:
|
|||||||
self.model_metrics.set_errored(error_msg)
|
self.model_metrics.set_errored(error_msg)
|
||||||
self.system_metrics.model_is_loaded = True
|
self.system_metrics.model_is_loaded = True
|
||||||
|
|
||||||
def _set_version(self, version: str) -> None:
|
|
||||||
self.version = version
|
|
||||||
|
|
||||||
def _set_mtoken(self, mtoken: str) -> None:
|
|
||||||
self.mtoken = mtoken
|
|
||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
|
|
||||||
async def __send_delete_requests_and_reset(self):
|
def __send_metrics_and_reset(self, elapsed):
|
||||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
|
||||||
data = {
|
|
||||||
"worker_id": self.id,
|
|
||||||
"mtoken": self.mtoken,
|
|
||||||
"request_idxs": idxs,
|
|
||||||
"success": success_flag,
|
|
||||||
}
|
|
||||||
log.debug(
|
|
||||||
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
|
|
||||||
)
|
|
||||||
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
|
||||||
for attempt in range(1, 4):
|
|
||||||
try:
|
|
||||||
session = await self.http()
|
|
||||||
async with session.post(full_path, json=data) as res:
|
|
||||||
log.debug(f"delete_requests response: {res.status}")
|
|
||||||
res.raise_for_status()
|
|
||||||
return True
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
log.debug("delete_requests timed out")
|
|
||||||
except (ClientResponseError, Exception) as e:
|
|
||||||
log.debug(f"delete_requests failed with error: {e}")
|
|
||||||
await asyncio.sleep(2)
|
|
||||||
log.debug(f"retrying delete_request, attempt: {attempt}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Take a snapshot of what we plan to send this tick.
|
def compute_autoscaler_data() -> AutoScalaerData:
|
||||||
# New arrivals after this snapshot will remain in the queue for the next tick.
|
return AutoScalaerData(
|
||||||
snapshot = list(self.model_metrics.requests_deleting)
|
|
||||||
success_idxs = [r.request_idx for r in snapshot if r.success is True]
|
|
||||||
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
|
|
||||||
|
|
||||||
if not success_idxs and not failed_idxs:
|
|
||||||
return # nothing to do
|
|
||||||
|
|
||||||
for report_addr in self.report_addr:
|
|
||||||
# TODO: Add a Redis subscriber queue for delete_requests
|
|
||||||
if report_addr == "https://cloud.vast.ai/api/v0":
|
|
||||||
# Patch: ignore the Redis API report_addr
|
|
||||||
continue
|
|
||||||
sent_success = True
|
|
||||||
sent_failed = True
|
|
||||||
|
|
||||||
if success_idxs:
|
|
||||||
sent_success = await post(report_addr, success_idxs, True)
|
|
||||||
if failed_idxs:
|
|
||||||
sent_failed = await post(report_addr, failed_idxs, False)
|
|
||||||
|
|
||||||
if sent_success and sent_failed:
|
|
||||||
# Remove only the items we actually sent from the live queue.
|
|
||||||
sent_set = set(success_idxs) | set(failed_idxs)
|
|
||||||
self.model_metrics.requests_deleting[:] = [
|
|
||||||
r for r in self.model_metrics.requests_deleting
|
|
||||||
if r.request_idx not in sent_set
|
|
||||||
]
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
async def __send_metrics_and_reset(self):
|
|
||||||
|
|
||||||
loadtime_snapshot = self.system_metrics.model_loading_time
|
|
||||||
|
|
||||||
def compute_autoscaler_data() -> AutoScalerData:
|
|
||||||
return AutoScalerData(
|
|
||||||
id=self.id,
|
id=self.id,
|
||||||
mtoken=self.mtoken,
|
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
||||||
version=self.version,
|
cur_load=(self.model_metrics.workload_processing / elapsed),
|
||||||
loadtime=(loadtime_snapshot or 0.0),
|
|
||||||
new_load=self.model_metrics.workload_processing,
|
|
||||||
cur_load=self.model_metrics.cur_load,
|
|
||||||
rej_load=self.model_metrics.workload_rejected,
|
|
||||||
max_perf=self.model_metrics.max_throughput,
|
max_perf=self.model_metrics.max_throughput,
|
||||||
cur_perf=self.model_metrics.workload_served,
|
cur_perf=self.model_metrics.cur_perf,
|
||||||
error_msg=self.model_metrics.error_msg or "",
|
error_msg=self.model_metrics.error_msg or "",
|
||||||
num_requests_working=len(self.model_metrics.requests_working),
|
num_requests_working=len(self.model_metrics.requests_working),
|
||||||
num_requests_recieved=len(self.model_metrics.requests_recieved),
|
num_requests_recieved=len(self.model_metrics.requests_recieved),
|
||||||
additional_disk_usage=self.system_metrics.additional_disk_usage,
|
additional_disk_usage=self.system_metrics.additional_disk_usage,
|
||||||
working_request_idxs=self.model_metrics.working_request_idxs,
|
|
||||||
cur_capacity=0,
|
cur_capacity=0,
|
||||||
max_capacity=0,
|
max_capacity=0,
|
||||||
url=self.url,
|
url=self.url,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_data(report_addr: str) -> bool:
|
def send_data(report_addr: str) -> bool:
|
||||||
data = compute_autoscaler_data()
|
data = compute_autoscaler_data()
|
||||||
log_data = asdict(data)
|
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||||
def obfuscate(secret: str) -> str:
|
|
||||||
if secret is None:
|
|
||||||
return ""
|
|
||||||
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
|
||||||
|
|
||||||
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
|
||||||
log.debug(
|
log.debug(
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
f"sending data to autoscaler",
|
f"sending data to autoscaler",
|
||||||
f"{json.dumps(log_data, indent=2)}",
|
f"{json.dumps((asdict(data)), indent=2)}",
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
|
||||||
for attempt in range(1, 4):
|
for attempt in range(1, 4):
|
||||||
try:
|
try:
|
||||||
session = await self.http()
|
res = requests.post(full_path, json=asdict(data), timeout=1)
|
||||||
async with session.post(full_path, json=asdict(data)) as res:
|
res.raise_for_status()
|
||||||
res.raise_for_status()
|
|
||||||
return True
|
return True
|
||||||
except asyncio.TimeoutError:
|
except requests.Timeout:
|
||||||
log.debug(f"autoscaler status update timed out")
|
log.debug(f"autoscaler status update timed out")
|
||||||
except (ClientResponseError, Exception) as e:
|
except Exception as e:
|
||||||
log.debug(f"autoscaler status update failed with error: {e}")
|
log.debug(f"autoscaler status update failed with error: {e}")
|
||||||
await asyncio.sleep(2)
|
time.sleep(2)
|
||||||
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
|
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
|
||||||
log.debug(f"failed to send update through {report_addr}")
|
log.debug(f"failed to send update through {report_addr}")
|
||||||
return False
|
return False
|
||||||
@@ -272,15 +145,11 @@ class Metrics:
|
|||||||
|
|
||||||
self.system_metrics.update_disk_usage()
|
self.system_metrics.update_disk_usage()
|
||||||
|
|
||||||
sent = False
|
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
if await send_data(report_addr):
|
success = send_data(report_addr)
|
||||||
sent = True
|
if success is True:
|
||||||
break
|
break
|
||||||
|
self.update_pending = False
|
||||||
if sent:
|
self.model_metrics.reset()
|
||||||
# clear the one-shot loadtime only if we actually sent *this* value
|
self.system_metrics.reset()
|
||||||
self.system_metrics.reset(expected=loadtime_snapshot)
|
self.last_metric_update = time.time()
|
||||||
self.update_pending = False
|
|
||||||
self.model_metrics.reset()
|
|
||||||
self.last_metric_update = time.time()
|
|
||||||
|
|||||||
+3
-3
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
|
|||||||
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
||||||
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
||||||
|
|
||||||
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
|
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
|
||||||
USE_SSL="${USE_SSL:-true}"
|
USE_SSL="${USE_SSL:-true}"
|
||||||
WORKER_PORT="${WORKER_PORT:-3000}"
|
WORKER_PORT="${WORKER_PORT:-3000}"
|
||||||
mkdir -p "$WORKSPACE_DIR"
|
mkdir -p "$WORKSPACE_DIR"
|
||||||
@@ -59,12 +59,12 @@ then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Fork testing
|
# Fork testing
|
||||||
[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
|
git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
|
||||||
if [[ -n ${PYWORKER_REF:-} ]]; then
|
if [[ -n ${PYWORKER_REF:-} ]]; then
|
||||||
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
|
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
|
||||||
fi
|
fi
|
||||||
|
|
||||||
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
|
uv venv --managed-python "$ENV_PATH" -p 3.10
|
||||||
source "$ENV_PATH/bin/activate"
|
source "$ENV_PATH/bin/activate"
|
||||||
|
|
||||||
uv pip install -r "${SERVER_DIR}/requirements.txt"
|
uv pip install -r "${SERVER_DIR}/requirements.txt"
|
||||||
|
|||||||
@@ -12,21 +12,9 @@ A docker image is provided but you may use any if the above requirements are met
|
|||||||
|
|
||||||
## Benchmarking
|
## Benchmarking
|
||||||
|
|
||||||
### Custom Benchmark Workflows
|
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
|
||||||
|
|
||||||
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
|
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
|
||||||
|
|
||||||
**Ways to provide the benchmark file:**
|
|
||||||
- Fork this repository and add your `benchmark.json` file
|
|
||||||
- Write the file during worker provisioning (onstart script or setup phase)
|
|
||||||
|
|
||||||
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
|
|
||||||
|
|
||||||
### Default Benchmark (Fallback)
|
|
||||||
|
|
||||||
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
|
|
||||||
|
|
||||||
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
|
|
||||||
|
|
||||||
| Environment Variable | Default Value | Description |
|
| Environment Variable | Default Value | Description |
|
||||||
| -------------------- | ------------- | ----------- |
|
| -------------------- | ------------- | ----------- |
|
||||||
@@ -36,7 +24,7 @@ The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to
|
|||||||
|
|
||||||
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
||||||
|
|
||||||
#### Calibrating Fallback Benchmark Duration
|
### Calibrating Benchmark Duration
|
||||||
|
|
||||||
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
||||||
|
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ def call_text2image_workflow(
|
|||||||
endpoint=route_response["endpoint"],
|
endpoint=route_response["endpoint"],
|
||||||
reqnum=route_response["reqnum"],
|
reqnum=route_response["reqnum"],
|
||||||
url=route_response["url"],
|
url=route_response["url"],
|
||||||
request_idx=route_response["request_idx"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the payload for the worker request
|
# Build the payload for the worker request
|
||||||
|
|||||||
@@ -5,13 +5,12 @@ import dataclasses
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from pathlib import Path
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from lib.data_types import ApiPayload, JsonDataException
|
from lib.data_types import ApiPayload, JsonDataException
|
||||||
|
|
||||||
log = logging.getLogger(__file__)
|
|
||||||
|
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
|
||||||
|
test_prompts = f.readlines()
|
||||||
|
|
||||||
def count_workload() -> float:
|
def count_workload() -> float:
|
||||||
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
||||||
@@ -25,32 +24,9 @@ class ComfyWorkflowData(ApiPayload):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls):
|
def for_test(cls):
|
||||||
"""
|
"""
|
||||||
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
|
Use the variables available to simulate workflows of the required running time
|
||||||
Otherwise, use the variables available to simulate workflows of the required running time
|
|
||||||
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
||||||
"""
|
"""
|
||||||
# Try to load benchmark.json
|
|
||||||
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
|
|
||||||
|
|
||||||
if benchmark_file.exists():
|
|
||||||
try:
|
|
||||||
with open(benchmark_file, "r") as f:
|
|
||||||
benchmark_workflow = json.load(f)
|
|
||||||
return cls(
|
|
||||||
input={
|
|
||||||
"request_id": f"test-{random.randint(1000, 99999)}",
|
|
||||||
"workflow_json": benchmark_workflow
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except (json.JSONDecodeError, IOError):
|
|
||||||
# JSON is malformed or file can't be read, fall through to default
|
|
||||||
log.error(f"Failed to benchmark using {benchmark_file}")
|
|
||||||
|
|
||||||
# Fallback: read prompts and construct payload
|
|
||||||
log.info("Using fallback method for benchmarking")
|
|
||||||
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
|
|
||||||
test_prompts = f.readlines()
|
|
||||||
|
|
||||||
test_prompt = random.choice(test_prompts).rstrip()
|
test_prompt = random.choice(test_prompts).rstrip()
|
||||||
return cls(
|
return cls(
|
||||||
input={
|
input={
|
||||||
|
|||||||
@@ -1,107 +0,0 @@
|
|||||||
{
|
|
||||||
"3": {
|
|
||||||
"inputs": {
|
|
||||||
"seed": "__RANDOM_INT__",
|
|
||||||
"steps": 20,
|
|
||||||
"cfg": 8,
|
|
||||||
"sampler_name": "euler",
|
|
||||||
"scheduler": "normal",
|
|
||||||
"denoise": 1,
|
|
||||||
"model": [
|
|
||||||
"4",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"positive": [
|
|
||||||
"6",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"negative": [
|
|
||||||
"7",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"latent_image": [
|
|
||||||
"5",
|
|
||||||
0
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "KSampler",
|
|
||||||
"_meta": {
|
|
||||||
"title": "KSampler"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"4": {
|
|
||||||
"inputs": {
|
|
||||||
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
|
|
||||||
},
|
|
||||||
"class_type": "CheckpointLoaderSimple",
|
|
||||||
"_meta": {
|
|
||||||
"title": "Load Checkpoint"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"5": {
|
|
||||||
"inputs": {
|
|
||||||
"width": 512,
|
|
||||||
"height": 512,
|
|
||||||
"batch_size": 1
|
|
||||||
},
|
|
||||||
"class_type": "EmptyLatentImage",
|
|
||||||
"_meta": {
|
|
||||||
"title": "Empty Latent Image"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"6": {
|
|
||||||
"inputs": {
|
|
||||||
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
|
||||||
"clip": [
|
|
||||||
"4",
|
|
||||||
1
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "CLIPTextEncode",
|
|
||||||
"_meta": {
|
|
||||||
"title": "CLIP Text Encode (Prompt)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"7": {
|
|
||||||
"inputs": {
|
|
||||||
"text": "text, watermark",
|
|
||||||
"clip": [
|
|
||||||
"4",
|
|
||||||
1
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "CLIPTextEncode",
|
|
||||||
"_meta": {
|
|
||||||
"title": "CLIP Text Encode (Prompt)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"8": {
|
|
||||||
"inputs": {
|
|
||||||
"samples": [
|
|
||||||
"3",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"vae": [
|
|
||||||
"4",
|
|
||||||
2
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "VAEDecode",
|
|
||||||
"_meta": {
|
|
||||||
"title": "VAE Decode"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"9": {
|
|
||||||
"inputs": {
|
|
||||||
"filename_prefix": "ComfyUI",
|
|
||||||
"images": [
|
|
||||||
"8",
|
|
||||||
0
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "SaveImage",
|
|
||||||
"_meta": {
|
|
||||||
"title": "Save Image"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -19,7 +19,6 @@ MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
|||||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||||
"Value not in list: ", # This error is emitted when the model file is not there at all
|
"Value not in list: ", # This error is emitted when the model file is not there at all
|
||||||
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +70,7 @@ class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def healthcheck_endpoint(self) -> Optional[str]:
|
def healthcheck_endpoint(self) -> Optional[str]:
|
||||||
return f"{MODEL_SERVER_URL}/health"
|
return "/health"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def payload_cls(cls) -> Type[ComfyWorkflowData]:
|
def payload_cls(cls) -> Type[ComfyWorkflowData]:
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ def call_custom_workflow_for_sd3(
|
|||||||
endpoint=message["endpoint"],
|
endpoint=message["endpoint"],
|
||||||
reqnum=message["reqnum"],
|
reqnum=message["reqnum"],
|
||||||
url=message["url"],
|
url=message["url"],
|
||||||
request_idx=message["request_idx"],
|
|
||||||
)
|
)
|
||||||
workflow = {
|
workflow = {
|
||||||
"3": {
|
"3": {
|
||||||
|
|||||||
@@ -119,25 +119,14 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
|
|||||||
class CompletionsData(GenericData):
|
class CompletionsData(GenericData):
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls) -> "CompletionsData":
|
def for_test(cls) -> "CompletionsData":
|
||||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||||
|
|
||||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
|
||||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
|
||||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
|
||||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
|
||||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
|
||||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
|
||||||
woodlands, shrublands, and mountainous areas.
|
|
||||||
|
|
||||||
Please answer the following question based on the above context."""
|
|
||||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
|
||||||
model = os.environ.get("MODEL_NAME")
|
model = os.environ.get("MODEL_NAME")
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError("MODEL_NAME environment variable not set")
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
|
|
||||||
test_input = {
|
test_input = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": f"{system_prompt}\n\n{unique_question}",
|
"prompt": prompt,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_tokens": 500,
|
"max_tokens": 500,
|
||||||
}
|
}
|
||||||
@@ -164,18 +153,7 @@ class ChatCompletionsData(GenericData):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls) -> "ChatCompletionsData":
|
def for_test(cls) -> "ChatCompletionsData":
|
||||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||||
|
|
||||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
|
||||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
|
||||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
|
||||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
|
||||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
|
||||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
|
||||||
woodlands, shrublands, and mountainous areas.
|
|
||||||
|
|
||||||
Please answer the following question based on the above context."""
|
|
||||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
|
||||||
model = os.environ.get("MODEL_NAME")
|
model = os.environ.get("MODEL_NAME")
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError("MODEL_NAME environment variable not set")
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
@@ -183,10 +161,7 @@ class ChatCompletionsData(GenericData):
|
|||||||
# Chat completions use messages format instead of prompt
|
# Chat completions use messages format instead of prompt
|
||||||
test_input = {
|
test_input = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": [
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
{"role": "system", "content": system_prompt}, # Shared prefix
|
|
||||||
{"role": "user", "content": unique_question} # Unique per request
|
|
||||||
],
|
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_tokens": 500,
|
"max_tokens": 500,
|
||||||
}
|
}
|
||||||
|
|||||||
+54
-100
@@ -42,7 +42,6 @@ class ReqResult:
|
|||||||
total_ms: float
|
total_ms: float
|
||||||
ok: bool
|
ok: bool
|
||||||
error: str = ""
|
error: str = ""
|
||||||
status_code: int = 0
|
|
||||||
t_start: float = 0.0
|
t_start: float = 0.0
|
||||||
t_end: float = 0.0
|
t_end: float = 0.0
|
||||||
workload: float = 0.0
|
workload: float = 0.0
|
||||||
@@ -59,73 +58,31 @@ def do_one(endpoint_name: str,
|
|||||||
route_session,
|
route_session,
|
||||||
worker_session):
|
worker_session):
|
||||||
try:
|
try:
|
||||||
workload = payload.count_workload()
|
u = payload.count_workload()
|
||||||
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
|
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": u}
|
||||||
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
|
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
|
||||||
start = time.time()
|
start = time.time()
|
||||||
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
|
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
|
||||||
t_after_route = time.time()
|
t_after_route = time.time()
|
||||||
if r0.status_code != 200:
|
if r0.status_code != 200:
|
||||||
results_list.append(ReqResult(worker_url="",
|
results_list.append(ReqResult("", (t_after_route - start) * 1000.0, 0.0, (t_after_route - start) * 1000.0, False,
|
||||||
route_ms=(t_after_route - start) * 1000.0,
|
f"route {r0.status_code} {r0.text}"))
|
||||||
worker_ms=0.0,
|
|
||||||
total_ms=(t_after_route - start) * 1000.0,
|
|
||||||
ok=False,
|
|
||||||
error=f"route error {r0.reason} {r0.text}",
|
|
||||||
status_code=r0.status_code,
|
|
||||||
t_start=start - t0,
|
|
||||||
t_end=t_after_route - t0,
|
|
||||||
workload=workload))
|
|
||||||
return
|
return
|
||||||
msg = r0.json()
|
msg = r0.json()
|
||||||
|
|
||||||
# 1) Check if we got a worker back from route
|
# 1) "Status" is in the response when no worker is ready
|
||||||
worker_url = msg.get("url", "")
|
worker_sampled = True
|
||||||
if not worker_url:
|
status = msg.get("status", "")
|
||||||
status = msg.get("status", "")
|
if status:
|
||||||
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
|
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
|
||||||
if m:
|
if m:
|
||||||
tot, loading, standby, err = map(int, m.groups())
|
tot, loading, standby, err = map(int, m.groups())
|
||||||
idle = max(tot - loading - standby - err, 0)
|
idle = max(tot - loading - standby - err, 0)
|
||||||
status_samples.append((time.time() - t0, idle))
|
status_samples.append((time.time() - t0, idle))
|
||||||
|
worker_sampled = False
|
||||||
|
|
||||||
# 2) If we got a worker, send the request
|
# 2) Otherwise (successful request), sample via /get_endpoint_workers/ for eligible (idle) worker tracking
|
||||||
if worker_url:
|
if worker_sampled:
|
||||||
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
|
|
||||||
t_before_worker = time.time()
|
|
||||||
r1 = worker_session.post(
|
|
||||||
urljoin(worker_url, worker_endpoint),
|
|
||||||
json=req,
|
|
||||||
verify=get_cert_file_path(),
|
|
||||||
timeout=(4, 120),
|
|
||||||
)
|
|
||||||
t_after_worker = time.time()
|
|
||||||
if r1.status_code != 200:
|
|
||||||
results_list.append(ReqResult(worker_url=worker_url,
|
|
||||||
route_ms=(t_after_route - start) * 1000.0,
|
|
||||||
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
|
||||||
total_ms=(t_after_worker - start) * 1000.0,
|
|
||||||
ok=False,
|
|
||||||
error=f"worker inference error {r1.reason} {r1.text}",
|
|
||||||
status_code=r1.status_code,
|
|
||||||
t_start=start - t0,
|
|
||||||
t_end=t_after_worker - t0,
|
|
||||||
workload=workload))
|
|
||||||
return
|
|
||||||
# Success case
|
|
||||||
results_list.append(ReqResult(worker_url=worker_url,
|
|
||||||
route_ms=(t_after_route - start) * 1000.0,
|
|
||||||
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
|
||||||
total_ms=(t_after_worker - start) * 1000.0,
|
|
||||||
ok=True,
|
|
||||||
error="",
|
|
||||||
status_code=200,
|
|
||||||
t_start=start - t0,
|
|
||||||
t_end=t_after_worker - t0,
|
|
||||||
workload=workload))
|
|
||||||
|
|
||||||
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
|
|
||||||
if worker_url:
|
|
||||||
try:
|
try:
|
||||||
r_status = route_session.post(
|
r_status = route_session.post(
|
||||||
urljoin(server_url, "/get_endpoint_workers/"),
|
urljoin(server_url, "/get_endpoint_workers/"),
|
||||||
@@ -143,18 +100,29 @@ def do_one(endpoint_name: str,
|
|||||||
status_samples.append((time.time() - t0, idle))
|
status_samples.append((time.time() - t0, idle))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 3) Send the request
|
||||||
|
worker_address = msg["url"]
|
||||||
|
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
|
||||||
|
t1 = time.time()
|
||||||
|
# Use explicit connect/read timeouts to avoid long hangs
|
||||||
|
r1 = worker_session.post(
|
||||||
|
urljoin(worker_address, worker_endpoint),
|
||||||
|
json=req,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
|
timeout=(4, 120),
|
||||||
|
)
|
||||||
|
t2 = time.time()
|
||||||
|
if r1.status_code != 200:
|
||||||
|
results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0,
|
||||||
|
(t2 - start) * 1000.0, False,
|
||||||
|
f"infer {r1.status_code} {r1.text}"))
|
||||||
|
return
|
||||||
|
results_list.append(ReqResult(worker_address, (t_after_route - start) * 1000.0, (t2 - t1) * 1000.0, (t2 - start) * 1000.0,
|
||||||
|
True, "", t_start=start - t0, t_end=t2 - t0, workload=u))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
t = time.time()
|
t = time.time()
|
||||||
results_list.append(ReqResult(worker_url="",
|
results_list.append(ReqResult("", (t - start) * 1000.0, 0.0, (t - start) * 1000.0, False, str(e)))
|
||||||
route_ms=0.0,
|
|
||||||
worker_ms=0.0,
|
|
||||||
total_ms=0.0,
|
|
||||||
ok=False,
|
|
||||||
error=f"unknown error {e}",
|
|
||||||
status_code=0,
|
|
||||||
t_start=t - t0,
|
|
||||||
t_end=t - t0,
|
|
||||||
workload=0.0))
|
|
||||||
|
|
||||||
def run_load_with_metrics(num_requests: int,
|
def run_load_with_metrics(num_requests: int,
|
||||||
requests_per_second: float,
|
requests_per_second: float,
|
||||||
@@ -164,7 +132,7 @@ def run_load_with_metrics(num_requests: int,
|
|||||||
worker_endpoint: str,
|
worker_endpoint: str,
|
||||||
instance: str,
|
instance: str,
|
||||||
out_path: str):
|
out_path: str):
|
||||||
|
# Resolve endpoint id + endpoint-scoped API key
|
||||||
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
|
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
|
||||||
account_api_key=account_api_key,
|
account_api_key=account_api_key,
|
||||||
instance=instance)
|
instance=instance)
|
||||||
@@ -177,7 +145,8 @@ def run_load_with_metrics(num_requests: int,
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
results = []
|
results = []
|
||||||
status_samples = []
|
status_samples = []
|
||||||
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
|
# Concurrency control
|
||||||
|
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "1024"))
|
||||||
submit_queue_factor = 2 # cap queued tasks to reduce memory
|
submit_queue_factor = 2 # cap queued tasks to reduce memory
|
||||||
|
|
||||||
# Shared HTTP sessions with connection pooling (persistent connections)
|
# Shared HTTP sessions with connection pooling (persistent connections)
|
||||||
@@ -189,9 +158,9 @@ def run_load_with_metrics(num_requests: int,
|
|||||||
return sess
|
return sess
|
||||||
|
|
||||||
# Router: mostly single host, small connection pool is sufficient
|
# Router: mostly single host, small connection pool is sufficient
|
||||||
route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
|
route_session = make_session(pool_connections=8, pool_maxsize=max_concurrency)
|
||||||
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
|
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
|
||||||
worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
|
worker_session = make_session(pool_connections=max(256, max_concurrency), pool_maxsize=max_concurrency)
|
||||||
|
|
||||||
# Fire requests using a thread pool, scheduling at requested RPS
|
# Fire requests using a thread pool, scheduling at requested RPS
|
||||||
inflight = set()
|
inflight = set()
|
||||||
@@ -240,12 +209,11 @@ def run_load_with_metrics(num_requests: int,
|
|||||||
|
|
||||||
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
|
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
|
||||||
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
|
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
|
||||||
route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
|
#route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
|
||||||
|
|
||||||
avg_total = float(np.mean(total_ms)) if succ else 0.0
|
avg_total = float(np.mean(total_ms)) if succ else 0.0
|
||||||
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
|
|
||||||
avg_route = float(np.mean(route_ms)) if succ else 0.0
|
|
||||||
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
|
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
|
||||||
|
total_compute_time_ms = float(np.sum(worker_ms)) if succ else 0.0
|
||||||
|
|
||||||
# Distribution over workers (by host:port)
|
# Distribution over workers (by host:port)
|
||||||
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
|
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
|
||||||
@@ -272,11 +240,11 @@ def run_load_with_metrics(num_requests: int,
|
|||||||
|
|
||||||
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
|
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
|
||||||
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
|
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
|
||||||
print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
|
print(f"Total compute time (sum worker latency, s): {total_compute_time_ms/1000.0:.2f}")
|
||||||
if errs:
|
if errs:
|
||||||
print("Sample errors:")
|
print("Sample errors:")
|
||||||
for e in errs[:5]:
|
for e in errs[:5]:
|
||||||
print(f" {e.status_code} {e.error}")
|
print(f" {e.error}")
|
||||||
|
|
||||||
# Plot: 2x3 grid
|
# Plot: 2x3 grid
|
||||||
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||||||
@@ -296,7 +264,7 @@ def run_load_with_metrics(num_requests: int,
|
|||||||
# Latency histogram (total)
|
# Latency histogram (total)
|
||||||
ax1 = axes[0, 1]
|
ax1 = axes[0, 1]
|
||||||
if succ:
|
if succ:
|
||||||
ax1.hist(total_ms, bins=30)
|
ax1.hist(total_ms, bins=30, color="#4e79a7")
|
||||||
ax1.set_title("Total latency (ms)")
|
ax1.set_title("Total latency (ms)")
|
||||||
ax1.set_xlabel("ms")
|
ax1.set_xlabel("ms")
|
||||||
ax1.set_ylabel("freq")
|
ax1.set_ylabel("freq")
|
||||||
@@ -322,7 +290,7 @@ def run_load_with_metrics(num_requests: int,
|
|||||||
ax_idle.plot(ts, vals, "-o", ms=3)
|
ax_idle.plot(ts, vals, "-o", ms=3)
|
||||||
ax_idle.set_title("Completions per second")
|
ax_idle.set_title("Completions per second")
|
||||||
ax_idle.set_xlabel("time (s)")
|
ax_idle.set_xlabel("time (s)")
|
||||||
ax_idle.set_ylabel("completions / sec")
|
ax_idle.set_ylabel("req/s")
|
||||||
|
|
||||||
# Summary text
|
# Summary text
|
||||||
ax3 = axes[1, 1]
|
ax3 = axes[1, 1]
|
||||||
@@ -330,36 +298,22 @@ def run_load_with_metrics(num_requests: int,
|
|||||||
text = (
|
text = (
|
||||||
f"Total requests: {total_reqs}\n"
|
f"Total requests: {total_reqs}\n"
|
||||||
f"Success: {succ} Errors: {len(errs)}\n"
|
f"Success: {succ} Errors: {len(errs)}\n"
|
||||||
f"Avg total latency: {avg_total:.1f} ms\n"
|
f"Avg latency: {avg_total:.1f} ms\n"
|
||||||
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
|
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
|
||||||
f"Avg route latency: {avg_route:.1f} ms\n"
|
f"Total compute time: {total_compute_time_ms/1000.0:.2f} s"
|
||||||
f"Avg worker latency: {avg_worker:.1f} ms\n"
|
|
||||||
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
|
|
||||||
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
|
|
||||||
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
|
|
||||||
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
|
|
||||||
)
|
)
|
||||||
ax3.set_title("Summary")
|
ax3.set_title("Summary")
|
||||||
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
|
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
|
||||||
|
|
||||||
# Error count over time
|
# Latency CDF (total_ms)
|
||||||
ax_errors = axes[1, 2]
|
ax_cdf = axes[1, 2]
|
||||||
all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
|
if succ:
|
||||||
if all_end_times:
|
x = np.sort(total_ms)
|
||||||
min_second = min(all_end_times)
|
y = np.linspace(0, 1, len(x), endpoint=True)
|
||||||
max_second = max(all_end_times)
|
ax_cdf.plot(x, y)
|
||||||
# Count errors per second
|
ax_cdf.set_title("Latency CDF")
|
||||||
errors_per_second = {}
|
ax_cdf.set_xlabel("ms")
|
||||||
for result in errs:
|
ax_cdf.set_ylabel("fraction ≤ x")
|
||||||
second = int(result.t_end)
|
|
||||||
errors_per_second[second] = errors_per_second.get(second, 0) + 1
|
|
||||||
# Create complete timeline including zeros
|
|
||||||
time_seconds = list(range(min_second, max_second + 1))
|
|
||||||
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
|
|
||||||
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
|
|
||||||
ax_errors.set_title("Errors per second")
|
|
||||||
ax_errors.set_xlabel("time (s)")
|
|
||||||
ax_errors.set_ylabel("errors / sec")
|
|
||||||
|
|
||||||
# Ensure unique output path and create directory if needed
|
# Ensure unique output path and create directory if needed
|
||||||
final_out_path = get_incremented_path(out_path)
|
final_out_path = get_incremented_path(out_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user