Added endpoint flexibility along with existing log. extended the log support

Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors
This commit is contained in:
Abiola Akinnubi
2025-05-29 17:22:31 -07:00
committed by Nader Arbabian
parent 6de4ee2b59
commit b1ca68c349
6 changed files with 77 additions and 16 deletions
+26 -1
View File
@@ -186,8 +186,33 @@ class Backend:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint
if health_check_url is None:
log.debug("No healthcheck endpoint defined, skipping healthcheck")
return
await sleep(5)
try:
log.debug(f"Performing healthcheck on {health_check_url}")
async with self.session.get(health_check_url) as response:
if response.status == 200:
log.debug("Healthcheck successful")
elif response.status == 503:
log.debug(f"Healthcheck failed with status: {response.status}")
self.backend_errored(f"Healthcheck failed with status: {response.status}")
else:
# endpoint not ready yet so bail
log.debug(
f"Healthcheck Endpoint not ready: {response.status}"
)
except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}")
self.backend_errored(str(e))
async def _start_tracking(self) -> None:
await gather(self.__read_logs(), self.metrics._send_metrics_loop())
await gather(self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck())
def backend_errored(self, msg: str) -> None:
self.metrics._model_errored(msg)
+14 -5
View File
@@ -22,13 +22,14 @@ class JsonDataException(Exception):
def __init__(self, json_msg: Dict[str, Any]):
self.message = json_msg
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
@dataclass
class ApiPayload(ABC):
@classmethod
@abstractmethod
def for_test(cls) -> "ApiPayload":
def for_test(cls: Type[ApiPayload_T]) -> ApiPayload_T:
"""defines how create a payload for load testing"""
pass
@@ -44,7 +45,7 @@ class ApiPayload(ABC):
@classmethod
@abstractmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ApiPayload":
def from_json_msg(cls: Type[ApiPayload_T], json_msg: Dict[str, Any]) -> ApiPayload_T:
"""
defines how to create an API payload from a JSON message,
it should throw an JsonDataException if there are issues with some fields
@@ -83,7 +84,6 @@ class AuthData:
)
ApiPayload_T = TypeVar("ApiPayload_T", bound=ApiPayload)
@dataclass
@@ -102,6 +102,13 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
"""the endpoint on the model API"""
pass
@property
@abstractmethod
def healthcheck_endpoint(self) -> Optional[str]:
"""the endpoint on the model API that is used for healthchecks"""
pass
@classmethod
@abstractmethod
def payload_cls(cls) -> Type[ApiPayload_T]:
@@ -127,7 +134,8 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
cls, req_data: Dict[str, Any]
) -> Tuple[AuthData, ApiPayload_T]:
errors = {}
auth_data = payload = None
auth_data: Optional[AuthData] = None
payload: Optional[ApiPayload_T] = None
try:
if "auth_data" in req_data:
auth_data = AuthData.from_json_msg(req_data["auth_data"])
@@ -137,7 +145,8 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
errors["auth_data"] = e.message
try:
if "payload" in req_data:
payload = cls.payload_cls().from_json_msg(req_data["payload"])
payload_cls = cls.payload_cls()
payload = payload_cls.from_json_msg(req_data["payload"])
else:
errors["payload"] = "field missing"
except JsonDataException as e: