From b1ca68c349494f3f73d9f1e0c49dd531b4b9de17 Mon Sep 17 00:00:00 2001 From: Abiola Akinnubi Date: Thu, 29 May 2025 17:22:31 -0700 Subject: [PATCH] Added endpoint flexibility along with existing log. extended the log support Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors --- lib/backend.py | 27 ++++++++++++++++++++++++++- lib/data_types.py | 19 ++++++++++++++----- workers/comfyui/data_types.py | 2 +- workers/comfyui/server.py | 14 +++++++++++--- workers/hello_world/server.py | 14 +++++++++++--- workers/tgi/server.py | 17 ++++++++++++++--- 6 files changed, 77 insertions(+), 16 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 416e0ac..a4dc756 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -186,8 +186,33 @@ class Backend: log.debug(f"Exception in main handler loop {e}") return web.Response(status=500) + + async def __healthcheck(self): + health_check_url = self.benchmark_handler.healthcheck_endpoint + if health_check_url is None: + log.debug("No healthcheck endpoint defined, skipping healthcheck") + return + await sleep(5) + try: + log.debug(f"Performing healthcheck on {health_check_url}") + async with self.session.get(health_check_url) as response: + if response.status == 200: + log.debug("Healthcheck successful") + elif response.status == 503: + log.debug(f"Healthcheck failed with status: {response.status}") + self.backend_errored(f"Healthcheck failed with status: {response.status}") + else: + # endpoint not ready yet so bail + log.debug( + f"Healthcheck Endpoint not ready: {response.status}" + ) + except Exception as e: + log.debug(f"Healthcheck failed with exception: {e}") + self.backend_errored(str(e)) + + async def _start_tracking(self) -> None: - await gather(self.__read_logs(), self.metrics._send_metrics_loop()) + await gather(self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()) def backend_errored(self, msg: str) -> None: self.metrics._model_errored(msg) diff --git a/lib/data_types.py b/lib/data_types.py index 654ad56..c90d472 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -22,13 +22,14 @@ class JsonDataException(Exception): def __init__(self, json_msg: Dict[str, Any]): self.message = json_msg +ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload") @dataclass class ApiPayload(ABC): @classmethod @abstractmethod - def for_test(cls) -> "ApiPayload": + def for_test(cls: Type[ApiPayload_T]) -> ApiPayload_T: """defines how create a payload for load testing""" pass @@ -44,7 +45,7 @@ class ApiPayload(ABC): @classmethod @abstractmethod - def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ApiPayload": + def from_json_msg(cls: Type[ApiPayload_T], json_msg: Dict[str, Any]) -> ApiPayload_T: """ defines how to create an API payload from a JSON message, it should throw an JsonDataException if there are issues with some fields @@ -83,7 +84,6 @@ class AuthData: ) -ApiPayload_T = TypeVar("ApiPayload_T", bound=ApiPayload) @dataclass @@ -102,6 +102,13 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]): """the endpoint on the model API""" pass + @property + @abstractmethod + def healthcheck_endpoint(self) -> Optional[str]: + """the endpoint on the model API that is used for healthchecks""" + pass + + @classmethod @abstractmethod def payload_cls(cls) -> Type[ApiPayload_T]: @@ -127,7 +134,8 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]): cls, req_data: Dict[str, Any] ) -> Tuple[AuthData, ApiPayload_T]: errors = {} - auth_data = payload = None + auth_data: Optional[AuthData] = None + payload: Optional[ApiPayload_T] = None try: if "auth_data" in req_data: auth_data = AuthData.from_json_msg(req_data["auth_data"]) @@ -137,7 +145,8 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]): errors["auth_data"] = e.message try: if "payload" in req_data: - payload = cls.payload_cls().from_json_msg(req_data["payload"]) + payload_cls = cls.payload_cls() + payload = payload_cls.from_json_msg(req_data["payload"]) else: errors["payload"] = "field missing" except JsonDataException as e: diff --git a/workers/comfyui/data_types.py b/workers/comfyui/data_types.py index f228593..82d9c6a 100644 --- a/workers/comfyui/data_types.py +++ b/workers/comfyui/data_types.py @@ -174,7 +174,7 @@ class CustomComfyWorkflowData(ApiPayload): @classmethod def for_test(cls): - raise NotImplemented("Custom comfy workflow is not used for testing") + raise NotImplementedError("Custom comfy workflow is not used for testing") def count_workload(self) -> float: return count_workload( diff --git a/workers/comfyui/server.py b/workers/comfyui/server.py index 8abb04d..a999f74 100644 --- a/workers/comfyui/server.py +++ b/workers/comfyui/server.py @@ -2,7 +2,7 @@ import os import logging import dataclasses import base64 -from typing import Union, Type +from typing import Optional, Union, Type from aiohttp import web, ClientResponse from anyio import open_file @@ -69,7 +69,11 @@ class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]): @property def endpoint(self) -> str: return "/runsync" - + + @property + def healthcheck_endpoint(self) -> Optional[str]: + return None + @classmethod def payload_cls(cls) -> Type[DefaultComfyWorkflowData]: return DefaultComfyWorkflowData @@ -89,7 +93,11 @@ class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]): @property def endpoint(self) -> str: return "/runsync" - + + @property + def healthcheck_endpoint(self) -> Optional[str]: + return None + @classmethod def payload_cls(cls) -> Type[CustomComfyWorkflowData]: return CustomComfyWorkflowData diff --git a/workers/hello_world/server.py b/workers/hello_world/server.py index 6c08724..d9c439f 100644 --- a/workers/hello_world/server.py +++ b/workers/hello_world/server.py @@ -13,7 +13,7 @@ EndpointHandler. This is useful for endpoints such as healthchecks. See below fo import os import logging import dataclasses -from typing import Dict, Any, Union, Type +from typing import Dict, Any, Optional, Union, Type from aiohttp import web, ClientResponse @@ -49,7 +49,11 @@ class GenerateHandler(EndpointHandler[InputData]): def endpoint(self) -> str: # the API endpoint return "/generate" - + + @property + def healthcheck_endpoint(self) -> Optional[str]: + return None + @classmethod def payload_cls(cls) -> Type[InputData]: return InputData @@ -93,7 +97,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]): @property def endpoint(self) -> str: return "/generate_stream" - + + @property + def healthcheck_endpoint(self) -> Optional[str]: + return None + @classmethod def payload_cls(cls) -> Type[InputData]: return InputData diff --git a/workers/tgi/server.py b/workers/tgi/server.py index 2b881ad..95cc5bc 100644 --- a/workers/tgi/server.py +++ b/workers/tgi/server.py @@ -14,7 +14,7 @@ from .data_types import InputData MODEL_SERVER_URL = "http://0.0.0.0:5001" # This is the last log line that gets emitted once comfyui+extensions have been fully loaded -MODEL_SERVER_START_LOG_MSG = '"message":"Connected","target":"text_generation_router"' +MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"'] MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"] @@ -33,6 +33,10 @@ class GenerateHandler(EndpointHandler[InputData]): def endpoint(self) -> str: return "/generate" + @property + def healthcheck_endpoint(self) -> str: + return f"{MODEL_SERVER_URL}/health" + @classmethod def payload_cls(cls) -> Type[InputData]: return InputData @@ -58,7 +62,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]): @property def endpoint(self) -> str: return "/generate_stream" - + + @property + def healthcheck_endpoint(self) -> str: + return f"{MODEL_SERVER_URL}/health" + @classmethod def payload_cls(cls) -> Type[InputData]: return InputData @@ -91,7 +99,10 @@ backend = Backend( allow_parallel_requests=True, benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256), log_actions=[ - (LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG), + *[ + (LogAction.ModelLoaded, info_msg) + for info_msg in MODEL_SERVER_START_LOG_MSG + ], (LogAction.Info, '"message":"Download'), *[ (LogAction.ModelError, error_msg)