Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7d43bc8d68 |
+23
-26
@@ -30,7 +30,7 @@ from lib.data_types import (
|
|||||||
BenchmarkResult
|
BenchmarkResult
|
||||||
)
|
)
|
||||||
|
|
||||||
VERSION = "0.2.0"
|
VERSION = "0.1.0"
|
||||||
|
|
||||||
MSG_HISTORY_LEN = 100
|
MSG_HISTORY_LEN = 100
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
@@ -66,17 +66,10 @@ class Backend:
|
|||||||
unsecured: bool = dataclasses.field(
|
unsecured: bool = dataclasses.field(
|
||||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||||
)
|
)
|
||||||
report_addr: str = dataclasses.field(
|
|
||||||
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
|
||||||
)
|
|
||||||
mtoken: str = dataclasses.field(
|
|
||||||
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.metrics = Metrics()
|
self.metrics = Metrics()
|
||||||
self.metrics._set_version(self.version)
|
self.metrics._set_version(self.version)
|
||||||
self.metrics._set_mtoken(self.mtoken)
|
|
||||||
self._total_pubkey_fetch_errors = 0
|
self._total_pubkey_fetch_errors = 0
|
||||||
self._pubkey = self._fetch_pubkey()
|
self._pubkey = self._fetch_pubkey()
|
||||||
self.__start_healthcheck: bool = False
|
self.__start_healthcheck: bool = False
|
||||||
@@ -111,19 +104,23 @@ class Backend:
|
|||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
def _fetch_pubkey(self):
|
def _fetch_pubkey(self):
|
||||||
report_addr = self.report_addr.rstrip("/")
|
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||||
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
result = subprocess.check_output(command, universal_newlines=True)
|
||||||
try:
|
log.debug("public key:")
|
||||||
result = subprocess.check_output(command, universal_newlines=True)
|
log.debug(result)
|
||||||
log.debug("public key:")
|
key = None
|
||||||
log.debug(result)
|
for _ in range(5):
|
||||||
key = RSA.import_key(result)
|
try:
|
||||||
if key is not None:
|
key = RSA.import_key(result)
|
||||||
return key
|
break
|
||||||
except (ValueError , subprocess.CalledProcessError) as e:
|
except ValueError as e:
|
||||||
log.debug(f"Error downloading key: {e}")
|
log.debug(f"Error downloading key: {e}")
|
||||||
self.backend_errored("Failed to get autoscaler pubkey")
|
time.sleep(15)
|
||||||
|
if key is None:
|
||||||
|
self._total_pubkey_fetch_errors += 1
|
||||||
|
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
||||||
|
self.backend_errored("Failed to get autoscaler pubkey")
|
||||||
|
return key
|
||||||
|
|
||||||
async def __handle_request(
|
async def __handle_request(
|
||||||
self,
|
self,
|
||||||
@@ -318,10 +315,10 @@ class Backend:
|
|||||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||||
log.debug("already ran benchmark")
|
log.debug("already ran benchmark")
|
||||||
# trigger model load
|
# trigger model load
|
||||||
# payload = self.benchmark_handler.make_benchmark_payload()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
# _ = await self.__call_api(
|
_ = await self.__call_api(
|
||||||
# handler=self.benchmark_handler, payload=payload
|
handler=self.benchmark_handler, payload=payload
|
||||||
# )
|
)
|
||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
@@ -396,7 +393,7 @@ class Backend:
|
|||||||
)
|
)
|
||||||
# some backends need a few seconds after logging successful startup before
|
# some backends need a few seconds after logging successful startup before
|
||||||
# they can begin accepting requests
|
# they can begin accepting requests
|
||||||
# await sleep(5)
|
await sleep(5)
|
||||||
try:
|
try:
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
self.__start_healthcheck = True
|
self.__start_healthcheck = True
|
||||||
|
|||||||
@@ -286,7 +286,6 @@ class AutoScalerData:
|
|||||||
"""Data that is reported to autoscaler"""
|
"""Data that is reported to autoscaler"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
mtoken: str
|
|
||||||
version: str
|
version: str
|
||||||
loadtime: float
|
loadtime: float
|
||||||
cur_load: float
|
cur_load: float
|
||||||
|
|||||||
+2
-16
@@ -28,7 +28,6 @@ def get_url() -> str:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Metrics:
|
class Metrics:
|
||||||
version: str = "0"
|
version: str = "0"
|
||||||
mtoken: str = ""
|
|
||||||
last_metric_update: float = 0.0
|
last_metric_update: float = 0.0
|
||||||
last_request_served: float = 0.0
|
last_request_served: float = 0.0
|
||||||
update_pending: bool = False
|
update_pending: bool = False
|
||||||
@@ -143,16 +142,12 @@ class Metrics:
|
|||||||
def _set_version(self, version: str) -> None:
|
def _set_version(self, version: str) -> None:
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|
||||||
def _set_mtoken(self, mtoken: str) -> None:
|
|
||||||
self.mtoken = mtoken
|
|
||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
|
|
||||||
async def __send_delete_requests_and_reset(self):
|
async def __send_delete_requests_and_reset(self):
|
||||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||||
data = {
|
data = {
|
||||||
"worker_id": self.id,
|
"worker_id": self.id,
|
||||||
"mtoken": self.mtoken,
|
|
||||||
"request_idxs": idxs,
|
"request_idxs": idxs,
|
||||||
"success": success_flag,
|
"success": success_flag,
|
||||||
}
|
}
|
||||||
@@ -214,7 +209,6 @@ class Metrics:
|
|||||||
def compute_autoscaler_data() -> AutoScalerData:
|
def compute_autoscaler_data() -> AutoScalerData:
|
||||||
return AutoScalerData(
|
return AutoScalerData(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
mtoken=self.mtoken,
|
|
||||||
version=self.version,
|
version=self.version,
|
||||||
loadtime=(loadtime_snapshot or 0.0),
|
loadtime=(loadtime_snapshot or 0.0),
|
||||||
new_load=self.model_metrics.workload_processing,
|
new_load=self.model_metrics.workload_processing,
|
||||||
@@ -234,25 +228,17 @@ class Metrics:
|
|||||||
|
|
||||||
async def send_data(report_addr: str) -> bool:
|
async def send_data(report_addr: str) -> bool:
|
||||||
data = compute_autoscaler_data()
|
data = compute_autoscaler_data()
|
||||||
log_data = asdict(data)
|
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||||
def obfuscate(secret: str) -> str:
|
|
||||||
if secret is None:
|
|
||||||
return ""
|
|
||||||
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
|
||||||
|
|
||||||
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
|
||||||
log.debug(
|
log.debug(
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
f"sending data to autoscaler",
|
f"sending data to autoscaler",
|
||||||
f"{json.dumps(log_data, indent=2)}",
|
f"{json.dumps((asdict(data)), indent=2)}",
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
|
||||||
for attempt in range(1, 4):
|
for attempt in range(1, 4):
|
||||||
try:
|
try:
|
||||||
session = await self.http()
|
session = await self.http()
|
||||||
|
|||||||
+25
-45
@@ -3,58 +3,38 @@ import logging
|
|||||||
from typing import List
|
from typing import List
|
||||||
import ssl
|
import ssl
|
||||||
from asyncio import run, gather
|
from asyncio import run, gather
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from lib.backend import Backend
|
from lib.backend import Backend
|
||||||
from lib.metrics import Metrics
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||||
try:
|
log.debug("getting certificate...")
|
||||||
log.debug("getting certificate...")
|
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
if use_ssl is True:
|
||||||
if use_ssl is True:
|
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
ssl_context.load_cert_chain(
|
||||||
ssl_context.load_cert_chain(
|
certfile="/etc/instance.crt",
|
||||||
certfile="/etc/instance.crt",
|
keyfile="/etc/instance.key",
|
||||||
keyfile="/etc/instance.key",
|
)
|
||||||
)
|
else:
|
||||||
else:
|
ssl_context = None
|
||||||
ssl_context = None
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
log.debug("starting server...")
|
log.debug("starting server...")
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
runner = web.AppRunner(app)
|
runner = web.AppRunner(app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(
|
site = web.TCPSite(
|
||||||
runner,
|
runner,
|
||||||
ssl_context=ssl_context,
|
ssl_context=ssl_context,
|
||||||
port=int(os.environ["WORKER_PORT"]),
|
port=int(os.environ["WORKER_PORT"]),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
await gather(site.start(), backend._start_tracking())
|
await gather(site.start(), backend._start_tracking())
|
||||||
|
|
||||||
run(main())
|
run(main())
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
err_msg = f"PyWorker failed to launch: {e}"
|
|
||||||
log.error(err_msg)
|
|
||||||
|
|
||||||
async def beacon():
|
|
||||||
metrics = Metrics()
|
|
||||||
metrics._set_version(getattr(backend, "version", "0"))
|
|
||||||
metrics._set_mtoken(getattr(backend, "mtoken", ""))
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
metrics._model_errored(err_msg)
|
|
||||||
await metrics._Metrics__send_metrics_and_reset()
|
|
||||||
await asyncio.sleep(10)
|
|
||||||
finally:
|
|
||||||
await metrics.aclose()
|
|
||||||
|
|
||||||
run(beacon())
|
|
||||||
|
|||||||
+2
-41
@@ -128,44 +128,5 @@ echo "launching PyWorker server"
|
|||||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||||
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
||||||
|
|
||||||
|
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
||||||
set +e
|
echo "launching PyWorker server done"
|
||||||
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
|
|
||||||
PY_STATUS=${PIPESTATUS[0]}
|
|
||||||
set -e
|
|
||||||
|
|
||||||
if [ "${PY_STATUS}" -ne 0 ]; then
|
|
||||||
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
|
|
||||||
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
|
|
||||||
MTOKEN="${MASTER_TOKEN:-}"
|
|
||||||
VERSION="${PYWORKER_VERSION:-0}"
|
|
||||||
|
|
||||||
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
|
|
||||||
for addr in "${REPORT_ADDRS[@]}"; do
|
|
||||||
curl -sS -X POST -H 'Content-Type: application/json' \
|
|
||||||
-d "$(cat <<JSON
|
|
||||||
{
|
|
||||||
"id": ${CONTAINER_ID:-0},
|
|
||||||
"mtoken": "${MTOKEN}",
|
|
||||||
"version": "${VERSION}",
|
|
||||||
"loadtime": 0,
|
|
||||||
"new_load": 0,
|
|
||||||
"cur_load": 0,
|
|
||||||
"rej_load": 0,
|
|
||||||
"max_perf": 0,
|
|
||||||
"cur_perf": 0,
|
|
||||||
"error_msg": "${ERROR_MSG}",
|
|
||||||
"num_requests_working": 0,
|
|
||||||
"num_requests_recieved": 0,
|
|
||||||
"additional_disk_usage": 0,
|
|
||||||
"working_request_idxs": [],
|
|
||||||
"cur_capacity": 0,
|
|
||||||
"max_capacity": 0,
|
|
||||||
"url": "${URL}"
|
|
||||||
}
|
|
||||||
JSON
|
|
||||||
)" "${addr%/}/worker_status/" || true
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "launching PyWorker server done"
|
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ def call_text2image_workflow(
|
|||||||
endpoint=route_response["endpoint"],
|
endpoint=route_response["endpoint"],
|
||||||
reqnum=route_response["reqnum"],
|
reqnum=route_response["reqnum"],
|
||||||
url=route_response["url"],
|
url=route_response["url"],
|
||||||
request_idx=route_response["request_idx"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the payload for the worker request
|
# Build the payload for the worker request
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ def call_custom_workflow_for_sd3(
|
|||||||
endpoint=message["endpoint"],
|
endpoint=message["endpoint"],
|
||||||
reqnum=message["reqnum"],
|
reqnum=message["reqnum"],
|
||||||
url=message["url"],
|
url=message["url"],
|
||||||
request_idx=message["request_idx"],
|
|
||||||
)
|
)
|
||||||
workflow = {
|
workflow = {
|
||||||
"3": {
|
"3": {
|
||||||
|
|||||||
Reference in New Issue
Block a user