Added endpoint flexibility along with existing log. extended the log support
Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors
This commit is contained in:
committed by
Nader Arbabian
parent
6de4ee2b59
commit
b1ca68c349
+26
-1
@@ -186,8 +186,33 @@ class Backend:
|
|||||||
log.debug(f"Exception in main handler loop {e}")
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
|
|
||||||
|
async def __healthcheck(self):
|
||||||
|
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
||||||
|
if health_check_url is None:
|
||||||
|
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||||
|
return
|
||||||
|
await sleep(5)
|
||||||
|
try:
|
||||||
|
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||||
|
async with self.session.get(health_check_url) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
log.debug("Healthcheck successful")
|
||||||
|
elif response.status == 503:
|
||||||
|
log.debug(f"Healthcheck failed with status: {response.status}")
|
||||||
|
self.backend_errored(f"Healthcheck failed with status: {response.status}")
|
||||||
|
else:
|
||||||
|
# endpoint not ready yet so bail
|
||||||
|
log.debug(
|
||||||
|
f"Healthcheck Endpoint not ready: {response.status}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Healthcheck failed with exception: {e}")
|
||||||
|
self.backend_errored(str(e))
|
||||||
|
|
||||||
|
|
||||||
async def _start_tracking(self) -> None:
|
async def _start_tracking(self) -> None:
|
||||||
await gather(self.__read_logs(), self.metrics._send_metrics_loop())
|
await gather(self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck())
|
||||||
|
|
||||||
def backend_errored(self, msg: str) -> None:
|
def backend_errored(self, msg: str) -> None:
|
||||||
self.metrics._model_errored(msg)
|
self.metrics._model_errored(msg)
|
||||||
|
|||||||
+14
-5
@@ -22,13 +22,14 @@ class JsonDataException(Exception):
|
|||||||
def __init__(self, json_msg: Dict[str, Any]):
|
def __init__(self, json_msg: Dict[str, Any]):
|
||||||
self.message = json_msg
|
self.message = json_msg
|
||||||
|
|
||||||
|
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ApiPayload(ABC):
|
class ApiPayload(ABC):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def for_test(cls) -> "ApiPayload":
|
def for_test(cls: Type[ApiPayload_T]) -> ApiPayload_T:
|
||||||
"""defines how create a payload for load testing"""
|
"""defines how create a payload for load testing"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -44,7 +45,7 @@ class ApiPayload(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ApiPayload":
|
def from_json_msg(cls: Type[ApiPayload_T], json_msg: Dict[str, Any]) -> ApiPayload_T:
|
||||||
"""
|
"""
|
||||||
defines how to create an API payload from a JSON message,
|
defines how to create an API payload from a JSON message,
|
||||||
it should throw an JsonDataException if there are issues with some fields
|
it should throw an JsonDataException if there are issues with some fields
|
||||||
@@ -83,7 +84,6 @@ class AuthData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ApiPayload_T = TypeVar("ApiPayload_T", bound=ApiPayload)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -102,6 +102,13 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
|
|||||||
"""the endpoint on the model API"""
|
"""the endpoint on the model API"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def healthcheck_endpoint(self) -> Optional[str]:
|
||||||
|
"""the endpoint on the model API that is used for healthchecks"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def payload_cls(cls) -> Type[ApiPayload_T]:
|
def payload_cls(cls) -> Type[ApiPayload_T]:
|
||||||
@@ -127,7 +134,8 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
|
|||||||
cls, req_data: Dict[str, Any]
|
cls, req_data: Dict[str, Any]
|
||||||
) -> Tuple[AuthData, ApiPayload_T]:
|
) -> Tuple[AuthData, ApiPayload_T]:
|
||||||
errors = {}
|
errors = {}
|
||||||
auth_data = payload = None
|
auth_data: Optional[AuthData] = None
|
||||||
|
payload: Optional[ApiPayload_T] = None
|
||||||
try:
|
try:
|
||||||
if "auth_data" in req_data:
|
if "auth_data" in req_data:
|
||||||
auth_data = AuthData.from_json_msg(req_data["auth_data"])
|
auth_data = AuthData.from_json_msg(req_data["auth_data"])
|
||||||
@@ -137,7 +145,8 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
|
|||||||
errors["auth_data"] = e.message
|
errors["auth_data"] = e.message
|
||||||
try:
|
try:
|
||||||
if "payload" in req_data:
|
if "payload" in req_data:
|
||||||
payload = cls.payload_cls().from_json_msg(req_data["payload"])
|
payload_cls = cls.payload_cls()
|
||||||
|
payload = payload_cls.from_json_msg(req_data["payload"])
|
||||||
else:
|
else:
|
||||||
errors["payload"] = "field missing"
|
errors["payload"] = "field missing"
|
||||||
except JsonDataException as e:
|
except JsonDataException as e:
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ class CustomComfyWorkflowData(ApiPayload):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls):
|
def for_test(cls):
|
||||||
raise NotImplemented("Custom comfy workflow is not used for testing")
|
raise NotImplementedError("Custom comfy workflow is not used for testing")
|
||||||
|
|
||||||
def count_workload(self) -> float:
|
def count_workload(self) -> float:
|
||||||
return count_workload(
|
return count_workload(
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import base64
|
import base64
|
||||||
from typing import Union, Type
|
from typing import Optional, Union, Type
|
||||||
|
|
||||||
from aiohttp import web, ClientResponse
|
from aiohttp import web, ClientResponse
|
||||||
from anyio import open_file
|
from anyio import open_file
|
||||||
@@ -70,6 +70,10 @@ class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
|
|||||||
def endpoint(self) -> str:
|
def endpoint(self) -> str:
|
||||||
return "/runsync"
|
return "/runsync"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def healthcheck_endpoint(self) -> Optional[str]:
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
|
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
|
||||||
return DefaultComfyWorkflowData
|
return DefaultComfyWorkflowData
|
||||||
@@ -90,6 +94,10 @@ class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
|
|||||||
def endpoint(self) -> str:
|
def endpoint(self) -> str:
|
||||||
return "/runsync"
|
return "/runsync"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def healthcheck_endpoint(self) -> Optional[str]:
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
|
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
|
||||||
return CustomComfyWorkflowData
|
return CustomComfyWorkflowData
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ EndpointHandler. This is useful for endpoints such as healthchecks. See below fo
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Dict, Any, Union, Type
|
from typing import Dict, Any, Optional, Union, Type
|
||||||
|
|
||||||
from aiohttp import web, ClientResponse
|
from aiohttp import web, ClientResponse
|
||||||
|
|
||||||
@@ -50,6 +50,10 @@ class GenerateHandler(EndpointHandler[InputData]):
|
|||||||
# the API endpoint
|
# the API endpoint
|
||||||
return "/generate"
|
return "/generate"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def healthcheck_endpoint(self) -> Optional[str]:
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def payload_cls(cls) -> Type[InputData]:
|
def payload_cls(cls) -> Type[InputData]:
|
||||||
return InputData
|
return InputData
|
||||||
@@ -94,6 +98,10 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
|
|||||||
def endpoint(self) -> str:
|
def endpoint(self) -> str:
|
||||||
return "/generate_stream"
|
return "/generate_stream"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def healthcheck_endpoint(self) -> Optional[str]:
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def payload_cls(cls) -> Type[InputData]:
|
def payload_cls(cls) -> Type[InputData]:
|
||||||
return InputData
|
return InputData
|
||||||
|
|||||||
+13
-2
@@ -14,7 +14,7 @@ from .data_types import InputData
|
|||||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||||
|
|
||||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||||
MODEL_SERVER_START_LOG_MSG = '"message":"Connected","target":"text_generation_router"'
|
MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"']
|
||||||
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"]
|
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"]
|
||||||
|
|
||||||
|
|
||||||
@@ -33,6 +33,10 @@ class GenerateHandler(EndpointHandler[InputData]):
|
|||||||
def endpoint(self) -> str:
|
def endpoint(self) -> str:
|
||||||
return "/generate"
|
return "/generate"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def healthcheck_endpoint(self) -> str:
|
||||||
|
return f"{MODEL_SERVER_URL}/health"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def payload_cls(cls) -> Type[InputData]:
|
def payload_cls(cls) -> Type[InputData]:
|
||||||
return InputData
|
return InputData
|
||||||
@@ -59,6 +63,10 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
|
|||||||
def endpoint(self) -> str:
|
def endpoint(self) -> str:
|
||||||
return "/generate_stream"
|
return "/generate_stream"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def healthcheck_endpoint(self) -> str:
|
||||||
|
return f"{MODEL_SERVER_URL}/health"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def payload_cls(cls) -> Type[InputData]:
|
def payload_cls(cls) -> Type[InputData]:
|
||||||
return InputData
|
return InputData
|
||||||
@@ -91,7 +99,10 @@ backend = Backend(
|
|||||||
allow_parallel_requests=True,
|
allow_parallel_requests=True,
|
||||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||||
log_actions=[
|
log_actions=[
|
||||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
*[
|
||||||
|
(LogAction.ModelLoaded, info_msg)
|
||||||
|
for info_msg in MODEL_SERVER_START_LOG_MSG
|
||||||
|
],
|
||||||
(LogAction.Info, '"message":"Download'),
|
(LogAction.Info, '"message":"Download'),
|
||||||
*[
|
*[
|
||||||
(LogAction.ModelError, error_msg)
|
(LogAction.ModelError, error_msg)
|
||||||
|
|||||||
Reference in New Issue
Block a user