Added endpoint flexibility along with existing log. extended the log support

Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors
This commit is contained in:
Abiola Akinnubi
2025-05-29 17:22:31 -07:00
committed by Nader Arbabian
parent 6de4ee2b59
commit b1ca68c349
6 changed files with 77 additions and 16 deletions
+26 -1
View File
@@ -186,8 +186,33 @@ class Backend:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint
if health_check_url is None:
log.debug("No healthcheck endpoint defined, skipping healthcheck")
return
await sleep(5)
try:
log.debug(f"Performing healthcheck on {health_check_url}")
async with self.session.get(health_check_url) as response:
if response.status == 200:
log.debug("Healthcheck successful")
elif response.status == 503:
log.debug(f"Healthcheck failed with status: {response.status}")
self.backend_errored(f"Healthcheck failed with status: {response.status}")
else:
# endpoint not ready yet so bail
log.debug(
f"Healthcheck Endpoint not ready: {response.status}"
)
except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}")
self.backend_errored(str(e))
async def _start_tracking(self) -> None:
await gather(self.__read_logs(), self.metrics._send_metrics_loop())
await gather(self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck())
def backend_errored(self, msg: str) -> None:
self.metrics._model_errored(msg)
+14 -5
View File
@@ -22,13 +22,14 @@ class JsonDataException(Exception):
def __init__(self, json_msg: Dict[str, Any]):
self.message = json_msg
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
@dataclass
class ApiPayload(ABC):
@classmethod
@abstractmethod
def for_test(cls) -> "ApiPayload":
def for_test(cls: Type[ApiPayload_T]) -> ApiPayload_T:
"""defines how create a payload for load testing"""
pass
@@ -44,7 +45,7 @@ class ApiPayload(ABC):
@classmethod
@abstractmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ApiPayload":
def from_json_msg(cls: Type[ApiPayload_T], json_msg: Dict[str, Any]) -> ApiPayload_T:
"""
defines how to create an API payload from a JSON message,
it should throw an JsonDataException if there are issues with some fields
@@ -83,7 +84,6 @@ class AuthData:
)
ApiPayload_T = TypeVar("ApiPayload_T", bound=ApiPayload)
@dataclass
@@ -102,6 +102,13 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
"""the endpoint on the model API"""
pass
@property
@abstractmethod
def healthcheck_endpoint(self) -> Optional[str]:
"""the endpoint on the model API that is used for healthchecks"""
pass
@classmethod
@abstractmethod
def payload_cls(cls) -> Type[ApiPayload_T]:
@@ -127,7 +134,8 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
cls, req_data: Dict[str, Any]
) -> Tuple[AuthData, ApiPayload_T]:
errors = {}
auth_data = payload = None
auth_data: Optional[AuthData] = None
payload: Optional[ApiPayload_T] = None
try:
if "auth_data" in req_data:
auth_data = AuthData.from_json_msg(req_data["auth_data"])
@@ -137,7 +145,8 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
errors["auth_data"] = e.message
try:
if "payload" in req_data:
payload = cls.payload_cls().from_json_msg(req_data["payload"])
payload_cls = cls.payload_cls()
payload = payload_cls.from_json_msg(req_data["payload"])
else:
errors["payload"] = "field missing"
except JsonDataException as e:
+1 -1
View File
@@ -174,7 +174,7 @@ class CustomComfyWorkflowData(ApiPayload):
@classmethod
def for_test(cls):
raise NotImplemented("Custom comfy workflow is not used for testing")
raise NotImplementedError("Custom comfy workflow is not used for testing")
def count_workload(self) -> float:
return count_workload(
+9 -1
View File
@@ -2,7 +2,7 @@ import os
import logging
import dataclasses
import base64
from typing import Union, Type
from typing import Optional, Union, Type
from aiohttp import web, ClientResponse
from anyio import open_file
@@ -70,6 +70,10 @@ class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
def endpoint(self) -> str:
return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
return DefaultComfyWorkflowData
@@ -90,6 +94,10 @@ class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
def endpoint(self) -> str:
return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
return CustomComfyWorkflowData
+9 -1
View File
@@ -13,7 +13,7 @@ EndpointHandler. This is useful for endpoints such as healthchecks. See below fo
import os
import logging
import dataclasses
from typing import Dict, Any, Union, Type
from typing import Dict, Any, Optional, Union, Type
from aiohttp import web, ClientResponse
@@ -50,6 +50,10 @@ class GenerateHandler(EndpointHandler[InputData]):
# the API endpoint
return "/generate"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
@@ -94,6 +98,10 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
def endpoint(self) -> str:
return "/generate_stream"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
+13 -2
View File
@@ -14,7 +14,7 @@ from .data_types import InputData
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = '"message":"Connected","target":"text_generation_router"'
MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"']
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"]
@@ -33,6 +33,10 @@ class GenerateHandler(EndpointHandler[InputData]):
def endpoint(self) -> str:
return "/generate"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
@@ -59,6 +63,10 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
def endpoint(self) -> str:
return "/generate_stream"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
@@ -91,7 +99,10 @@ backend = Backend(
allow_parallel_requests=True,
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
*[
(LogAction.ModelLoaded, info_msg)
for info_msg in MODEL_SERVER_START_LOG_MSG
],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)