Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0471f6b219 | |||
| 9c795e2a01 | |||
| 830b532781 | |||
| d6a6e34c6b | |||
| ac1e109c48 | |||
| d6eb498ee4 | |||
| 70d51bafe1 | |||
| 63909736bb | |||
| f4f7080df1 | |||
| d51a338e8f | |||
| 92a04bd7af | |||
| c98d661513 | |||
| f6fd1c6ac1 | |||
| 055e346c8c | |||
| 1cedb28acf | |||
| ec25dda3ad | |||
| 0397af719d | |||
| 3786cf978d | |||
| a86d4bcf9c | |||
| e9b6a14a5e | |||
| cadac033e1 |
+78
-31
@@ -5,7 +5,7 @@ import base64
|
||||
import subprocess
|
||||
import dataclasses
|
||||
import logging
|
||||
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task, get_running_loop
|
||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||
from functools import cached_property
|
||||
from distutils.util import strtobool
|
||||
@@ -47,7 +47,7 @@ class Backend:
|
||||
This class is responsible for:
|
||||
1. Tailing logs and updating load time metrics
|
||||
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
|
||||
sending the request. It also updates metrics as it makes those requests.
|
||||
sending the request. It also updates metrics as it makes those requests.
|
||||
3. Running a benchmark from an EndpointHandler
|
||||
"""
|
||||
|
||||
@@ -73,9 +73,11 @@ class Backend:
|
||||
self._total_pubkey_fetch_errors = 0
|
||||
self._pubkey = self._fetch_pubkey()
|
||||
self.__start_healthcheck: bool = False
|
||||
self._model_tail_start_time = None
|
||||
self._model_loaded_time = None
|
||||
self._first_healthcheck_ok = False
|
||||
|
||||
# NEW: FIFO queue + worker count
|
||||
self.request_queue: "asyncio.Queue[tuple[EndpointHandler[ApiPayload_T], web.Request, asyncio.Future]]" = asyncio.Queue()
|
||||
# If parallel allowed, let multiple workers drain the queue (order preserved by FIFO per worker; overall start order is FIFO).
|
||||
self._num_workers: int = 1 if not self.allow_parallel_requests else int(os.environ.get("WORKERS", "4"))
|
||||
|
||||
@property
|
||||
def pubkey(self) -> Optional[RSA.RsaKey]:
|
||||
@@ -94,6 +96,22 @@ class Backend:
|
||||
timeout = ClientTimeout(total=None)
|
||||
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
||||
|
||||
async def _worker(self):
|
||||
while True:
|
||||
handler, request, fut = await self.request_queue.get()
|
||||
try:
|
||||
# Skip if already cancelled while waiting in the queue
|
||||
if fut.cancelled():
|
||||
continue
|
||||
res = await self.__process_enqueued_request(handler, request)
|
||||
if not fut.cancelled():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
if not fut.cancelled():
|
||||
fut.set_exception(e)
|
||||
finally:
|
||||
self.request_queue.task_done()
|
||||
|
||||
def create_handler(
|
||||
self,
|
||||
handler: EndpointHandler[ApiPayload_T],
|
||||
@@ -107,7 +125,6 @@ class Backend:
|
||||
|
||||
#######################################Private#######################################
|
||||
def _fetch_pubkey(self):
|
||||
t0 = time.time()
|
||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||
result = subprocess.check_output(command, universal_newlines=True)
|
||||
log.debug("public key:")
|
||||
@@ -124,8 +141,6 @@ class Backend:
|
||||
self._total_pubkey_fetch_errors += 1
|
||||
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
||||
self.backend_errored("Failed to get autoscaler pubkey")
|
||||
else:
|
||||
log.debug(f"pubkey fetch+parse took {time.time()-t0:.2f}s")
|
||||
return key
|
||||
|
||||
async def __handle_request(
|
||||
@@ -133,7 +148,36 @@ class Backend:
|
||||
handler: EndpointHandler[ApiPayload_T],
|
||||
request: web.Request,
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""use this function to forward requests to the model endpoint"""
|
||||
"""use this function to enqueue requests for FIFO processing"""
|
||||
loop = get_running_loop()
|
||||
fut: asyncio.Future = loop.create_future()
|
||||
|
||||
# If the client disconnects while waiting in the FIFO, cancel the future so the worker skips it
|
||||
cancel_watch = create_task(request.wait_for_disconnection())
|
||||
def _cancel_if_disconnected(_):
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
cancel_watch.add_done_callback(_cancel_if_disconnected)
|
||||
|
||||
try:
|
||||
await self.request_queue.put((handler, request, fut))
|
||||
return await fut
|
||||
except asyncio.CancelledError:
|
||||
# Propagate cancellation to ensure aiohttp doesn't expect a response body
|
||||
raise
|
||||
finally:
|
||||
# Best-effort cleanup of the watcher
|
||||
cancel_watch.cancel()
|
||||
|
||||
async def __process_enqueued_request(
|
||||
self,
|
||||
handler: EndpointHandler[ApiPayload_T],
|
||||
request: web.Request,
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""
|
||||
This contains the original __handle_request logic and is invoked by workers,
|
||||
ensuring FIFO execution via asyncio.Queue.
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
auth_data, payload = handler.get_data_from_request(data)
|
||||
@@ -141,8 +185,11 @@ class Backend:
|
||||
return web.json_response(data=e.message, status=422)
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||
|
||||
workload = payload.count_workload()
|
||||
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
||||
request_metrics: RequestMetrics = RequestMetrics(
|
||||
request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created"
|
||||
)
|
||||
|
||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||
await request.wait_for_disconnection()
|
||||
@@ -183,6 +230,8 @@ class Backend:
|
||||
acquired = False
|
||||
try:
|
||||
self.metrics._request_start(request_metrics)
|
||||
|
||||
# Preserve existing semaphore behavior for serializing requests when requested
|
||||
if self.allow_parallel_requests is False:
|
||||
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
||||
await self.sem.acquire()
|
||||
@@ -192,6 +241,7 @@ class Backend:
|
||||
)
|
||||
else:
|
||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||
|
||||
done, pending = await wait(
|
||||
[
|
||||
create_task(make_request()),
|
||||
@@ -246,10 +296,6 @@ class Backend:
|
||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||
async with self.healthcheck_session.get(health_check_url) as response:
|
||||
if response.status == 200:
|
||||
if not self._first_healthcheck_ok:
|
||||
if self._model_loaded_time:
|
||||
log.debug(f"first healthcheck OK after {time.time()-self._model_loaded_time:.2f}s since ModelLoaded")
|
||||
self._first_healthcheck_ok = True
|
||||
log.debug("Healthcheck successful")
|
||||
elif response.status == 503:
|
||||
log.debug(f"Healthcheck failed with status: {response.status}")
|
||||
@@ -263,8 +309,14 @@ class Backend:
|
||||
self.backend_errored(str(e))
|
||||
|
||||
async def _start_tracking(self) -> None:
|
||||
# Start the FIFO workers alongside existing loops
|
||||
worker_tasks = tuple(self._worker() for _ in range(self._num_workers))
|
||||
await gather(
|
||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
||||
self.__read_logs(),
|
||||
self.metrics._send_metrics_loop(),
|
||||
self.__healthcheck(),
|
||||
self.metrics._send_delete_requests_loop(),
|
||||
*worker_tasks,
|
||||
)
|
||||
|
||||
def backend_errored(self, msg: str) -> None:
|
||||
@@ -325,20 +377,17 @@ class Backend:
|
||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||
log.debug("already ran benchmark")
|
||||
# trigger model load
|
||||
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||
# _ = await self.__call_api(
|
||||
# handler=self.benchmark_handler, payload=payload
|
||||
# )
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
_ = await self.__call_api(
|
||||
handler=self.benchmark_handler, payload=payload
|
||||
)
|
||||
return float(f.readline())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
log.debug("Initial run to trigger model loading...")
|
||||
t_bench0 = time.time()
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||
log.debug(f"warmup request took {time.time()-t_bench0:.2f}s")
|
||||
t_benchmark_loop0 = time.time()
|
||||
|
||||
max_throughput = 0
|
||||
sum_throughput = 0
|
||||
@@ -363,6 +412,9 @@ class Backend:
|
||||
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
||||
time_elapsed = time.time() - start
|
||||
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
||||
if successful_responses == 0:
|
||||
self.backend_errored("No successful responses from benchmark")
|
||||
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
|
||||
|
||||
throughput = total_workload / time_elapsed
|
||||
sum_throughput += throughput
|
||||
@@ -386,7 +438,6 @@ class Backend:
|
||||
log.debug(
|
||||
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||
)
|
||||
log.debug(f"benchmark loop took {time.time()-t_benchmark_loop0:.2f}s")
|
||||
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||
f.write(str(max_throughput))
|
||||
return max_throughput
|
||||
@@ -399,17 +450,14 @@ class Backend:
|
||||
for action, msg in self.log_actions:
|
||||
match action:
|
||||
case LogAction.ModelLoaded if msg in log_line:
|
||||
now = time.time()
|
||||
elapsed = now - self._model_tail_start_time
|
||||
log.debug(f"ModelLoaded observed after {elapsed:.2f}s: {log_line}")
|
||||
log.debug(
|
||||
f"Got log line indicating model is loaded: {log_line}"
|
||||
)
|
||||
# some backends need a few seconds after logging successful startup before
|
||||
# they can begin accepting requests
|
||||
# await sleep(5)
|
||||
await sleep(5)
|
||||
try:
|
||||
t_bench0 = time.time()
|
||||
max_throughput = await run_benchmark()
|
||||
self._model_loaded_time = time.time()
|
||||
log.debug(f"benchmark total took {self._model_loaded_time - t_bench0:.2f}s")
|
||||
self.__start_healthcheck = True
|
||||
self.metrics._model_loaded(
|
||||
max_throughput=max_throughput,
|
||||
@@ -428,7 +476,6 @@ class Backend:
|
||||
|
||||
async def tail_log():
|
||||
log.debug(f"tailing file: {self.model_log_file}")
|
||||
self._model_tail_start_time = time.time()
|
||||
async with await open_file(self.model_log_file) as f:
|
||||
while True:
|
||||
line = await f.readline()
|
||||
|
||||
+1
-1
@@ -257,7 +257,7 @@ class ModelMetrics:
|
||||
def wait_time(self) -> float:
|
||||
if (len(self.requests_working) == 0):
|
||||
return 0.0
|
||||
return sum([request.workload for request in self.requests_working.values()]) / self.max_throughput
|
||||
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
|
||||
|
||||
@property
|
||||
def cur_load(self) -> float:
|
||||
|
||||
+32
-9
@@ -145,14 +145,15 @@ class Metrics:
|
||||
#######################################Private#######################################
|
||||
|
||||
async def __send_delete_requests_and_reset(self):
|
||||
|
||||
async def send_data(report_addr: str, success: bool) -> bool:
|
||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||
data = {
|
||||
"worker_id": self.id,
|
||||
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success],
|
||||
"success": success
|
||||
"request_idxs": idxs,
|
||||
"success": success_flag,
|
||||
}
|
||||
log.debug(f"Deleting requests that {'succeeded' if success else 'failed'}: {data['request_idxs']}")
|
||||
log.debug(
|
||||
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
|
||||
)
|
||||
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
@@ -162,16 +163,38 @@ class Metrics:
|
||||
res.raise_for_status()
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
log.debug(f"delete_requests timed out")
|
||||
log.debug("delete_requests timed out")
|
||||
except (ClientResponseError, Exception) as e:
|
||||
log.debug(f"delete_requests failed with error: {e}")
|
||||
await asyncio.sleep(2)
|
||||
log.debug(f"retrying delete_request, attempt: {attempt}")
|
||||
return False
|
||||
|
||||
# Take a snapshot of what we plan to send this tick.
|
||||
# New arrivals after this snapshot will remain in the queue for the next tick.
|
||||
snapshot = list(self.model_metrics.requests_deleting)
|
||||
success_idxs = [r.request_idx for r in snapshot if r.success is True]
|
||||
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
|
||||
|
||||
if not success_idxs and not failed_idxs:
|
||||
return # nothing to do
|
||||
|
||||
for report_addr in self.report_addr:
|
||||
success = await send_data(report_addr, success=True) and await send_data(report_addr, success=False)
|
||||
if success is True:
|
||||
self.model_metrics.requests_deleting.clear()
|
||||
sent_success = True
|
||||
sent_failed = True
|
||||
|
||||
if success_idxs:
|
||||
sent_success = await post(report_addr, success_idxs, True)
|
||||
if failed_idxs:
|
||||
sent_failed = await post(report_addr, failed_idxs, False)
|
||||
|
||||
if sent_success and sent_failed:
|
||||
# Remove only the items we actually sent from the live queue.
|
||||
sent_set = set(success_idxs) | set(failed_idxs)
|
||||
self.model_metrics.requests_deleting[:] = [
|
||||
r for r in self.model_metrics.requests_deleting
|
||||
if r.request_idx not in sent_set
|
||||
]
|
||||
break
|
||||
|
||||
|
||||
|
||||
+55
-22
@@ -2,9 +2,6 @@
|
||||
|
||||
set -e -o pipefail
|
||||
|
||||
log() { echo "$(date +'%Y-%m-%d %H:%M:%S') $*"; }
|
||||
step(){ _t0=$(date +%s); eval "$1"; _dt=$(($(date +%s)-_t0)); log "$2 took ${_dt}s"; }
|
||||
|
||||
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
|
||||
|
||||
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
|
||||
@@ -44,28 +41,33 @@ echo_var DEBUG_LOG
|
||||
echo_var PYWORKER_LOG
|
||||
echo_var MODEL_LOG
|
||||
|
||||
# # Populate /etc/environment with quoted values
|
||||
# if ! grep -q "VAST" /etc/environment; then
|
||||
# env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
|
||||
# name=${line%%=*}
|
||||
# value=${line#*=}
|
||||
# printf '%s="%s"\n' "$name" "$value"
|
||||
# done > /etc/environment
|
||||
# fi
|
||||
# Populate /etc/environment with quoted values
|
||||
if ! grep -q "VAST" /etc/environment; then
|
||||
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
|
||||
name=${line%%=*}
|
||||
value=${line#*=}
|
||||
printf '%s="%s"\n' "$name" "$value"
|
||||
done > /etc/environment
|
||||
fi
|
||||
|
||||
if [ ! -d "$ENV_PATH" ]
|
||||
then
|
||||
echo "setting up venv"
|
||||
step 'if ! which uv; then curl -LsSf https://astral.sh/uv/install.sh | sh; source ~/.local/bin/env; fi' "uv install"
|
||||
if ! which uv; then
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source ~/.local/bin/env
|
||||
fi
|
||||
|
||||
# Fork testing
|
||||
step '[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"' "git clone"
|
||||
step 'if [[ -n ${PYWORKER_REF:-} ]]; then (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); fi' "git checkout"
|
||||
[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
|
||||
if [[ -n ${PYWORKER_REF:-} ]]; then
|
||||
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
|
||||
fi
|
||||
|
||||
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
|
||||
source "$ENV_PATH/bin/activate"
|
||||
|
||||
step 'uv venv --python-preference only-managed "$ENV_PATH" -p 3.10' "venv create"
|
||||
step 'source "$ENV_PATH/bin/activate"' "venv activate"
|
||||
step 'uv pip install -r "${SERVER_DIR}/requirements.txt"' "pip install requirements"
|
||||
uv pip install -r "${SERVER_DIR}/requirements.txt"
|
||||
|
||||
touch ~/.no_auto_tmux
|
||||
else
|
||||
@@ -78,8 +80,39 @@ fi
|
||||
[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
|
||||
|
||||
if [ "$USE_SSL" = true ]; then
|
||||
step 'openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" -nodes -sha256 -keyout /etc/instance.key -out /etc/instance.csr -config /etc/openssl-san.cnf' "openssl csr"
|
||||
step 'curl --header "Content-Type: application/octet-stream" --data-binary @//etc/instance.csr -X POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt' "sign cert"
|
||||
|
||||
cat << EOF > /etc/openssl-san.cnf
|
||||
[req]
|
||||
default_bits = 2048
|
||||
distinguished_name = req_distinguished_name
|
||||
req_extensions = v3_req
|
||||
|
||||
[req_distinguished_name]
|
||||
countryName = US
|
||||
stateOrProvinceName = CA
|
||||
organizationName = Vast.ai Inc.
|
||||
commonName = vast.ai
|
||||
|
||||
[v3_req]
|
||||
basicConstraints = CA:FALSE
|
||||
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
|
||||
subjectAltName = @alt_names
|
||||
|
||||
[alt_names]
|
||||
IP.1 = 0.0.0.0
|
||||
EOF
|
||||
|
||||
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
|
||||
-nodes \
|
||||
-sha256 \
|
||||
-keyout /etc/instance.key \
|
||||
-out /etc/instance.csr \
|
||||
-config /etc/openssl-san.cnf
|
||||
|
||||
curl --header 'Content-Type: application/octet-stream' \
|
||||
--data-binary @//etc/instance.csr \
|
||||
-X \
|
||||
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
|
||||
fi
|
||||
|
||||
|
||||
@@ -89,11 +122,11 @@ export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
|
||||
|
||||
cd "$SERVER_DIR"
|
||||
|
||||
log "launching PyWorker server"
|
||||
echo "launching PyWorker server"
|
||||
|
||||
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
||||
|
||||
_t0=$(date +%s); (python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") & _dt=$(($(date +%s)-_t0)); log "PyWorker spawn took ${_dt}s"
|
||||
log "launching PyWorker server done"
|
||||
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
||||
echo "launching PyWorker server done"
|
||||
|
||||
@@ -12,9 +12,21 @@ A docker image is provided but you may use any if the above requirements are met
|
||||
|
||||
## Benchmarking
|
||||
|
||||
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
|
||||
### Custom Benchmark Workflows
|
||||
|
||||
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
|
||||
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
|
||||
|
||||
**Ways to provide the benchmark file:**
|
||||
- Fork this repository and add your `benchmark.json` file
|
||||
- Write the file during worker provisioning (onstart script or setup phase)
|
||||
|
||||
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
|
||||
|
||||
### Default Benchmark (Fallback)
|
||||
|
||||
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
|
||||
|
||||
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
|
||||
|
||||
| Environment Variable | Default Value | Description |
|
||||
| -------------------- | ------------- | ----------- |
|
||||
@@ -24,7 +36,7 @@ The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image wo
|
||||
|
||||
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
||||
|
||||
### Calibrating Benchmark Duration
|
||||
#### Calibrating Fallback Benchmark Duration
|
||||
|
||||
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
||||
|
||||
|
||||
@@ -5,12 +5,13 @@ import dataclasses
|
||||
from typing import Dict, Any
|
||||
from functools import cache
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
import json
|
||||
import logging
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
|
||||
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
|
||||
test_prompts = f.readlines()
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
def count_workload() -> float:
|
||||
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
||||
@@ -24,9 +25,32 @@ class ComfyWorkflowData(ApiPayload):
|
||||
@classmethod
|
||||
def for_test(cls):
|
||||
"""
|
||||
Use the variables available to simulate workflows of the required running time
|
||||
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
|
||||
Otherwise, use the variables available to simulate workflows of the required running time
|
||||
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
||||
"""
|
||||
# Try to load benchmark.json
|
||||
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
|
||||
|
||||
if benchmark_file.exists():
|
||||
try:
|
||||
with open(benchmark_file, "r") as f:
|
||||
benchmark_workflow = json.load(f)
|
||||
return cls(
|
||||
input={
|
||||
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||
"workflow_json": benchmark_workflow
|
||||
}
|
||||
)
|
||||
except (json.JSONDecodeError, IOError):
|
||||
# JSON is malformed or file can't be read, fall through to default
|
||||
log.error(f"Failed to benchmark using {benchmark_file}")
|
||||
|
||||
# Fallback: read prompts and construct payload
|
||||
log.info("Using fallback method for benchmarking")
|
||||
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
|
||||
test_prompts = f.readlines()
|
||||
|
||||
test_prompt = random.choice(test_prompts).rstrip()
|
||||
return cls(
|
||||
input={
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
{
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": "__RANDOM_INT__",
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "text, watermark",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||
"Value not in list: ", # This error is emitted when the model file is not there at all
|
||||
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -6,13 +6,10 @@ from typing import Union, Type, Dict, Any, Optional
|
||||
from aiohttp import web, ClientResponse
|
||||
import nltk
|
||||
import logging
|
||||
import time
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
t0 = time.time()
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} NLTK words download+load took {time.time()-t0:.2f}s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
Generic dataclass accepts any dictionary in input.
|
||||
|
||||
Reference in New Issue
Block a user