Compare commits

..

1 Commits

Author SHA1 Message Date
Edgar Lin 7d43bc8d68 remove redis pubsub from pyworker 2025-10-29 11:46:31 -07:00
5 changed files with 22 additions and 25 deletions
+22 -21
View File
@@ -66,9 +66,6 @@ class Backend:
unsecured: bool = dataclasses.field( unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
) )
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
def __post_init__(self): def __post_init__(self):
self.metrics = Metrics() self.metrics = Metrics()
@@ -107,19 +104,23 @@ class Backend:
#######################################Private####################################### #######################################Private#######################################
def _fetch_pubkey(self): def _fetch_pubkey(self):
report_addr = self.report_addr.rstrip("/") command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"] result = subprocess.check_output(command, universal_newlines=True)
try: log.debug("public key:")
result = subprocess.check_output(command, universal_newlines=True) log.debug(result)
log.debug("public key:") key = None
log.debug(result) for _ in range(5):
key = RSA.import_key(result) try:
if key is not None: key = RSA.import_key(result)
return key break
except (ValueError , subprocess.CalledProcessError) as e: except ValueError as e:
log.debug(f"Error downloading key: {e}") log.debug(f"Error downloading key: {e}")
self.backend_errored("Failed to get autoscaler pubkey") time.sleep(15)
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
async def __handle_request( async def __handle_request(
self, self,
@@ -314,10 +315,10 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f: with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark") log.debug("already ran benchmark")
# trigger model load # trigger model load
# payload = self.benchmark_handler.make_benchmark_payload() payload = self.benchmark_handler.make_benchmark_payload()
# _ = await self.__call_api( _ = await self.__call_api(
# handler=self.benchmark_handler, payload=payload handler=self.benchmark_handler, payload=payload
# ) )
return float(f.readline()) return float(f.readline())
except FileNotFoundError: except FileNotFoundError:
pass pass
@@ -392,7 +393,7 @@ class Backend:
) )
# some backends need a few seconds after logging successful startup before # some backends need a few seconds after logging successful startup before
# they can begin accepting requests # they can begin accepting requests
# await sleep(5) await sleep(5)
try: try:
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
self.__start_healthcheck = True self.__start_healthcheck = True
-1
View File
@@ -98,7 +98,6 @@ def call_text2image_workflow(
endpoint=route_response["endpoint"], endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"], reqnum=route_response["reqnum"],
url=route_response["url"], url=route_response["url"],
request_idx=route_response["request_idx"],
) )
# Build the payload for the worker request # Build the payload for the worker request
-1
View File
@@ -82,7 +82,6 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"], endpoint=message["endpoint"],
reqnum=message["reqnum"], reqnum=message["reqnum"],
url=message["url"], url=message["url"],
request_idx=message["request_idx"],
) )
workflow = { workflow = {
"3": { "3": {
-1
View File
@@ -43,7 +43,6 @@ backend = Backend(
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
], ],
], ],
max_wait_time=600
) )
-1
View File
@@ -113,7 +113,6 @@ backend = Backend(
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
], ],
], ],
max_wait_time=600
) )