Compare commits

..

3 Commits

Author SHA1 Message Date
Colter Downing fd9d56e576 remove env var writing 2025-10-28 16:11:35 -07:00
Colter Downing 8d9ffb3a6c removed 5 sec sleep and warmup request on load 2025-10-28 15:28:30 -07:00
Colter Downing 5d5bc197d7 adding timings for cold start 2025-10-26 18:44:23 -07:00
11 changed files with 133 additions and 406 deletions
+76 -124
View File
@@ -9,7 +9,6 @@ from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property
from distutils.util import strtobool
from collections import deque
from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
@@ -31,7 +30,7 @@ from lib.data_types import (
BenchmarkResult
)
VERSION = "0.2.1"
VERSION = "0.1.0"
MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
@@ -64,24 +63,19 @@ class Backend:
version = VERSION
msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
queue: deque = dataclasses.field(default_factory=deque, repr=False)
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
mtoken: str = dataclasses.field(
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
)
def __post_init__(self):
self.metrics = Metrics()
self.metrics._set_version(self.version)
self.metrics._set_mtoken(self.mtoken)
self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False
self._model_tail_start_time = None
self._model_loaded_time = None
self._first_healthcheck_ok = False
@property
def pubkey(self) -> Optional[RSA.RsaKey]:
@@ -113,19 +107,26 @@ class Backend:
#######################################Private#######################################
def _fetch_pubkey(self):
report_addr = self.report_addr.rstrip("/")
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
try:
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = RSA.import_key(result)
if key is not None:
return key
except (ValueError , subprocess.CalledProcessError) as e:
log.debug(f"Error downloading key: {e}")
self.backend_errored("Failed to get autoscaler pubkey")
t0 = time.time()
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = None
for _ in range(5):
try:
key = RSA.import_key(result)
break
except ValueError as e:
log.debug(f"Error downloading key: {e}")
time.sleep(15)
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
else:
log.debug(f"pubkey fetch+parse took {time.time()-t0:.2f}s")
return key
async def __handle_request(
self,
@@ -143,26 +144,11 @@ class Backend:
workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
def advance_queue_after_completion(event: asyncio.Event):
"""Pop current head and wake next waiter, if any."""
# If this event is current head, wake next waiter
if self.queue and self.queue[0] is event:
self.queue.popleft()
if self.queue:
self.queue[0].set()
else:
# Else, remove it from the queue
try:
self.queue.remove(event)
except ValueError:
pass
async def cancel_api_call_if_disconnected() -> None:
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled")
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics)
return
raise asyncio.CancelledError
async def make_request() -> Union[web.Response, web.StreamResponse]:
try:
@@ -179,9 +165,7 @@ class Backend:
res = await handler.generate_client_response(request, response)
self.metrics._request_success(request_metrics)
return res
except asyncio.CancelledError:
raise
except Exception as e:
except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(request_metrics)
return web.Response(status=500)
@@ -196,87 +180,46 @@ class Backend:
self.metrics._request_reject(request_metrics)
return web.Response(status=429)
disconnect_task = create_task(cancel_api_call_if_disconnected())
next_request_task = None
work_task = None
event = asyncio.Event() # Used in finally block, so initialize here
self.metrics._request_start(request_metrics)
acquired = False
try:
if self.allow_parallel_requests:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done:
return web.Response(status=499)
# otherwise work_task completed
return await work_task
# FIFO-queue branch
else:
# Insert a Event into the queue for this request
# Event.set() == our request is up next
self.queue.append(event)
if self.queue and self.queue[0] is event:
event.set()
# Race between our request being next and request being cancelled
next_request_task = create_task(event.wait())
first_done, first_pending = await wait(
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
self.metrics._request_start(request_metrics)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
acquired = True
log.debug(
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
# If the disconnect task wins the race
if disconnect_task in first_done:
# Clean up the next_request_task, then exit
for t in first_pending:
t.cancel()
await asyncio.gather(*first_pending, return_exceptions=True)
return web.Response(status=499)
# We are the next-up request in the queue
log.debug(f"Starting work on request {request_metrics.reqnum}...")
# Race the backend API call with the disconnect task
work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done:
return web.Response(status=499)
# otherwise work_task completed
return await work_task
done_task = done.pop()
try:
return done_task.result()
except Exception as e:
log.debug(f"Request task raised exception: {e}")
return web.Response(status=500)
except asyncio.CancelledError:
return web.Response(status=499)
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499)
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
finally:
if not self.allow_parallel_requests:
advance_queue_after_completion(event)
# Always release the semaphore if it was acquired
if acquired:
self.sem.release()
self.metrics._request_end(request_metrics)
cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t]
for t in cleanup_tasks:
if not t.done():
t.cancel()
if cleanup_tasks:
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
@cached_property
def healthcheck_session(self):
@@ -303,6 +246,10 @@ class Backend:
log.debug(f"Performing healthcheck on {health_check_url}")
async with self.healthcheck_session.get(health_check_url) as response:
if response.status == 200:
if not self._first_healthcheck_ok:
if self._model_loaded_time:
log.debug(f"first healthcheck OK after {time.time()-self._model_loaded_time:.2f}s since ModelLoaded")
self._first_healthcheck_ok = True
log.debug("Healthcheck successful")
elif response.status == 503:
log.debug(f"Healthcheck failed with status: {response.status}")
@@ -349,7 +296,7 @@ class Backend:
message = {
key: value
for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature" and key != "__request_id"
if key != "signature"
}
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug(
@@ -359,7 +306,7 @@ class Backend:
elif message in self.msg_history:
log.debug(f"message: {message} already in message history")
return False
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
@@ -387,8 +334,11 @@ class Backend:
pass
log.debug("Initial run to trigger model loading...")
t_bench0 = time.time()
payload = self.benchmark_handler.make_benchmark_payload()
await self.__call_api(handler=self.benchmark_handler, payload=payload)
log.debug(f"warmup request took {time.time()-t_bench0:.2f}s")
t_benchmark_loop0 = time.time()
max_throughput = 0
sum_throughput = 0
@@ -413,9 +363,6 @@ class Backend:
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
time_elapsed = time.time() - start
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
if successful_responses == 0:
self.backend_errored("No successful responses from benchmark")
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
throughput = total_workload / time_elapsed
sum_throughput += throughput
@@ -439,6 +386,7 @@ class Backend:
log.debug(
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
)
log.debug(f"benchmark loop took {time.time()-t_benchmark_loop0:.2f}s")
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
f.write(str(max_throughput))
return max_throughput
@@ -451,14 +399,17 @@ class Backend:
for action, msg in self.log_actions:
match action:
case LogAction.ModelLoaded if msg in log_line:
log.debug(
f"Got log line indicating model is loaded: {log_line}"
)
now = time.time()
elapsed = now - self._model_tail_start_time
log.debug(f"ModelLoaded observed after {elapsed:.2f}s: {log_line}")
# some backends need a few seconds after logging successful startup before
# they can begin accepting requests
# await sleep(5)
try:
t_bench0 = time.time()
max_throughput = await run_benchmark()
self._model_loaded_time = time.time()
log.debug(f"benchmark total took {self._model_loaded_time - t_bench0:.2f}s")
self.__start_healthcheck = True
self.metrics._model_loaded(
max_throughput=max_throughput,
@@ -477,6 +428,7 @@ class Backend:
async def tail_log():
log.debug(f"tailing file: {self.model_log_file}")
self._model_tail_start_time = time.time()
async with await open_file(self.model_log_file) as f:
while True:
line = await f.readline()
+5 -7
View File
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
class AuthData:
"""data used to authenticate requester"""
signature: str
cost: str
endpoint: str
reqnum: int
request_idx: int
signature: str
url: str
request_idx: int
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -190,12 +190,11 @@ class SystemMetrics:
self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage
def reset(self, expected: float | None) -> None:
def reset(self):
# autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading
if self.model_loading_time == expected:
self.model_loading_time = None
self.model_loading_time = None
@dataclass
@@ -258,7 +257,7 @@ class ModelMetrics:
def wait_time(self) -> float:
if (len(self.requests_working) == 0):
return 0.0
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
return sum([request.workload for request in self.requests_working.values()]) / self.max_throughput
@property
def cur_load(self) -> float:
@@ -286,7 +285,6 @@ class AutoScalerData:
"""Data that is reported to autoscaler"""
id: int
mtoken: str
version: str
loadtime: float
cur_load: float
+18 -65
View File
@@ -28,7 +28,6 @@ def get_url() -> str:
@dataclass
class Metrics:
version: str = "0"
mtoken: str = ""
last_metric_update: float = 0.0
last_request_served: float = 0.0
update_pending: bool = False
@@ -143,22 +142,17 @@ class Metrics:
def _set_version(self, version: str) -> None:
self.version = version
def _set_mtoken(self, mtoken: str) -> None:
self.mtoken = mtoken
#######################################Private#######################################
async def __send_delete_requests_and_reset(self):
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
async def send_data(report_addr: str, success: bool) -> bool:
data = {
"worker_id": self.id,
"mtoken": self.mtoken,
"request_idxs": idxs,
"success": success_flag,
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success],
"success": success
}
log.debug(
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
)
log.debug(f"Deleting requests that {'succeeded' if success else 'failed'}: {data['request_idxs']}")
full_path = report_addr.rstrip("/") + "/delete_requests/"
for attempt in range(1, 4):
try:
@@ -168,55 +162,26 @@ class Metrics:
res.raise_for_status()
return True
except asyncio.TimeoutError:
log.debug("delete_requests timed out")
log.debug(f"delete_requests timed out")
except (ClientResponseError, Exception) as e:
log.debug(f"delete_requests failed with error: {e}")
await asyncio.sleep(2)
log.debug(f"retrying delete_request, attempt: {attempt}")
return False
# Take a snapshot of what we plan to send this tick.
# New arrivals after this snapshot will remain in the queue for the next tick.
snapshot = list(self.model_metrics.requests_deleting)
success_idxs = [r.request_idx for r in snapshot if r.success is True]
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
if not success_idxs and not failed_idxs:
return # nothing to do
for report_addr in self.report_addr:
# TODO: Add a Redis subscriber queue for delete_requests
if report_addr == "https://cloud.vast.ai/api/v0":
# Patch: ignore the Redis API report_addr
continue
sent_success = True
sent_failed = True
if success_idxs:
sent_success = await post(report_addr, success_idxs, True)
if failed_idxs:
sent_failed = await post(report_addr, failed_idxs, False)
if sent_success and sent_failed:
# Remove only the items we actually sent from the live queue.
sent_set = set(success_idxs) | set(failed_idxs)
self.model_metrics.requests_deleting[:] = [
r for r in self.model_metrics.requests_deleting
if r.request_idx not in sent_set
]
success = await send_data(report_addr, success=True) and await send_data(report_addr, success=False)
if success is True:
self.model_metrics.requests_deleting.clear()
break
async def __send_metrics_and_reset(self):
loadtime_snapshot = self.system_metrics.model_loading_time
def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData(
id=self.id,
mtoken=self.mtoken,
version=self.version,
loadtime=(loadtime_snapshot or 0.0),
loadtime=(self.system_metrics.model_loading_time or 0.0),
new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected,
@@ -234,25 +199,17 @@ class Metrics:
async def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data()
log_data = asdict(data)
def obfuscate(secret: str) -> str:
if secret is None:
return ""
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
full_path = report_addr.rstrip("/") + "/worker_status/"
log.debug(
"\n".join(
[
"#" * 60,
f"sending data to autoscaler",
f"{json.dumps(log_data, indent=2)}",
f"{json.dumps((asdict(data)), indent=2)}",
"#" * 60,
]
)
)
full_path = report_addr.rstrip("/") + "/worker_status/"
for attempt in range(1, 4):
try:
session = await self.http()
@@ -272,15 +229,11 @@ class Metrics:
self.system_metrics.update_disk_usage()
sent = False
for report_addr in self.report_addr:
if await send_data(report_addr):
sent = True
success = await send_data(report_addr)
if success is True:
break
if sent:
# clear the one-shot loadtime only if we actually sent *this* value
self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False
self.model_metrics.reset()
self.last_metric_update = time.time()
self.update_pending = False
self.model_metrics.reset()
self.system_metrics.reset()
self.last_metric_update = time.time()
+23 -56
View File
@@ -2,6 +2,9 @@
set -e -o pipefail
log() { echo "$(date +'%Y-%m-%d %H:%M:%S') $*"; }
step(){ _t0=$(date +%s); eval "$1"; _dt=$(($(date +%s)-_t0)); log "$2 took ${_dt}s"; }
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
@@ -9,7 +12,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR"
@@ -41,33 +44,28 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG
echo_var MODEL_LOG
# Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
name=${line%%=*}
value=${line#*=}
printf '%s="%s"\n' "$name" "$value"
done > /etc/environment
fi
# # Populate /etc/environment with quoted values
# if ! grep -q "VAST" /etc/environment; then
# env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
# name=${line%%=*}
# value=${line#*=}
# printf '%s="%s"\n' "$name" "$value"
# done > /etc/environment
# fi
if [ ! -d "$ENV_PATH" ]
then
echo "setting up venv"
if ! which uv; then
curl -LsSf https://astral.sh/uv/install.sh | sh
source ~/.local/bin/env
fi
step 'if ! which uv; then curl -LsSf https://astral.sh/uv/install.sh | sh; source ~/.local/bin/env; fi' "uv install"
# Fork testing
[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
if [[ -n ${PYWORKER_REF:-} ]]; then
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
fi
step '[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"' "git clone"
step 'if [[ -n ${PYWORKER_REF:-} ]]; then (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); fi' "git checkout"
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
source "$ENV_PATH/bin/activate"
uv pip install -r "${SERVER_DIR}/requirements.txt"
step 'uv venv --python-preference only-managed "$ENV_PATH" -p 3.10' "venv create"
step 'source "$ENV_PATH/bin/activate"' "venv activate"
step 'uv pip install -r "${SERVER_DIR}/requirements.txt"' "pip install requirements"
touch ~/.no_auto_tmux
else
@@ -80,39 +78,8 @@ fi
[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
if [ "$USE_SSL" = true ]; then
cat << EOF > /etc/openssl-san.cnf
[req]
default_bits = 2048
distinguished_name = req_distinguished_name
req_extensions = v3_req
[req_distinguished_name]
countryName = US
stateOrProvinceName = CA
organizationName = Vast.ai Inc.
commonName = vast.ai
[v3_req]
basicConstraints = CA:FALSE
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
subjectAltName = @alt_names
[alt_names]
IP.1 = 0.0.0.0
EOF
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
-nodes \
-sha256 \
-keyout /etc/instance.key \
-out /etc/instance.csr \
-config /etc/openssl-san.cnf
curl --header 'Content-Type: application/octet-stream' \
--data-binary @//etc/instance.csr \
-X \
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
step 'openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" -nodes -sha256 -keyout /etc/instance.key -out /etc/instance.csr -config /etc/openssl-san.cnf' "openssl csr"
step 'curl --header "Content-Type: application/octet-stream" --data-binary @//etc/instance.csr -X POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt' "sign cert"
fi
@@ -122,11 +89,11 @@ export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
cd "$SERVER_DIR"
echo "launching PyWorker server"
log "launching PyWorker server"
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
echo "launching PyWorker server done"
_t0=$(date +%s); (python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") & _dt=$(($(date +%s)-_t0)); log "PyWorker spawn took ${_dt}s"
log "launching PyWorker server done"
+3 -15
View File
@@ -12,21 +12,9 @@ A docker image is provided but you may use any if the above requirements are met
## Benchmarking
### Custom Benchmark Workflows
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
**Ways to provide the benchmark file:**
- Fork this repository and add your `benchmark.json` file
- Write the file during worker provisioning (onstart script or setup phase)
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
### Default Benchmark (Fallback)
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
| Environment Variable | Default Value | Description |
| -------------------- | ------------- | ----------- |
@@ -36,7 +24,7 @@ The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
#### Calibrating Fallback Benchmark Duration
### Calibrating Benchmark Duration
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
-1
View File
@@ -98,7 +98,6 @@ def call_text2image_workflow(
endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"],
url=route_response["url"],
request_idx=route_response["request_idx"],
)
# Build the payload for the worker request
+4 -28
View File
@@ -5,13 +5,12 @@ import dataclasses
from typing import Dict, Any
from functools import cache
from math import ceil
from pathlib import Path
import json
import logging
from lib.data_types import ApiPayload, JsonDataException
log = logging.getLogger(__file__)
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
def count_workload() -> float:
# Always 100.0 where there is a single instance of ComfyUI handling requests
@@ -25,32 +24,9 @@ class ComfyWorkflowData(ApiPayload):
@classmethod
def for_test(cls):
"""
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
Otherwise, use the variables available to simulate workflows of the required running time
Use the variables available to simulate workflows of the required running time
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
"""
# Try to load benchmark.json
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
if benchmark_file.exists():
try:
with open(benchmark_file, "r") as f:
benchmark_workflow = json.load(f)
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"workflow_json": benchmark_workflow
}
)
except (json.JSONDecodeError, IOError):
# JSON is malformed or file can't be read, fall through to default
log.error(f"Failed to benchmark using {benchmark_file}")
# Fallback: read prompts and construct payload
log.info("Using fallback method for benchmarking")
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
test_prompt = random.choice(test_prompts).rstrip()
return cls(
input={
@@ -1,107 +0,0 @@
{
"3": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "text, watermark",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
-1
View File
@@ -19,7 +19,6 @@ MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: ", # This error is emitted when the model file is not there at all
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
]
-1
View File
@@ -82,7 +82,6 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
request_idx=message["request_idx"],
)
workflow = {
"3": {
+4 -1
View File
@@ -6,10 +6,13 @@ from typing import Union, Type, Dict, Any, Optional
from aiohttp import web, ClientResponse
import nltk
import logging
import time
log = logging.getLogger(__name__)
t0 = time.time()
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
log = logging.getLogger(__name__)
print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} NLTK words download+load took {time.time()-t0:.2f}s")
"""
Generic dataclass accepts any dictionary in input.