Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a47c9d1ed0 | |||
| 0b14562a63 | |||
| de9b50abb9 | |||
| c510801723 | |||
| a12523b1d2 | |||
| 7db54f3bd7 | |||
| d63a060202 | |||
| c6521cb6d4 | |||
| b7fe4ebb91 | |||
| 8ae7b74605 | |||
| 106067d716 | |||
| f5134d4bf5 | |||
| 47e5460532 | |||
| ec2ac0a21a | |||
| 2cde573c56 | |||
| b2e4a5db0c | |||
| 7437028cb2 | |||
| 02c8307af7 | |||
| 7c0f316eeb | |||
| b4025a744f | |||
| d190308329 | |||
| 9f5a432513 | |||
| e09f1fa953 | |||
| ba6f1c2e4b | |||
| 944f83fc03 | |||
| 298590fb88 | |||
| 814c3acd4c | |||
| 22bca74087 | |||
| 9c795e2a01 | |||
| 830b532781 | |||
| d6a6e34c6b | |||
| ac1e109c48 | |||
| d6eb498ee4 | |||
| f56bbc0ebe | |||
| 70d51bafe1 | |||
| 63909736bb | |||
| f4f7080df1 | |||
| d51a338e8f | |||
| 92a04bd7af | |||
| c98d661513 | |||
| f6fd1c6ac1 | |||
| 055e346c8c | |||
| 1cedb28acf | |||
| ec25dda3ad | |||
| 0397af719d | |||
| 3786cf978d | |||
| a86d4bcf9c | |||
| e9b6a14a5e | |||
| cadac033e1 |
+31
-25
@@ -30,7 +30,7 @@ from lib.data_types import (
|
|||||||
BenchmarkResult
|
BenchmarkResult
|
||||||
)
|
)
|
||||||
|
|
||||||
VERSION = "0.1.0"
|
VERSION = "0.2.0"
|
||||||
|
|
||||||
MSG_HISTORY_LEN = 100
|
MSG_HISTORY_LEN = 100
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
@@ -66,10 +66,17 @@ class Backend:
|
|||||||
unsecured: bool = dataclasses.field(
|
unsecured: bool = dataclasses.field(
|
||||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||||
)
|
)
|
||||||
|
report_addr: str = dataclasses.field(
|
||||||
|
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
||||||
|
)
|
||||||
|
mtoken: str = dataclasses.field(
|
||||||
|
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.metrics = Metrics()
|
self.metrics = Metrics()
|
||||||
self.metrics._set_version(self.version)
|
self.metrics._set_version(self.version)
|
||||||
|
self.metrics._set_mtoken(self.mtoken)
|
||||||
self._total_pubkey_fetch_errors = 0
|
self._total_pubkey_fetch_errors = 0
|
||||||
self._pubkey = self._fetch_pubkey()
|
self._pubkey = self._fetch_pubkey()
|
||||||
self.__start_healthcheck: bool = False
|
self.__start_healthcheck: bool = False
|
||||||
@@ -104,23 +111,19 @@ class Backend:
|
|||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
def _fetch_pubkey(self):
|
def _fetch_pubkey(self):
|
||||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
report_addr = self.report_addr.rstrip("/")
|
||||||
result = subprocess.check_output(command, universal_newlines=True)
|
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
||||||
log.debug("public key:")
|
try:
|
||||||
log.debug(result)
|
result = subprocess.check_output(command, universal_newlines=True)
|
||||||
key = None
|
log.debug("public key:")
|
||||||
for _ in range(5):
|
log.debug(result)
|
||||||
try:
|
key = RSA.import_key(result)
|
||||||
key = RSA.import_key(result)
|
if key is not None:
|
||||||
break
|
return key
|
||||||
except ValueError as e:
|
except (ValueError , subprocess.CalledProcessError) as e:
|
||||||
log.debug(f"Error downloading key: {e}")
|
log.debug(f"Error downloading key: {e}")
|
||||||
time.sleep(15)
|
self.backend_errored("Failed to get autoscaler pubkey")
|
||||||
if key is None:
|
|
||||||
self._total_pubkey_fetch_errors += 1
|
|
||||||
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
|
||||||
self.backend_errored("Failed to get autoscaler pubkey")
|
|
||||||
return key
|
|
||||||
|
|
||||||
async def __handle_request(
|
async def __handle_request(
|
||||||
self,
|
self,
|
||||||
@@ -286,7 +289,7 @@ class Backend:
|
|||||||
message = {
|
message = {
|
||||||
key: value
|
key: value
|
||||||
for (key, value) in (dataclasses.asdict(auth_data).items())
|
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||||
if key != "signature"
|
if key != "signature" and key != "__request_id"
|
||||||
}
|
}
|
||||||
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -296,7 +299,7 @@ class Backend:
|
|||||||
elif message in self.msg_history:
|
elif message in self.msg_history:
|
||||||
log.debug(f"message: {message} already in message history")
|
log.debug(f"message: {message} already in message history")
|
||||||
return False
|
return False
|
||||||
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
|
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
|
||||||
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||||
self.msg_history.append(message)
|
self.msg_history.append(message)
|
||||||
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||||
@@ -315,10 +318,10 @@ class Backend:
|
|||||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||||
log.debug("already ran benchmark")
|
log.debug("already ran benchmark")
|
||||||
# trigger model load
|
# trigger model load
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
_ = await self.__call_api(
|
# _ = await self.__call_api(
|
||||||
handler=self.benchmark_handler, payload=payload
|
# handler=self.benchmark_handler, payload=payload
|
||||||
)
|
# )
|
||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
@@ -350,6 +353,9 @@ class Backend:
|
|||||||
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
||||||
time_elapsed = time.time() - start
|
time_elapsed = time.time() - start
|
||||||
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
||||||
|
if successful_responses == 0:
|
||||||
|
self.backend_errored("No successful responses from benchmark")
|
||||||
|
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
|
||||||
|
|
||||||
throughput = total_workload / time_elapsed
|
throughput = total_workload / time_elapsed
|
||||||
sum_throughput += throughput
|
sum_throughput += throughput
|
||||||
@@ -390,7 +396,7 @@ class Backend:
|
|||||||
)
|
)
|
||||||
# some backends need a few seconds after logging successful startup before
|
# some backends need a few seconds after logging successful startup before
|
||||||
# they can begin accepting requests
|
# they can begin accepting requests
|
||||||
await sleep(5)
|
# await sleep(5)
|
||||||
try:
|
try:
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
self.__start_healthcheck = True
|
self.__start_healthcheck = True
|
||||||
|
|||||||
+7
-5
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
|
|||||||
class AuthData:
|
class AuthData:
|
||||||
"""data used to authenticate requester"""
|
"""data used to authenticate requester"""
|
||||||
|
|
||||||
signature: str
|
|
||||||
cost: str
|
cost: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
reqnum: int
|
reqnum: int
|
||||||
url: str
|
|
||||||
request_idx: int
|
request_idx: int
|
||||||
|
signature: str
|
||||||
|
url: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
||||||
@@ -190,11 +190,12 @@ class SystemMetrics:
|
|||||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||||
self.last_disk_usage = disk_usage
|
self.last_disk_usage = disk_usage
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, expected: float | None) -> None:
|
||||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||||
# as well: they should send model_loading_time once when they are done loading
|
# as well: they should send model_loading_time once when they are done loading
|
||||||
self.model_loading_time = None
|
if self.model_loading_time == expected:
|
||||||
|
self.model_loading_time = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -257,7 +258,7 @@ class ModelMetrics:
|
|||||||
def wait_time(self) -> float:
|
def wait_time(self) -> float:
|
||||||
if (len(self.requests_working) == 0):
|
if (len(self.requests_working) == 0):
|
||||||
return 0.0
|
return 0.0
|
||||||
return sum([request.workload for request in self.requests_working.values()]) / self.max_throughput
|
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cur_load(self) -> float:
|
def cur_load(self) -> float:
|
||||||
@@ -285,6 +286,7 @@ class AutoScalerData:
|
|||||||
"""Data that is reported to autoscaler"""
|
"""Data that is reported to autoscaler"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
|
mtoken: str
|
||||||
version: str
|
version: str
|
||||||
loadtime: float
|
loadtime: float
|
||||||
cur_load: float
|
cur_load: float
|
||||||
|
|||||||
+65
-18
@@ -28,6 +28,7 @@ def get_url() -> str:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Metrics:
|
class Metrics:
|
||||||
version: str = "0"
|
version: str = "0"
|
||||||
|
mtoken: str = ""
|
||||||
last_metric_update: float = 0.0
|
last_metric_update: float = 0.0
|
||||||
last_request_served: float = 0.0
|
last_request_served: float = 0.0
|
||||||
update_pending: bool = False
|
update_pending: bool = False
|
||||||
@@ -142,17 +143,22 @@ class Metrics:
|
|||||||
def _set_version(self, version: str) -> None:
|
def _set_version(self, version: str) -> None:
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|
||||||
|
def _set_mtoken(self, mtoken: str) -> None:
|
||||||
|
self.mtoken = mtoken
|
||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
|
|
||||||
async def __send_delete_requests_and_reset(self):
|
async def __send_delete_requests_and_reset(self):
|
||||||
|
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||||
async def send_data(report_addr: str, success: bool) -> bool:
|
|
||||||
data = {
|
data = {
|
||||||
"worker_id": self.id,
|
"worker_id": self.id,
|
||||||
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success],
|
"mtoken": self.mtoken,
|
||||||
"success": success
|
"request_idxs": idxs,
|
||||||
|
"success": success_flag,
|
||||||
}
|
}
|
||||||
log.debug(f"Deleting requests that {'succeeded' if success else 'failed'}: {data['request_idxs']}")
|
log.debug(
|
||||||
|
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
|
||||||
|
)
|
||||||
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
||||||
for attempt in range(1, 4):
|
for attempt in range(1, 4):
|
||||||
try:
|
try:
|
||||||
@@ -162,26 +168,55 @@ class Metrics:
|
|||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
return True
|
return True
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
log.debug(f"delete_requests timed out")
|
log.debug("delete_requests timed out")
|
||||||
except (ClientResponseError, Exception) as e:
|
except (ClientResponseError, Exception) as e:
|
||||||
log.debug(f"delete_requests failed with error: {e}")
|
log.debug(f"delete_requests failed with error: {e}")
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
log.debug(f"retrying delete_request, attempt: {attempt}")
|
log.debug(f"retrying delete_request, attempt: {attempt}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Take a snapshot of what we plan to send this tick.
|
||||||
|
# New arrivals after this snapshot will remain in the queue for the next tick.
|
||||||
|
snapshot = list(self.model_metrics.requests_deleting)
|
||||||
|
success_idxs = [r.request_idx for r in snapshot if r.success is True]
|
||||||
|
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
|
||||||
|
|
||||||
|
if not success_idxs and not failed_idxs:
|
||||||
|
return # nothing to do
|
||||||
|
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
success = await send_data(report_addr, success=True) and await send_data(report_addr, success=False)
|
# TODO: Add a Redis subscriber queue for delete_requests
|
||||||
if success is True:
|
if report_addr == "https://cloud.vast.ai/api/v0":
|
||||||
self.model_metrics.requests_deleting.clear()
|
# Patch: ignore the Redis API report_addr
|
||||||
|
continue
|
||||||
|
sent_success = True
|
||||||
|
sent_failed = True
|
||||||
|
|
||||||
|
if success_idxs:
|
||||||
|
sent_success = await post(report_addr, success_idxs, True)
|
||||||
|
if failed_idxs:
|
||||||
|
sent_failed = await post(report_addr, failed_idxs, False)
|
||||||
|
|
||||||
|
if sent_success and sent_failed:
|
||||||
|
# Remove only the items we actually sent from the live queue.
|
||||||
|
sent_set = set(success_idxs) | set(failed_idxs)
|
||||||
|
self.model_metrics.requests_deleting[:] = [
|
||||||
|
r for r in self.model_metrics.requests_deleting
|
||||||
|
if r.request_idx not in sent_set
|
||||||
|
]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
async def __send_metrics_and_reset(self):
|
async def __send_metrics_and_reset(self):
|
||||||
|
|
||||||
|
loadtime_snapshot = self.system_metrics.model_loading_time
|
||||||
|
|
||||||
def compute_autoscaler_data() -> AutoScalerData:
|
def compute_autoscaler_data() -> AutoScalerData:
|
||||||
return AutoScalerData(
|
return AutoScalerData(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
|
mtoken=self.mtoken,
|
||||||
version=self.version,
|
version=self.version,
|
||||||
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
loadtime=(loadtime_snapshot or 0.0),
|
||||||
new_load=self.model_metrics.workload_processing,
|
new_load=self.model_metrics.workload_processing,
|
||||||
cur_load=self.model_metrics.cur_load,
|
cur_load=self.model_metrics.cur_load,
|
||||||
rej_load=self.model_metrics.workload_rejected,
|
rej_load=self.model_metrics.workload_rejected,
|
||||||
@@ -199,17 +234,25 @@ class Metrics:
|
|||||||
|
|
||||||
async def send_data(report_addr: str) -> bool:
|
async def send_data(report_addr: str) -> bool:
|
||||||
data = compute_autoscaler_data()
|
data = compute_autoscaler_data()
|
||||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
log_data = asdict(data)
|
||||||
|
def obfuscate(secret: str) -> str:
|
||||||
|
if secret is None:
|
||||||
|
return ""
|
||||||
|
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
||||||
|
|
||||||
|
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
||||||
log.debug(
|
log.debug(
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
f"sending data to autoscaler",
|
f"sending data to autoscaler",
|
||||||
f"{json.dumps((asdict(data)), indent=2)}",
|
f"{json.dumps(log_data, indent=2)}",
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||||
for attempt in range(1, 4):
|
for attempt in range(1, 4):
|
||||||
try:
|
try:
|
||||||
session = await self.http()
|
session = await self.http()
|
||||||
@@ -229,11 +272,15 @@ class Metrics:
|
|||||||
|
|
||||||
self.system_metrics.update_disk_usage()
|
self.system_metrics.update_disk_usage()
|
||||||
|
|
||||||
|
sent = False
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
success = await send_data(report_addr)
|
if await send_data(report_addr):
|
||||||
if success is True:
|
sent = True
|
||||||
break
|
break
|
||||||
self.update_pending = False
|
|
||||||
self.model_metrics.reset()
|
if sent:
|
||||||
self.system_metrics.reset()
|
# clear the one-shot loadtime only if we actually sent *this* value
|
||||||
self.last_metric_update = time.time()
|
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||||
|
self.update_pending = False
|
||||||
|
self.model_metrics.reset()
|
||||||
|
self.last_metric_update = time.time()
|
||||||
|
|||||||
+45
-25
@@ -3,38 +3,58 @@ import logging
|
|||||||
from typing import List
|
from typing import List
|
||||||
import ssl
|
import ssl
|
||||||
from asyncio import run, gather
|
from asyncio import run, gather
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from lib.backend import Backend
|
from lib.backend import Backend
|
||||||
|
from lib.metrics import Metrics
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||||
log.debug("getting certificate...")
|
try:
|
||||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
log.debug("getting certificate...")
|
||||||
if use_ssl is True:
|
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
if use_ssl is True:
|
||||||
ssl_context.load_cert_chain(
|
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
certfile="/etc/instance.crt",
|
ssl_context.load_cert_chain(
|
||||||
keyfile="/etc/instance.key",
|
certfile="/etc/instance.crt",
|
||||||
)
|
keyfile="/etc/instance.key",
|
||||||
else:
|
)
|
||||||
ssl_context = None
|
else:
|
||||||
|
ssl_context = None
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
log.debug("starting server...")
|
log.debug("starting server...")
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
runner = web.AppRunner(app)
|
runner = web.AppRunner(app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(
|
site = web.TCPSite(
|
||||||
runner,
|
runner,
|
||||||
ssl_context=ssl_context,
|
ssl_context=ssl_context,
|
||||||
port=int(os.environ["WORKER_PORT"]),
|
port=int(os.environ["WORKER_PORT"]),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
await gather(site.start(), backend._start_tracking())
|
await gather(site.start(), backend._start_tracking())
|
||||||
|
|
||||||
run(main())
|
run(main())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
err_msg = f"PyWorker failed to launch: {e}"
|
||||||
|
log.error(err_msg)
|
||||||
|
|
||||||
|
async def beacon():
|
||||||
|
metrics = Metrics()
|
||||||
|
metrics._set_version(getattr(backend, "version", "0"))
|
||||||
|
metrics._set_mtoken(getattr(backend, "mtoken", ""))
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
metrics._model_errored(err_msg)
|
||||||
|
await metrics._Metrics__send_metrics_and_reset()
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
finally:
|
||||||
|
await metrics.aclose()
|
||||||
|
|
||||||
|
run(beacon())
|
||||||
|
|||||||
+42
-3
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
|
|||||||
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
||||||
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
||||||
|
|
||||||
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
|
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
|
||||||
USE_SSL="${USE_SSL:-true}"
|
USE_SSL="${USE_SSL:-true}"
|
||||||
WORKER_PORT="${WORKER_PORT:-3000}"
|
WORKER_PORT="${WORKER_PORT:-3000}"
|
||||||
mkdir -p "$WORKSPACE_DIR"
|
mkdir -p "$WORKSPACE_DIR"
|
||||||
@@ -128,5 +128,44 @@ echo "launching PyWorker server"
|
|||||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||||
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
||||||
|
|
||||||
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
|
||||||
echo "launching PyWorker server done"
|
set +e
|
||||||
|
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
|
||||||
|
PY_STATUS=${PIPESTATUS[0]}
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if [ "${PY_STATUS}" -ne 0 ]; then
|
||||||
|
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
|
||||||
|
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
|
||||||
|
MTOKEN="${MASTER_TOKEN:-}"
|
||||||
|
VERSION="${PYWORKER_VERSION:-0}"
|
||||||
|
|
||||||
|
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
|
||||||
|
for addr in "${REPORT_ADDRS[@]}"; do
|
||||||
|
curl -sS -X POST -H 'Content-Type: application/json' \
|
||||||
|
-d "$(cat <<JSON
|
||||||
|
{
|
||||||
|
"id": ${CONTAINER_ID:-0},
|
||||||
|
"mtoken": "${MTOKEN}",
|
||||||
|
"version": "${VERSION}",
|
||||||
|
"loadtime": 0,
|
||||||
|
"new_load": 0,
|
||||||
|
"cur_load": 0,
|
||||||
|
"rej_load": 0,
|
||||||
|
"max_perf": 0,
|
||||||
|
"cur_perf": 0,
|
||||||
|
"error_msg": "${ERROR_MSG}",
|
||||||
|
"num_requests_working": 0,
|
||||||
|
"num_requests_recieved": 0,
|
||||||
|
"additional_disk_usage": 0,
|
||||||
|
"working_request_idxs": [],
|
||||||
|
"cur_capacity": 0,
|
||||||
|
"max_capacity": 0,
|
||||||
|
"url": "${URL}"
|
||||||
|
}
|
||||||
|
JSON
|
||||||
|
)" "${addr%/}/worker_status/" || true
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "launching PyWorker server done"
|
||||||
@@ -12,9 +12,21 @@ A docker image is provided but you may use any if the above requirements are met
|
|||||||
|
|
||||||
## Benchmarking
|
## Benchmarking
|
||||||
|
|
||||||
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
|
### Custom Benchmark Workflows
|
||||||
|
|
||||||
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
|
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
|
||||||
|
|
||||||
|
**Ways to provide the benchmark file:**
|
||||||
|
- Fork this repository and add your `benchmark.json` file
|
||||||
|
- Write the file during worker provisioning (onstart script or setup phase)
|
||||||
|
|
||||||
|
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
|
||||||
|
|
||||||
|
### Default Benchmark (Fallback)
|
||||||
|
|
||||||
|
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
|
||||||
|
|
||||||
|
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
|
||||||
|
|
||||||
| Environment Variable | Default Value | Description |
|
| Environment Variable | Default Value | Description |
|
||||||
| -------------------- | ------------- | ----------- |
|
| -------------------- | ------------- | ----------- |
|
||||||
@@ -24,7 +36,7 @@ The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image wo
|
|||||||
|
|
||||||
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
||||||
|
|
||||||
### Calibrating Benchmark Duration
|
#### Calibrating Fallback Benchmark Duration
|
||||||
|
|
||||||
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
||||||
|
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ def call_text2image_workflow(
|
|||||||
endpoint=route_response["endpoint"],
|
endpoint=route_response["endpoint"],
|
||||||
reqnum=route_response["reqnum"],
|
reqnum=route_response["reqnum"],
|
||||||
url=route_response["url"],
|
url=route_response["url"],
|
||||||
|
request_idx=route_response["request_idx"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the payload for the worker request
|
# Build the payload for the worker request
|
||||||
|
|||||||
@@ -5,12 +5,13 @@ import dataclasses
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from lib.data_types import ApiPayload, JsonDataException
|
from lib.data_types import ApiPayload, JsonDataException
|
||||||
|
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
|
|
||||||
test_prompts = f.readlines()
|
|
||||||
|
|
||||||
def count_workload() -> float:
|
def count_workload() -> float:
|
||||||
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
||||||
@@ -24,9 +25,32 @@ class ComfyWorkflowData(ApiPayload):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls):
|
def for_test(cls):
|
||||||
"""
|
"""
|
||||||
Use the variables available to simulate workflows of the required running time
|
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
|
||||||
|
Otherwise, use the variables available to simulate workflows of the required running time
|
||||||
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
||||||
"""
|
"""
|
||||||
|
# Try to load benchmark.json
|
||||||
|
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
|
||||||
|
|
||||||
|
if benchmark_file.exists():
|
||||||
|
try:
|
||||||
|
with open(benchmark_file, "r") as f:
|
||||||
|
benchmark_workflow = json.load(f)
|
||||||
|
return cls(
|
||||||
|
input={
|
||||||
|
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||||
|
"workflow_json": benchmark_workflow
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, IOError):
|
||||||
|
# JSON is malformed or file can't be read, fall through to default
|
||||||
|
log.error(f"Failed to benchmark using {benchmark_file}")
|
||||||
|
|
||||||
|
# Fallback: read prompts and construct payload
|
||||||
|
log.info("Using fallback method for benchmarking")
|
||||||
|
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
|
||||||
|
test_prompts = f.readlines()
|
||||||
|
|
||||||
test_prompt = random.choice(test_prompts).rstrip()
|
test_prompt = random.choice(test_prompts).rstrip()
|
||||||
return cls(
|
return cls(
|
||||||
input={
|
input={
|
||||||
|
|||||||
@@ -0,0 +1,107 @@
|
|||||||
|
{
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"seed": "__RANDOM_INT__",
|
||||||
|
"steps": 20,
|
||||||
|
"cfg": 8,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"denoise": 1,
|
||||||
|
"model": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"7",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"5",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"_meta": {
|
||||||
|
"title": "KSampler"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Checkpoint"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
"batch_size": 1
|
||||||
|
},
|
||||||
|
"class_type": "EmptyLatentImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Empty Latent Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "text, watermark",
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"4",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE Decode"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "ComfyUI",
|
||||||
|
"images": [
|
||||||
|
"8",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Save Image"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
|||||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||||
"Value not in list: ", # This error is emitted when the model file is not there at all
|
"Value not in list: ", # This error is emitted when the model file is not there at all
|
||||||
|
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ def call_custom_workflow_for_sd3(
|
|||||||
endpoint=message["endpoint"],
|
endpoint=message["endpoint"],
|
||||||
reqnum=message["reqnum"],
|
reqnum=message["reqnum"],
|
||||||
url=message["url"],
|
url=message["url"],
|
||||||
|
request_idx=message["request_idx"],
|
||||||
)
|
)
|
||||||
workflow = {
|
workflow = {
|
||||||
"3": {
|
"3": {
|
||||||
|
|||||||
Reference in New Issue
Block a user