Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 62fbfb061d | |||
| c772e1651b | |||
| ecc6a3ce0d | |||
| 7986e51e9e | |||
| a47c9d1ed0 | |||
| 0b14562a63 | |||
| de9b50abb9 | |||
| c510801723 | |||
| a12523b1d2 |
+33
-7
@@ -235,10 +235,14 @@ class Backend:
|
|||||||
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
first_healthcheck = True
|
||||||
while True:
|
while True:
|
||||||
await sleep(10)
|
await sleep(10)
|
||||||
if self.__start_healthcheck is False:
|
if self.__start_healthcheck is False:
|
||||||
continue
|
continue
|
||||||
|
if first_healthcheck:
|
||||||
|
log.info(f"[healthcheck] First healthcheck starting (model is now loaded)")
|
||||||
|
first_healthcheck = False
|
||||||
try:
|
try:
|
||||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||||
async with self.healthcheck_session.get(health_check_url) as response:
|
async with self.healthcheck_session.get(health_check_url) as response:
|
||||||
@@ -256,9 +260,22 @@ class Backend:
|
|||||||
self.backend_errored(str(e))
|
self.backend_errored(str(e))
|
||||||
|
|
||||||
async def _start_tracking(self) -> None:
|
async def _start_tracking(self) -> None:
|
||||||
await gather(
|
log.info("Starting tracking tasks (read_logs, send_metrics_loop, healthcheck, send_delete_requests_loop)")
|
||||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
task_names = ["read_logs", "send_metrics_loop", "healthcheck", "send_delete_requests_loop"]
|
||||||
|
results = await gather(
|
||||||
|
self.__read_logs(),
|
||||||
|
self.metrics._send_metrics_loop(),
|
||||||
|
self.__healthcheck(),
|
||||||
|
self.metrics._send_delete_requests_loop(),
|
||||||
|
return_exceptions=True
|
||||||
)
|
)
|
||||||
|
# If we get here, one or more tasks exited (they should run forever)
|
||||||
|
log.error(f"CRITICAL: _start_tracking gather returned! This should never happen. Results: {results}")
|
||||||
|
for name, result in zip(task_names, results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
log.error(f"Tracking task '{name}' crashed with exception: {result}", exc_info=result)
|
||||||
|
elif result is not None:
|
||||||
|
log.warning(f"Tracking task '{name}' exited unexpectedly with result: {result}")
|
||||||
|
|
||||||
def backend_errored(self, msg: str) -> None:
|
def backend_errored(self, msg: str) -> None:
|
||||||
self.metrics._model_errored(msg)
|
self.metrics._model_errored(msg)
|
||||||
@@ -399,15 +416,20 @@ class Backend:
|
|||||||
# await sleep(5)
|
# await sleep(5)
|
||||||
try:
|
try:
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
|
log.info(f"[benchmark] Benchmark complete, max_throughput={max_throughput}, setting healthcheck=True")
|
||||||
self.__start_healthcheck = True
|
self.__start_healthcheck = True
|
||||||
self.metrics._model_loaded(
|
self.metrics._model_loaded(
|
||||||
max_throughput=max_throughput,
|
max_throughput=max_throughput,
|
||||||
)
|
)
|
||||||
|
log.info(f"[benchmark] _model_loaded() called, returning from handle_log_line")
|
||||||
except ClientConnectorError as e:
|
except ClientConnectorError as e:
|
||||||
log.debug(
|
log.debug(
|
||||||
f"failed to connect to comfyui api during benchmark"
|
f"failed to connect to model api during benchmark"
|
||||||
)
|
)
|
||||||
self.backend_errored(str(e))
|
self.backend_errored(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Unexpected error during benchmark: {e}", exc_info=True)
|
||||||
|
self.backend_errored(f"Benchmark failed: {e}")
|
||||||
case LogAction.ModelError if msg in log_line:
|
case LogAction.ModelError if msg in log_line:
|
||||||
log.debug(f"Got log line indicating error: {log_line}")
|
log.debug(f"Got log line indicating error: {log_line}")
|
||||||
self.backend_errored(msg)
|
self.backend_errored(msg)
|
||||||
@@ -419,10 +441,14 @@ class Backend:
|
|||||||
log.debug(f"tailing file: {self.model_log_file}")
|
log.debug(f"tailing file: {self.model_log_file}")
|
||||||
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
|
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
|
||||||
while True:
|
while True:
|
||||||
line = await f.readline()
|
try:
|
||||||
if line:
|
line = await f.readline()
|
||||||
await handle_log_line(line.rstrip())
|
if line:
|
||||||
else:
|
await handle_log_line(line.rstrip())
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(LOG_POLL_INTERVAL)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error processing log line: {e}", exc_info=True)
|
||||||
await asyncio.sleep(LOG_POLL_INTERVAL)
|
await asyncio.sleep(LOG_POLL_INTERVAL)
|
||||||
|
|
||||||
###########
|
###########
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
@@ -17,6 +18,14 @@ DELETE_REQUESTS_INTERVAL = 1
|
|||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_logs():
|
||||||
|
"""Force flush all log handlers and stdout/stderr."""
|
||||||
|
for handler in logging.root.handlers:
|
||||||
|
handler.flush()
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_url() -> str:
|
def get_url() -> str:
|
||||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||||
@@ -119,22 +128,41 @@ class Metrics:
|
|||||||
await self.__send_delete_requests_and_reset()
|
await self.__send_delete_requests_and_reset()
|
||||||
|
|
||||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||||
|
loop_count = 0
|
||||||
|
first_loaded_send_done = False
|
||||||
while True:
|
while True:
|
||||||
await sleep(METRICS_UPDATE_INTERVAL)
|
await sleep(METRICS_UPDATE_INTERVAL)
|
||||||
|
loop_count += 1
|
||||||
elapsed = time.time() - self.last_metric_update
|
elapsed = time.time() - self.last_metric_update
|
||||||
|
# Log heartbeat every 30 seconds to confirm loop is running
|
||||||
|
if loop_count % 30 == 0:
|
||||||
|
log.debug(f"[heartbeat] metrics loop alive, loop_count={loop_count}, model_loaded={self.system_metrics.model_is_loaded}")
|
||||||
|
_flush_logs()
|
||||||
|
# Extra logging for first few iterations after model loads
|
||||||
|
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
|
||||||
|
log.info(f"[transition] First iteration with model_loaded=True, loop_count={loop_count}, elapsed={elapsed:.1f}")
|
||||||
|
_flush_logs()
|
||||||
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
||||||
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
||||||
await self.__send_metrics_and_reset()
|
await self.__send_metrics_and_reset()
|
||||||
elif self.update_pending or elapsed > 10:
|
elif self.update_pending or elapsed > 10:
|
||||||
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
||||||
await self.__send_metrics_and_reset()
|
await self.__send_metrics_and_reset()
|
||||||
|
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
|
||||||
|
first_loaded_send_done = True
|
||||||
|
log.info(f"[transition] First loaded metrics send complete, continuing to next iteration...")
|
||||||
|
_flush_logs()
|
||||||
|
|
||||||
def _model_loaded(self, max_throughput: float) -> None:
|
def _model_loaded(self, max_throughput: float) -> None:
|
||||||
|
log.info(f"MODEL LOADED: Setting model_is_loaded=True, max_throughput={max_throughput}")
|
||||||
|
_flush_logs()
|
||||||
self.system_metrics.model_loading_time = (
|
self.system_metrics.model_loading_time = (
|
||||||
time.time() - self.system_metrics.model_loading_start
|
time.time() - self.system_metrics.model_loading_start
|
||||||
)
|
)
|
||||||
self.system_metrics.model_is_loaded = True
|
self.system_metrics.model_is_loaded = True
|
||||||
self.model_metrics.max_throughput = max_throughput
|
self.model_metrics.max_throughput = max_throughput
|
||||||
|
log.info(f"MODEL LOADED: model_loading_time={self.system_metrics.model_loading_time}")
|
||||||
|
_flush_logs()
|
||||||
|
|
||||||
def _model_errored(self, error_msg: str) -> None:
|
def _model_errored(self, error_msg: str) -> None:
|
||||||
self.model_metrics.set_errored(error_msg)
|
self.model_metrics.set_errored(error_msg)
|
||||||
@@ -271,6 +299,7 @@ class Metrics:
|
|||||||
###########
|
###########
|
||||||
|
|
||||||
self.system_metrics.update_disk_usage()
|
self.system_metrics.update_disk_usage()
|
||||||
|
had_loadtime = loadtime_snapshot is not None and loadtime_snapshot > 0
|
||||||
|
|
||||||
sent = False
|
sent = False
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
@@ -279,8 +308,14 @@ class Metrics:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if sent:
|
if sent:
|
||||||
|
if had_loadtime:
|
||||||
|
log.info(f"FIRST LOADTIME METRICS SENT SUCCESSFULLY! loadtime={loadtime_snapshot}")
|
||||||
|
_flush_logs()
|
||||||
# clear the one-shot loadtime only if we actually sent *this* value
|
# clear the one-shot loadtime only if we actually sent *this* value
|
||||||
self.system_metrics.reset(expected=loadtime_snapshot)
|
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||||
self.update_pending = False
|
self.update_pending = False
|
||||||
self.model_metrics.reset()
|
self.model_metrics.reset()
|
||||||
self.last_metric_update = time.time()
|
self.last_metric_update = time.time()
|
||||||
|
if had_loadtime:
|
||||||
|
log.info(f"POST-SEND: reset complete, last_metric_update={self.last_metric_update}, continuing loop...")
|
||||||
|
_flush_logs()
|
||||||
|
|||||||
+65
-25
@@ -1,40 +1,80 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
from typing import List
|
from typing import List
|
||||||
import ssl
|
import ssl
|
||||||
from asyncio import run, gather
|
from asyncio import run, gather
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from lib.backend import Backend
|
from lib.backend import Backend
|
||||||
|
from lib.metrics import Metrics
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_signal_handlers():
|
||||||
|
"""Setup signal handlers to log when process receives termination signals."""
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
sig_name = signal.Signals(signum).name
|
||||||
|
log.error(f"SIGNAL RECEIVED: {sig_name} ({signum}) - process is being terminated")
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
sys.exit(128 + signum)
|
||||||
|
|
||||||
|
# Handle common termination signals
|
||||||
|
for sig in [signal.SIGTERM, signal.SIGINT, signal.SIGHUP]:
|
||||||
|
try:
|
||||||
|
signal.signal(sig, signal_handler)
|
||||||
|
except (OSError, ValueError):
|
||||||
|
pass # Some signals may not be available
|
||||||
|
|
||||||
|
|
||||||
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||||
log.debug("getting certificate...")
|
_setup_signal_handlers()
|
||||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
try:
|
||||||
if use_ssl is True:
|
log.debug("getting certificate...")
|
||||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||||
ssl_context.load_cert_chain(
|
if use_ssl is True:
|
||||||
certfile="/etc/instance.crt",
|
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
keyfile="/etc/instance.key",
|
ssl_context.load_cert_chain(
|
||||||
)
|
certfile="/etc/instance.crt",
|
||||||
else:
|
keyfile="/etc/instance.key",
|
||||||
ssl_context = None
|
)
|
||||||
|
else:
|
||||||
|
ssl_context = None
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
log.debug("starting server...")
|
log.debug("starting server...")
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
runner = web.AppRunner(app)
|
runner = web.AppRunner(app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(
|
site = web.TCPSite(
|
||||||
runner,
|
runner,
|
||||||
ssl_context=ssl_context,
|
ssl_context=ssl_context,
|
||||||
port=int(os.environ["WORKER_PORT"]),
|
port=int(os.environ["WORKER_PORT"]),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
await gather(site.start(), backend._start_tracking())
|
await gather(site.start(), backend._start_tracking())
|
||||||
|
|
||||||
run(main())
|
run(main())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
err_msg = f"PyWorker failed to launch: {e}"
|
||||||
|
log.error(err_msg)
|
||||||
|
|
||||||
|
async def beacon():
|
||||||
|
metrics = Metrics()
|
||||||
|
metrics._set_version(getattr(backend, "version", "0"))
|
||||||
|
metrics._set_mtoken(getattr(backend, "mtoken", ""))
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
metrics._model_errored(err_msg)
|
||||||
|
await metrics._Metrics__send_metrics_and_reset()
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
finally:
|
||||||
|
await metrics.aclose()
|
||||||
|
|
||||||
|
run(beacon())
|
||||||
|
|||||||
+40
-2
@@ -132,5 +132,43 @@ cd "$SERVER_DIR"
|
|||||||
|
|
||||||
echo "launching PyWorker server"
|
echo "launching PyWorker server"
|
||||||
|
|
||||||
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
set +e
|
||||||
echo "launching PyWorker server done"
|
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
|
||||||
|
PY_STATUS=${PIPESTATUS[0]}
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if [ "${PY_STATUS}" -ne 0 ]; then
|
||||||
|
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
|
||||||
|
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
|
||||||
|
MTOKEN="${MASTER_TOKEN:-}"
|
||||||
|
VERSION="${PYWORKER_VERSION:-0}"
|
||||||
|
|
||||||
|
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
|
||||||
|
for addr in "${REPORT_ADDRS[@]}"; do
|
||||||
|
curl -sS -X POST -H 'Content-Type: application/json' \
|
||||||
|
-d "$(cat <<JSON
|
||||||
|
{
|
||||||
|
"id": ${CONTAINER_ID:-0},
|
||||||
|
"mtoken": "${MTOKEN}",
|
||||||
|
"version": "${VERSION}",
|
||||||
|
"loadtime": 0,
|
||||||
|
"new_load": 0,
|
||||||
|
"cur_load": 0,
|
||||||
|
"rej_load": 0,
|
||||||
|
"max_perf": 0,
|
||||||
|
"cur_perf": 0,
|
||||||
|
"error_msg": "${ERROR_MSG}",
|
||||||
|
"num_requests_working": 0,
|
||||||
|
"num_requests_recieved": 0,
|
||||||
|
"additional_disk_usage": 0,
|
||||||
|
"working_request_idxs": [],
|
||||||
|
"cur_capacity": 0,
|
||||||
|
"max_capacity": 0,
|
||||||
|
"url": "${URL}"
|
||||||
|
}
|
||||||
|
JSON
|
||||||
|
)" "${addr%/}/worker_status/" || true
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "launching PyWorker server done"
|
||||||
Reference in New Issue
Block a user