Added endpoint flexibility along with existing log. extended the log support

Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors
This commit is contained in:
Abiola Akinnubi
2025-05-29 17:22:31 -07:00
committed by Nader Arbabian
parent 6de4ee2b59
commit b1ca68c349
6 changed files with 77 additions and 16 deletions
+26 -1
View File
@@ -186,8 +186,33 @@ class Backend:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint
if health_check_url is None:
log.debug("No healthcheck endpoint defined, skipping healthcheck")
return
await sleep(5)
try:
log.debug(f"Performing healthcheck on {health_check_url}")
async with self.session.get(health_check_url) as response:
if response.status == 200:
log.debug("Healthcheck successful")
elif response.status == 503:
log.debug(f"Healthcheck failed with status: {response.status}")
self.backend_errored(f"Healthcheck failed with status: {response.status}")
else:
# endpoint not ready yet so bail
log.debug(
f"Healthcheck Endpoint not ready: {response.status}"
)
except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}")
self.backend_errored(str(e))
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
await gather(self.__read_logs(), self.metrics._send_metrics_loop()) await gather(self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck())
def backend_errored(self, msg: str) -> None: def backend_errored(self, msg: str) -> None:
self.metrics._model_errored(msg) self.metrics._model_errored(msg)
+14 -5
View File
@@ -22,13 +22,14 @@ class JsonDataException(Exception):
def __init__(self, json_msg: Dict[str, Any]): def __init__(self, json_msg: Dict[str, Any]):
self.message = json_msg self.message = json_msg
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
@dataclass @dataclass
class ApiPayload(ABC): class ApiPayload(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def for_test(cls) -> "ApiPayload": def for_test(cls: Type[ApiPayload_T]) -> ApiPayload_T:
"""defines how create a payload for load testing""" """defines how create a payload for load testing"""
pass pass
@@ -44,7 +45,7 @@ class ApiPayload(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ApiPayload": def from_json_msg(cls: Type[ApiPayload_T], json_msg: Dict[str, Any]) -> ApiPayload_T:
""" """
defines how to create an API payload from a JSON message, defines how to create an API payload from a JSON message,
it should throw an JsonDataException if there are issues with some fields it should throw an JsonDataException if there are issues with some fields
@@ -83,7 +84,6 @@ class AuthData:
) )
ApiPayload_T = TypeVar("ApiPayload_T", bound=ApiPayload)
@dataclass @dataclass
@@ -102,6 +102,13 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
"""the endpoint on the model API""" """the endpoint on the model API"""
pass pass
@property
@abstractmethod
def healthcheck_endpoint(self) -> Optional[str]:
"""the endpoint on the model API that is used for healthchecks"""
pass
@classmethod @classmethod
@abstractmethod @abstractmethod
def payload_cls(cls) -> Type[ApiPayload_T]: def payload_cls(cls) -> Type[ApiPayload_T]:
@@ -127,7 +134,8 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
cls, req_data: Dict[str, Any] cls, req_data: Dict[str, Any]
) -> Tuple[AuthData, ApiPayload_T]: ) -> Tuple[AuthData, ApiPayload_T]:
errors = {} errors = {}
auth_data = payload = None auth_data: Optional[AuthData] = None
payload: Optional[ApiPayload_T] = None
try: try:
if "auth_data" in req_data: if "auth_data" in req_data:
auth_data = AuthData.from_json_msg(req_data["auth_data"]) auth_data = AuthData.from_json_msg(req_data["auth_data"])
@@ -137,7 +145,8 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
errors["auth_data"] = e.message errors["auth_data"] = e.message
try: try:
if "payload" in req_data: if "payload" in req_data:
payload = cls.payload_cls().from_json_msg(req_data["payload"]) payload_cls = cls.payload_cls()
payload = payload_cls.from_json_msg(req_data["payload"])
else: else:
errors["payload"] = "field missing" errors["payload"] = "field missing"
except JsonDataException as e: except JsonDataException as e:
+1 -1
View File
@@ -174,7 +174,7 @@ class CustomComfyWorkflowData(ApiPayload):
@classmethod @classmethod
def for_test(cls): def for_test(cls):
raise NotImplemented("Custom comfy workflow is not used for testing") raise NotImplementedError("Custom comfy workflow is not used for testing")
def count_workload(self) -> float: def count_workload(self) -> float:
return count_workload( return count_workload(
+9 -1
View File
@@ -2,7 +2,7 @@ import os
import logging import logging
import dataclasses import dataclasses
import base64 import base64
from typing import Union, Type from typing import Optional, Union, Type
from aiohttp import web, ClientResponse from aiohttp import web, ClientResponse
from anyio import open_file from anyio import open_file
@@ -70,6 +70,10 @@ class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
def endpoint(self) -> str: def endpoint(self) -> str:
return "/runsync" return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod @classmethod
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]: def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
return DefaultComfyWorkflowData return DefaultComfyWorkflowData
@@ -90,6 +94,10 @@ class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
def endpoint(self) -> str: def endpoint(self) -> str:
return "/runsync" return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod @classmethod
def payload_cls(cls) -> Type[CustomComfyWorkflowData]: def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
return CustomComfyWorkflowData return CustomComfyWorkflowData
+9 -1
View File
@@ -13,7 +13,7 @@ EndpointHandler. This is useful for endpoints such as healthchecks. See below fo
import os import os
import logging import logging
import dataclasses import dataclasses
from typing import Dict, Any, Union, Type from typing import Dict, Any, Optional, Union, Type
from aiohttp import web, ClientResponse from aiohttp import web, ClientResponse
@@ -50,6 +50,10 @@ class GenerateHandler(EndpointHandler[InputData]):
# the API endpoint # the API endpoint
return "/generate" return "/generate"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod @classmethod
def payload_cls(cls) -> Type[InputData]: def payload_cls(cls) -> Type[InputData]:
return InputData return InputData
@@ -94,6 +98,10 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
def endpoint(self) -> str: def endpoint(self) -> str:
return "/generate_stream" return "/generate_stream"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod @classmethod
def payload_cls(cls) -> Type[InputData]: def payload_cls(cls) -> Type[InputData]:
return InputData return InputData
+13 -2
View File
@@ -14,7 +14,7 @@ from .data_types import InputData
MODEL_SERVER_URL = "http://0.0.0.0:5001" MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded # This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = '"message":"Connected","target":"text_generation_router"' MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"']
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"] MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"]
@@ -33,6 +33,10 @@ class GenerateHandler(EndpointHandler[InputData]):
def endpoint(self) -> str: def endpoint(self) -> str:
return "/generate" return "/generate"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod @classmethod
def payload_cls(cls) -> Type[InputData]: def payload_cls(cls) -> Type[InputData]:
return InputData return InputData
@@ -59,6 +63,10 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
def endpoint(self) -> str: def endpoint(self) -> str:
return "/generate_stream" return "/generate_stream"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod @classmethod
def payload_cls(cls) -> Type[InputData]: def payload_cls(cls) -> Type[InputData]:
return InputData return InputData
@@ -91,7 +99,10 @@ backend = Backend(
allow_parallel_requests=True, allow_parallel_requests=True,
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256), benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[ log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG), *[
(LogAction.ModelLoaded, info_msg)
for info_msg in MODEL_SERVER_START_LOG_MSG
],
(LogAction.Info, '"message":"Download'), (LogAction.Info, '"message":"Download'),
*[ *[
(LogAction.ModelError, error_msg) (LogAction.ModelError, error_msg)