Endpoint update pr one (#1)
* Added endpoint flexibility along with existing log. extended the log support * Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support * Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors * Added endpoint flexibility along with existing log. extended the log support Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors * Endpoint Utils and API changes
This commit is contained in:
committed by
Nader Arbabian
parent
b1ca68c349
commit
71ed54ebe4
+7
-7
@@ -186,7 +186,6 @@ class Backend:
|
||||
log.debug(f"Exception in main handler loop {e}")
|
||||
return web.Response(status=500)
|
||||
|
||||
|
||||
async def __healthcheck(self):
|
||||
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
||||
if health_check_url is None:
|
||||
@@ -200,19 +199,20 @@ class Backend:
|
||||
log.debug("Healthcheck successful")
|
||||
elif response.status == 503:
|
||||
log.debug(f"Healthcheck failed with status: {response.status}")
|
||||
self.backend_errored(f"Healthcheck failed with status: {response.status}")
|
||||
self.backend_errored(
|
||||
f"Healthcheck failed with status: {response.status}"
|
||||
)
|
||||
else:
|
||||
# endpoint not ready yet so bail
|
||||
log.debug(
|
||||
f"Healthcheck Endpoint not ready: {response.status}"
|
||||
)
|
||||
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
||||
except Exception as e:
|
||||
log.debug(f"Healthcheck failed with exception: {e}")
|
||||
self.backend_errored(str(e))
|
||||
|
||||
|
||||
async def _start_tracking(self) -> None:
|
||||
await gather(self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck())
|
||||
await gather(
|
||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
|
||||
)
|
||||
|
||||
def backend_errored(self, msg: str) -> None:
|
||||
self.metrics._model_errored(msg)
|
||||
|
||||
+6
-4
@@ -8,6 +8,7 @@ from aiohttp import web, ClientResponse
|
||||
import inspect
|
||||
|
||||
import psutil
|
||||
import requests
|
||||
|
||||
|
||||
"""
|
||||
@@ -22,8 +23,10 @@ class JsonDataException(Exception):
|
||||
def __init__(self, json_msg: Dict[str, Any]):
|
||||
self.message = json_msg
|
||||
|
||||
|
||||
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiPayload(ABC):
|
||||
|
||||
@@ -45,7 +48,9 @@ class ApiPayload(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_json_msg(cls: Type[ApiPayload_T], json_msg: Dict[str, Any]) -> ApiPayload_T:
|
||||
def from_json_msg(
|
||||
cls: Type[ApiPayload_T], json_msg: Dict[str, Any]
|
||||
) -> ApiPayload_T:
|
||||
"""
|
||||
defines how to create an API payload from a JSON message,
|
||||
it should throw an JsonDataException if there are issues with some fields
|
||||
@@ -84,8 +89,6 @@ class AuthData:
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointHandler(ABC, Generic[ApiPayload_T]):
|
||||
"""
|
||||
@@ -108,7 +111,6 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
|
||||
"""the endpoint on the model API that is used for healthchecks"""
|
||||
pass
|
||||
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def payload_cls(cls) -> Type[ApiPayload_T]:
|
||||
|
||||
+27
-2
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
@@ -13,6 +14,13 @@ import requests
|
||||
|
||||
from lib.data_types import AuthData, ApiPayload
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class ClientStatus(Enum):
|
||||
FetchEndpoint = 1
|
||||
@@ -71,9 +79,19 @@ class ClientState:
|
||||
|
||||
def make_call(self):
|
||||
self.status = ClientStatus.FetchEndpoint
|
||||
endpoint_api_key = AuthData.get_endpoint_api_key(
|
||||
endpoint_name=self.endpoint_group_name,
|
||||
account_api_key=self.api_key,
|
||||
)
|
||||
if not endpoint_api_key:
|
||||
self.as_error.append(
|
||||
f"Endpoint {self.endpoint_group_name} not found for API key",
|
||||
)
|
||||
self.status = ClientStatus.Error
|
||||
return
|
||||
route_payload = {
|
||||
"endpoint": self.endpoint_group_name,
|
||||
"api_key": self.api_key,
|
||||
"api_key": endpoint_api_key,
|
||||
"cost": self.payload.count_workload(),
|
||||
}
|
||||
response = requests.post(
|
||||
@@ -215,11 +233,18 @@ def run_test(
|
||||
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
|
||||
print_thread.daemon = True # makes threads get killed on program exit
|
||||
print_thread.start()
|
||||
endpoint_api_key = AuthData.get_endpoint_api_key(
|
||||
endpoint_name=endpoint_group_name,
|
||||
account_api_key=api_key,
|
||||
)
|
||||
if not endpoint_api_key:
|
||||
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
||||
return
|
||||
try:
|
||||
for _ in range(num_requests):
|
||||
client = ClientState(
|
||||
endpoint_group_name=endpoint_group_name,
|
||||
api_key=api_key,
|
||||
api_key=endpoint_api_key,
|
||||
server_url=server_url,
|
||||
worker_endpoint=worker_endpoint,
|
||||
payload=payload_cls.for_test(),
|
||||
|
||||
Reference in New Issue
Block a user