Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7d43bc8d68 |
+22
-21
@@ -66,9 +66,6 @@ class Backend:
|
|||||||
unsecured: bool = dataclasses.field(
|
unsecured: bool = dataclasses.field(
|
||||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||||
)
|
)
|
||||||
report_addr: str = dataclasses.field(
|
|
||||||
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.metrics = Metrics()
|
self.metrics = Metrics()
|
||||||
@@ -107,19 +104,23 @@ class Backend:
|
|||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
def _fetch_pubkey(self):
|
def _fetch_pubkey(self):
|
||||||
report_addr = self.report_addr.rstrip("/")
|
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||||
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
result = subprocess.check_output(command, universal_newlines=True)
|
||||||
try:
|
log.debug("public key:")
|
||||||
result = subprocess.check_output(command, universal_newlines=True)
|
log.debug(result)
|
||||||
log.debug("public key:")
|
key = None
|
||||||
log.debug(result)
|
for _ in range(5):
|
||||||
key = RSA.import_key(result)
|
try:
|
||||||
if key is not None:
|
key = RSA.import_key(result)
|
||||||
return key
|
break
|
||||||
except (ValueError , subprocess.CalledProcessError) as e:
|
except ValueError as e:
|
||||||
log.debug(f"Error downloading key: {e}")
|
log.debug(f"Error downloading key: {e}")
|
||||||
self.backend_errored("Failed to get autoscaler pubkey")
|
time.sleep(15)
|
||||||
|
if key is None:
|
||||||
|
self._total_pubkey_fetch_errors += 1
|
||||||
|
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
||||||
|
self.backend_errored("Failed to get autoscaler pubkey")
|
||||||
|
return key
|
||||||
|
|
||||||
async def __handle_request(
|
async def __handle_request(
|
||||||
self,
|
self,
|
||||||
@@ -314,10 +315,10 @@ class Backend:
|
|||||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||||
log.debug("already ran benchmark")
|
log.debug("already ran benchmark")
|
||||||
# trigger model load
|
# trigger model load
|
||||||
# payload = self.benchmark_handler.make_benchmark_payload()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
# _ = await self.__call_api(
|
_ = await self.__call_api(
|
||||||
# handler=self.benchmark_handler, payload=payload
|
handler=self.benchmark_handler, payload=payload
|
||||||
# )
|
)
|
||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
@@ -392,7 +393,7 @@ class Backend:
|
|||||||
)
|
)
|
||||||
# some backends need a few seconds after logging successful startup before
|
# some backends need a few seconds after logging successful startup before
|
||||||
# they can begin accepting requests
|
# they can begin accepting requests
|
||||||
# await sleep(5)
|
await sleep(5)
|
||||||
try:
|
try:
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
self.__start_healthcheck = True
|
self.__start_healthcheck = True
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ def call_text2image_workflow(
|
|||||||
endpoint=route_response["endpoint"],
|
endpoint=route_response["endpoint"],
|
||||||
reqnum=route_response["reqnum"],
|
reqnum=route_response["reqnum"],
|
||||||
url=route_response["url"],
|
url=route_response["url"],
|
||||||
request_idx=route_response["request_idx"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the payload for the worker request
|
# Build the payload for the worker request
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ def call_custom_workflow_for_sd3(
|
|||||||
endpoint=message["endpoint"],
|
endpoint=message["endpoint"],
|
||||||
reqnum=message["reqnum"],
|
reqnum=message["reqnum"],
|
||||||
url=message["url"],
|
url=message["url"],
|
||||||
request_idx=message["request_idx"],
|
|
||||||
)
|
)
|
||||||
workflow = {
|
workflow = {
|
||||||
"3": {
|
"3": {
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ backend = Backend(
|
|||||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
max_wait_time=600
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -113,7 +113,6 @@ backend = Backend(
|
|||||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
max_wait_time=600
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user