Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fd9d56e576 | |||
| 8d9ffb3a6c | |||
| 5d5bc197d7 |
+28
-13
@@ -73,6 +73,9 @@ class Backend:
|
|||||||
self._total_pubkey_fetch_errors = 0
|
self._total_pubkey_fetch_errors = 0
|
||||||
self._pubkey = self._fetch_pubkey()
|
self._pubkey = self._fetch_pubkey()
|
||||||
self.__start_healthcheck: bool = False
|
self.__start_healthcheck: bool = False
|
||||||
|
self._model_tail_start_time = None
|
||||||
|
self._model_loaded_time = None
|
||||||
|
self._first_healthcheck_ok = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pubkey(self) -> Optional[RSA.RsaKey]:
|
def pubkey(self) -> Optional[RSA.RsaKey]:
|
||||||
@@ -104,6 +107,7 @@ class Backend:
|
|||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
def _fetch_pubkey(self):
|
def _fetch_pubkey(self):
|
||||||
|
t0 = time.time()
|
||||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||||
result = subprocess.check_output(command, universal_newlines=True)
|
result = subprocess.check_output(command, universal_newlines=True)
|
||||||
log.debug("public key:")
|
log.debug("public key:")
|
||||||
@@ -120,6 +124,8 @@ class Backend:
|
|||||||
self._total_pubkey_fetch_errors += 1
|
self._total_pubkey_fetch_errors += 1
|
||||||
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
||||||
self.backend_errored("Failed to get autoscaler pubkey")
|
self.backend_errored("Failed to get autoscaler pubkey")
|
||||||
|
else:
|
||||||
|
log.debug(f"pubkey fetch+parse took {time.time()-t0:.2f}s")
|
||||||
return key
|
return key
|
||||||
|
|
||||||
async def __handle_request(
|
async def __handle_request(
|
||||||
@@ -240,6 +246,10 @@ class Backend:
|
|||||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||||
async with self.healthcheck_session.get(health_check_url) as response:
|
async with self.healthcheck_session.get(health_check_url) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
|
if not self._first_healthcheck_ok:
|
||||||
|
if self._model_loaded_time:
|
||||||
|
log.debug(f"first healthcheck OK after {time.time()-self._model_loaded_time:.2f}s since ModelLoaded")
|
||||||
|
self._first_healthcheck_ok = True
|
||||||
log.debug("Healthcheck successful")
|
log.debug("Healthcheck successful")
|
||||||
elif response.status == 503:
|
elif response.status == 503:
|
||||||
log.debug(f"Healthcheck failed with status: {response.status}")
|
log.debug(f"Healthcheck failed with status: {response.status}")
|
||||||
@@ -286,7 +296,7 @@ class Backend:
|
|||||||
message = {
|
message = {
|
||||||
key: value
|
key: value
|
||||||
for (key, value) in (dataclasses.asdict(auth_data).items())
|
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||||
if key != "signature" and key != "__request_id"
|
if key != "signature"
|
||||||
}
|
}
|
||||||
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -296,7 +306,7 @@ class Backend:
|
|||||||
elif message in self.msg_history:
|
elif message in self.msg_history:
|
||||||
log.debug(f"message: {message} already in message history")
|
log.debug(f"message: {message} already in message history")
|
||||||
return False
|
return False
|
||||||
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
|
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
|
||||||
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||||
self.msg_history.append(message)
|
self.msg_history.append(message)
|
||||||
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||||
@@ -315,17 +325,20 @@ class Backend:
|
|||||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||||
log.debug("already ran benchmark")
|
log.debug("already ran benchmark")
|
||||||
# trigger model load
|
# trigger model load
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
_ = await self.__call_api(
|
# _ = await self.__call_api(
|
||||||
handler=self.benchmark_handler, payload=payload
|
# handler=self.benchmark_handler, payload=payload
|
||||||
)
|
# )
|
||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log.debug("Initial run to trigger model loading...")
|
log.debug("Initial run to trigger model loading...")
|
||||||
|
t_bench0 = time.time()
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||||
|
log.debug(f"warmup request took {time.time()-t_bench0:.2f}s")
|
||||||
|
t_benchmark_loop0 = time.time()
|
||||||
|
|
||||||
max_throughput = 0
|
max_throughput = 0
|
||||||
sum_throughput = 0
|
sum_throughput = 0
|
||||||
@@ -350,9 +363,6 @@ class Backend:
|
|||||||
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
||||||
time_elapsed = time.time() - start
|
time_elapsed = time.time() - start
|
||||||
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
||||||
if successful_responses == 0:
|
|
||||||
self.backend_errored("No successful responses from benchmark")
|
|
||||||
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
|
|
||||||
|
|
||||||
throughput = total_workload / time_elapsed
|
throughput = total_workload / time_elapsed
|
||||||
sum_throughput += throughput
|
sum_throughput += throughput
|
||||||
@@ -376,6 +386,7 @@ class Backend:
|
|||||||
log.debug(
|
log.debug(
|
||||||
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||||
)
|
)
|
||||||
|
log.debug(f"benchmark loop took {time.time()-t_benchmark_loop0:.2f}s")
|
||||||
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||||
f.write(str(max_throughput))
|
f.write(str(max_throughput))
|
||||||
return max_throughput
|
return max_throughput
|
||||||
@@ -388,14 +399,17 @@ class Backend:
|
|||||||
for action, msg in self.log_actions:
|
for action, msg in self.log_actions:
|
||||||
match action:
|
match action:
|
||||||
case LogAction.ModelLoaded if msg in log_line:
|
case LogAction.ModelLoaded if msg in log_line:
|
||||||
log.debug(
|
now = time.time()
|
||||||
f"Got log line indicating model is loaded: {log_line}"
|
elapsed = now - self._model_tail_start_time
|
||||||
)
|
log.debug(f"ModelLoaded observed after {elapsed:.2f}s: {log_line}")
|
||||||
# some backends need a few seconds after logging successful startup before
|
# some backends need a few seconds after logging successful startup before
|
||||||
# they can begin accepting requests
|
# they can begin accepting requests
|
||||||
await sleep(5)
|
# await sleep(5)
|
||||||
try:
|
try:
|
||||||
|
t_bench0 = time.time()
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
|
self._model_loaded_time = time.time()
|
||||||
|
log.debug(f"benchmark total took {self._model_loaded_time - t_bench0:.2f}s")
|
||||||
self.__start_healthcheck = True
|
self.__start_healthcheck = True
|
||||||
self.metrics._model_loaded(
|
self.metrics._model_loaded(
|
||||||
max_throughput=max_throughput,
|
max_throughput=max_throughput,
|
||||||
@@ -414,6 +428,7 @@ class Backend:
|
|||||||
|
|
||||||
async def tail_log():
|
async def tail_log():
|
||||||
log.debug(f"tailing file: {self.model_log_file}")
|
log.debug(f"tailing file: {self.model_log_file}")
|
||||||
|
self._model_tail_start_time = time.time()
|
||||||
async with await open_file(self.model_log_file) as f:
|
async with await open_file(self.model_log_file) as f:
|
||||||
while True:
|
while True:
|
||||||
line = await f.readline()
|
line = await f.readline()
|
||||||
|
|||||||
+5
-6
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
|
|||||||
class AuthData:
|
class AuthData:
|
||||||
"""data used to authenticate requester"""
|
"""data used to authenticate requester"""
|
||||||
|
|
||||||
|
signature: str
|
||||||
cost: str
|
cost: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
reqnum: int
|
reqnum: int
|
||||||
request_idx: int
|
|
||||||
signature: str
|
|
||||||
url: str
|
url: str
|
||||||
|
request_idx: int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
||||||
@@ -190,12 +190,11 @@ class SystemMetrics:
|
|||||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||||
self.last_disk_usage = disk_usage
|
self.last_disk_usage = disk_usage
|
||||||
|
|
||||||
def reset(self, expected: float | None) -> None:
|
def reset(self):
|
||||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||||
# as well: they should send model_loading_time once when they are done loading
|
# as well: they should send model_loading_time once when they are done loading
|
||||||
if self.model_loading_time == expected:
|
self.model_loading_time = None
|
||||||
self.model_loading_time = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -258,7 +257,7 @@ class ModelMetrics:
|
|||||||
def wait_time(self) -> float:
|
def wait_time(self) -> float:
|
||||||
if (len(self.requests_working) == 0):
|
if (len(self.requests_working) == 0):
|
||||||
return 0.0
|
return 0.0
|
||||||
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
|
return sum([request.workload for request in self.requests_working.values()]) / self.max_throughput
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cur_load(self) -> float:
|
def cur_load(self) -> float:
|
||||||
|
|||||||
+16
-45
@@ -145,15 +145,14 @@ class Metrics:
|
|||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
|
|
||||||
async def __send_delete_requests_and_reset(self):
|
async def __send_delete_requests_and_reset(self):
|
||||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
|
||||||
|
async def send_data(report_addr: str, success: bool) -> bool:
|
||||||
data = {
|
data = {
|
||||||
"worker_id": self.id,
|
"worker_id": self.id,
|
||||||
"request_idxs": idxs,
|
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success],
|
||||||
"success": success_flag,
|
"success": success
|
||||||
}
|
}
|
||||||
log.debug(
|
log.debug(f"Deleting requests that {'succeeded' if success else 'failed'}: {data['request_idxs']}")
|
||||||
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
|
|
||||||
)
|
|
||||||
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
||||||
for attempt in range(1, 4):
|
for attempt in range(1, 4):
|
||||||
try:
|
try:
|
||||||
@@ -163,50 +162,26 @@ class Metrics:
|
|||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
return True
|
return True
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
log.debug("delete_requests timed out")
|
log.debug(f"delete_requests timed out")
|
||||||
except (ClientResponseError, Exception) as e:
|
except (ClientResponseError, Exception) as e:
|
||||||
log.debug(f"delete_requests failed with error: {e}")
|
log.debug(f"delete_requests failed with error: {e}")
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
log.debug(f"retrying delete_request, attempt: {attempt}")
|
log.debug(f"retrying delete_request, attempt: {attempt}")
|
||||||
return False
|
|
||||||
|
|
||||||
# Take a snapshot of what we plan to send this tick.
|
|
||||||
# New arrivals after this snapshot will remain in the queue for the next tick.
|
|
||||||
snapshot = list(self.model_metrics.requests_deleting)
|
|
||||||
success_idxs = [r.request_idx for r in snapshot if r.success is True]
|
|
||||||
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
|
|
||||||
|
|
||||||
if not success_idxs and not failed_idxs:
|
|
||||||
return # nothing to do
|
|
||||||
|
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
sent_success = True
|
success = await send_data(report_addr, success=True) and await send_data(report_addr, success=False)
|
||||||
sent_failed = True
|
if success is True:
|
||||||
|
self.model_metrics.requests_deleting.clear()
|
||||||
if success_idxs:
|
|
||||||
sent_success = await post(report_addr, success_idxs, True)
|
|
||||||
if failed_idxs:
|
|
||||||
sent_failed = await post(report_addr, failed_idxs, False)
|
|
||||||
|
|
||||||
if sent_success and sent_failed:
|
|
||||||
# Remove only the items we actually sent from the live queue.
|
|
||||||
sent_set = set(success_idxs) | set(failed_idxs)
|
|
||||||
self.model_metrics.requests_deleting[:] = [
|
|
||||||
r for r in self.model_metrics.requests_deleting
|
|
||||||
if r.request_idx not in sent_set
|
|
||||||
]
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
async def __send_metrics_and_reset(self):
|
async def __send_metrics_and_reset(self):
|
||||||
|
|
||||||
loadtime_snapshot = self.system_metrics.model_loading_time
|
|
||||||
|
|
||||||
def compute_autoscaler_data() -> AutoScalerData:
|
def compute_autoscaler_data() -> AutoScalerData:
|
||||||
return AutoScalerData(
|
return AutoScalerData(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
version=self.version,
|
version=self.version,
|
||||||
loadtime=(loadtime_snapshot or 0.0),
|
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
||||||
new_load=self.model_metrics.workload_processing,
|
new_load=self.model_metrics.workload_processing,
|
||||||
cur_load=self.model_metrics.cur_load,
|
cur_load=self.model_metrics.cur_load,
|
||||||
rej_load=self.model_metrics.workload_rejected,
|
rej_load=self.model_metrics.workload_rejected,
|
||||||
@@ -254,15 +229,11 @@ class Metrics:
|
|||||||
|
|
||||||
self.system_metrics.update_disk_usage()
|
self.system_metrics.update_disk_usage()
|
||||||
|
|
||||||
sent = False
|
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
if await send_data(report_addr):
|
success = await send_data(report_addr)
|
||||||
sent = True
|
if success is True:
|
||||||
break
|
break
|
||||||
|
self.update_pending = False
|
||||||
if sent:
|
self.model_metrics.reset()
|
||||||
# clear the one-shot loadtime only if we actually sent *this* value
|
self.system_metrics.reset()
|
||||||
self.system_metrics.reset(expected=loadtime_snapshot)
|
self.last_metric_update = time.time()
|
||||||
self.update_pending = False
|
|
||||||
self.model_metrics.reset()
|
|
||||||
self.last_metric_update = time.time()
|
|
||||||
|
|||||||
+22
-55
@@ -2,6 +2,9 @@
|
|||||||
|
|
||||||
set -e -o pipefail
|
set -e -o pipefail
|
||||||
|
|
||||||
|
log() { echo "$(date +'%Y-%m-%d %H:%M:%S') $*"; }
|
||||||
|
step(){ _t0=$(date +%s); eval "$1"; _dt=$(($(date +%s)-_t0)); log "$2 took ${_dt}s"; }
|
||||||
|
|
||||||
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
|
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
|
||||||
|
|
||||||
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
|
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
|
||||||
@@ -41,33 +44,28 @@ echo_var DEBUG_LOG
|
|||||||
echo_var PYWORKER_LOG
|
echo_var PYWORKER_LOG
|
||||||
echo_var MODEL_LOG
|
echo_var MODEL_LOG
|
||||||
|
|
||||||
# Populate /etc/environment with quoted values
|
# # Populate /etc/environment with quoted values
|
||||||
if ! grep -q "VAST" /etc/environment; then
|
# if ! grep -q "VAST" /etc/environment; then
|
||||||
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
|
# env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
|
||||||
name=${line%%=*}
|
# name=${line%%=*}
|
||||||
value=${line#*=}
|
# value=${line#*=}
|
||||||
printf '%s="%s"\n' "$name" "$value"
|
# printf '%s="%s"\n' "$name" "$value"
|
||||||
done > /etc/environment
|
# done > /etc/environment
|
||||||
fi
|
# fi
|
||||||
|
|
||||||
if [ ! -d "$ENV_PATH" ]
|
if [ ! -d "$ENV_PATH" ]
|
||||||
then
|
then
|
||||||
echo "setting up venv"
|
echo "setting up venv"
|
||||||
if ! which uv; then
|
step 'if ! which uv; then curl -LsSf https://astral.sh/uv/install.sh | sh; source ~/.local/bin/env; fi' "uv install"
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
||||||
source ~/.local/bin/env
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Fork testing
|
# Fork testing
|
||||||
[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
|
step '[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"' "git clone"
|
||||||
if [[ -n ${PYWORKER_REF:-} ]]; then
|
step 'if [[ -n ${PYWORKER_REF:-} ]]; then (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); fi' "git checkout"
|
||||||
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
|
|
||||||
fi
|
|
||||||
|
|
||||||
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
|
|
||||||
source "$ENV_PATH/bin/activate"
|
|
||||||
|
|
||||||
uv pip install -r "${SERVER_DIR}/requirements.txt"
|
step 'uv venv --python-preference only-managed "$ENV_PATH" -p 3.10' "venv create"
|
||||||
|
step 'source "$ENV_PATH/bin/activate"' "venv activate"
|
||||||
|
step 'uv pip install -r "${SERVER_DIR}/requirements.txt"' "pip install requirements"
|
||||||
|
|
||||||
touch ~/.no_auto_tmux
|
touch ~/.no_auto_tmux
|
||||||
else
|
else
|
||||||
@@ -80,39 +78,8 @@ fi
|
|||||||
[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
|
[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
|
||||||
|
|
||||||
if [ "$USE_SSL" = true ]; then
|
if [ "$USE_SSL" = true ]; then
|
||||||
|
step 'openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" -nodes -sha256 -keyout /etc/instance.key -out /etc/instance.csr -config /etc/openssl-san.cnf' "openssl csr"
|
||||||
cat << EOF > /etc/openssl-san.cnf
|
step 'curl --header "Content-Type: application/octet-stream" --data-binary @//etc/instance.csr -X POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt' "sign cert"
|
||||||
[req]
|
|
||||||
default_bits = 2048
|
|
||||||
distinguished_name = req_distinguished_name
|
|
||||||
req_extensions = v3_req
|
|
||||||
|
|
||||||
[req_distinguished_name]
|
|
||||||
countryName = US
|
|
||||||
stateOrProvinceName = CA
|
|
||||||
organizationName = Vast.ai Inc.
|
|
||||||
commonName = vast.ai
|
|
||||||
|
|
||||||
[v3_req]
|
|
||||||
basicConstraints = CA:FALSE
|
|
||||||
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
|
|
||||||
subjectAltName = @alt_names
|
|
||||||
|
|
||||||
[alt_names]
|
|
||||||
IP.1 = 0.0.0.0
|
|
||||||
EOF
|
|
||||||
|
|
||||||
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
|
|
||||||
-nodes \
|
|
||||||
-sha256 \
|
|
||||||
-keyout /etc/instance.key \
|
|
||||||
-out /etc/instance.csr \
|
|
||||||
-config /etc/openssl-san.cnf
|
|
||||||
|
|
||||||
curl --header 'Content-Type: application/octet-stream' \
|
|
||||||
--data-binary @//etc/instance.csr \
|
|
||||||
-X \
|
|
||||||
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
@@ -122,11 +89,11 @@ export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
|
|||||||
|
|
||||||
cd "$SERVER_DIR"
|
cd "$SERVER_DIR"
|
||||||
|
|
||||||
echo "launching PyWorker server"
|
log "launching PyWorker server"
|
||||||
|
|
||||||
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
||||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||||
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
||||||
|
|
||||||
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
_t0=$(date +%s); (python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") & _dt=$(($(date +%s)-_t0)); log "PyWorker spawn took ${_dt}s"
|
||||||
echo "launching PyWorker server done"
|
log "launching PyWorker server done"
|
||||||
|
|||||||
@@ -12,21 +12,9 @@ A docker image is provided but you may use any if the above requirements are met
|
|||||||
|
|
||||||
## Benchmarking
|
## Benchmarking
|
||||||
|
|
||||||
### Custom Benchmark Workflows
|
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
|
||||||
|
|
||||||
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
|
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
|
||||||
|
|
||||||
**Ways to provide the benchmark file:**
|
|
||||||
- Fork this repository and add your `benchmark.json` file
|
|
||||||
- Write the file during worker provisioning (onstart script or setup phase)
|
|
||||||
|
|
||||||
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
|
|
||||||
|
|
||||||
### Default Benchmark (Fallback)
|
|
||||||
|
|
||||||
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
|
|
||||||
|
|
||||||
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
|
|
||||||
|
|
||||||
| Environment Variable | Default Value | Description |
|
| Environment Variable | Default Value | Description |
|
||||||
| -------------------- | ------------- | ----------- |
|
| -------------------- | ------------- | ----------- |
|
||||||
@@ -36,7 +24,7 @@ The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to
|
|||||||
|
|
||||||
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
||||||
|
|
||||||
#### Calibrating Fallback Benchmark Duration
|
### Calibrating Benchmark Duration
|
||||||
|
|
||||||
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
||||||
|
|
||||||
|
|||||||
@@ -5,13 +5,12 @@ import dataclasses
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from pathlib import Path
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from lib.data_types import ApiPayload, JsonDataException
|
from lib.data_types import ApiPayload, JsonDataException
|
||||||
|
|
||||||
log = logging.getLogger(__file__)
|
|
||||||
|
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
|
||||||
|
test_prompts = f.readlines()
|
||||||
|
|
||||||
def count_workload() -> float:
|
def count_workload() -> float:
|
||||||
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
||||||
@@ -25,32 +24,9 @@ class ComfyWorkflowData(ApiPayload):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls):
|
def for_test(cls):
|
||||||
"""
|
"""
|
||||||
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
|
Use the variables available to simulate workflows of the required running time
|
||||||
Otherwise, use the variables available to simulate workflows of the required running time
|
|
||||||
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
||||||
"""
|
"""
|
||||||
# Try to load benchmark.json
|
|
||||||
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
|
|
||||||
|
|
||||||
if benchmark_file.exists():
|
|
||||||
try:
|
|
||||||
with open(benchmark_file, "r") as f:
|
|
||||||
benchmark_workflow = json.load(f)
|
|
||||||
return cls(
|
|
||||||
input={
|
|
||||||
"request_id": f"test-{random.randint(1000, 99999)}",
|
|
||||||
"workflow_json": benchmark_workflow
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except (json.JSONDecodeError, IOError):
|
|
||||||
# JSON is malformed or file can't be read, fall through to default
|
|
||||||
log.error(f"Failed to benchmark using {benchmark_file}")
|
|
||||||
|
|
||||||
# Fallback: read prompts and construct payload
|
|
||||||
log.info("Using fallback method for benchmarking")
|
|
||||||
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
|
|
||||||
test_prompts = f.readlines()
|
|
||||||
|
|
||||||
test_prompt = random.choice(test_prompts).rstrip()
|
test_prompt = random.choice(test_prompts).rstrip()
|
||||||
return cls(
|
return cls(
|
||||||
input={
|
input={
|
||||||
|
|||||||
@@ -1,107 +0,0 @@
|
|||||||
{
|
|
||||||
"3": {
|
|
||||||
"inputs": {
|
|
||||||
"seed": "__RANDOM_INT__",
|
|
||||||
"steps": 20,
|
|
||||||
"cfg": 8,
|
|
||||||
"sampler_name": "euler",
|
|
||||||
"scheduler": "normal",
|
|
||||||
"denoise": 1,
|
|
||||||
"model": [
|
|
||||||
"4",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"positive": [
|
|
||||||
"6",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"negative": [
|
|
||||||
"7",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"latent_image": [
|
|
||||||
"5",
|
|
||||||
0
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "KSampler",
|
|
||||||
"_meta": {
|
|
||||||
"title": "KSampler"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"4": {
|
|
||||||
"inputs": {
|
|
||||||
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
|
|
||||||
},
|
|
||||||
"class_type": "CheckpointLoaderSimple",
|
|
||||||
"_meta": {
|
|
||||||
"title": "Load Checkpoint"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"5": {
|
|
||||||
"inputs": {
|
|
||||||
"width": 512,
|
|
||||||
"height": 512,
|
|
||||||
"batch_size": 1
|
|
||||||
},
|
|
||||||
"class_type": "EmptyLatentImage",
|
|
||||||
"_meta": {
|
|
||||||
"title": "Empty Latent Image"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"6": {
|
|
||||||
"inputs": {
|
|
||||||
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
|
||||||
"clip": [
|
|
||||||
"4",
|
|
||||||
1
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "CLIPTextEncode",
|
|
||||||
"_meta": {
|
|
||||||
"title": "CLIP Text Encode (Prompt)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"7": {
|
|
||||||
"inputs": {
|
|
||||||
"text": "text, watermark",
|
|
||||||
"clip": [
|
|
||||||
"4",
|
|
||||||
1
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "CLIPTextEncode",
|
|
||||||
"_meta": {
|
|
||||||
"title": "CLIP Text Encode (Prompt)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"8": {
|
|
||||||
"inputs": {
|
|
||||||
"samples": [
|
|
||||||
"3",
|
|
||||||
0
|
|
||||||
],
|
|
||||||
"vae": [
|
|
||||||
"4",
|
|
||||||
2
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "VAEDecode",
|
|
||||||
"_meta": {
|
|
||||||
"title": "VAE Decode"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"9": {
|
|
||||||
"inputs": {
|
|
||||||
"filename_prefix": "ComfyUI",
|
|
||||||
"images": [
|
|
||||||
"8",
|
|
||||||
0
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"class_type": "SaveImage",
|
|
||||||
"_meta": {
|
|
||||||
"title": "Save Image"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -19,7 +19,6 @@ MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
|||||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||||
"Value not in list: ", # This error is emitted when the model file is not there at all
|
"Value not in list: ", # This error is emitted when the model file is not there at all
|
||||||
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,10 +6,13 @@ from typing import Union, Type, Dict, Any, Optional
|
|||||||
from aiohttp import web, ClientResponse
|
from aiohttp import web, ClientResponse
|
||||||
import nltk
|
import nltk
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
t0 = time.time()
|
||||||
nltk.download("words")
|
nltk.download("words")
|
||||||
WORD_LIST = nltk.corpus.words.words()
|
WORD_LIST = nltk.corpus.words.words()
|
||||||
log = logging.getLogger(__name__)
|
print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} NLTK words download+load took {time.time()-t0:.2f}s")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Generic dataclass accepts any dictionary in input.
|
Generic dataclass accepts any dictionary in input.
|
||||||
|
|||||||
Reference in New Issue
Block a user