Endpoint update pr one (#1)

* Added endpoint flexibility along with existing log. extended the log support

* Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

* Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors

* Added endpoint flexibility along with existing log. extended the log support

Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors

* Endpoint Utils and API changes
This commit is contained in:
Abiola Akinnubi
2025-06-02 17:13:25 -07:00
committed by Nader Arbabian
parent b1ca68c349
commit 71ed54ebe4
11 changed files with 212 additions and 61 deletions
+7 -7
View File
@@ -186,7 +186,6 @@ class Backend:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint
if health_check_url is None:
@@ -200,19 +199,20 @@ class Backend:
log.debug("Healthcheck successful")
elif response.status == 503:
log.debug(f"Healthcheck failed with status: {response.status}")
self.backend_errored(f"Healthcheck failed with status: {response.status}")
self.backend_errored(
f"Healthcheck failed with status: {response.status}"
)
else:
# endpoint not ready yet so bail
log.debug(
f"Healthcheck Endpoint not ready: {response.status}"
)
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}")
self.backend_errored(str(e))
async def _start_tracking(self) -> None:
await gather(self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck())
await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
)
def backend_errored(self, msg: str) -> None:
self.metrics._model_errored(msg)
+6 -4
View File
@@ -8,6 +8,7 @@ from aiohttp import web, ClientResponse
import inspect
import psutil
import requests
"""
@@ -22,8 +23,10 @@ class JsonDataException(Exception):
def __init__(self, json_msg: Dict[str, Any]):
self.message = json_msg
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
@dataclass
class ApiPayload(ABC):
@@ -45,7 +48,9 @@ class ApiPayload(ABC):
@classmethod
@abstractmethod
def from_json_msg(cls: Type[ApiPayload_T], json_msg: Dict[str, Any]) -> ApiPayload_T:
def from_json_msg(
cls: Type[ApiPayload_T], json_msg: Dict[str, Any]
) -> ApiPayload_T:
"""
defines how to create an API payload from a JSON message,
it should throw an JsonDataException if there are issues with some fields
@@ -84,8 +89,6 @@ class AuthData:
)
@dataclass
class EndpointHandler(ABC, Generic[ApiPayload_T]):
"""
@@ -108,7 +111,6 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
"""the endpoint on the model API that is used for healthchecks"""
pass
@classmethod
@abstractmethod
def payload_cls(cls) -> Type[ApiPayload_T]:
+27 -2
View File
@@ -1,3 +1,4 @@
import logging
import os
import time
import argparse
@@ -13,6 +14,13 @@ import requests
from lib.data_types import AuthData, ApiPayload
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
class ClientStatus(Enum):
FetchEndpoint = 1
@@ -71,9 +79,19 @@ class ClientState:
def make_call(self):
self.status = ClientStatus.FetchEndpoint
endpoint_api_key = AuthData.get_endpoint_api_key(
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
)
if not endpoint_api_key:
self.as_error.append(
f"Endpoint {self.endpoint_group_name} not found for API key",
)
self.status = ClientStatus.Error
return
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.api_key,
"api_key": endpoint_api_key,
"cost": self.payload.count_workload(),
}
response = requests.post(
@@ -215,11 +233,18 @@ def run_test(
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
print_thread.daemon = True # makes threads get killed on program exit
print_thread.start()
endpoint_api_key = AuthData.get_endpoint_api_key(
endpoint_name=endpoint_group_name,
account_api_key=api_key,
)
if not endpoint_api_key:
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
return
try:
for _ in range(num_requests):
client = ClientState(
endpoint_group_name=endpoint_group_name,
api_key=api_key,
api_key=endpoint_api_key,
server_url=server_url,
worker_endpoint=worker_endpoint,
payload=payload_cls.for_test(),