Endpoint update pr one (#1)

* Added endpoint flexibility along with existing log. extended the log support

* Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

* Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors

* Added endpoint flexibility along with existing log. extended the log support

Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors

* Endpoint Utils and API changes
This commit is contained in:
Abiola Akinnubi
2025-06-02 17:13:25 -07:00
committed by Nader Arbabian
parent b1ca68c349
commit 71ed54ebe4
11 changed files with 212 additions and 61 deletions
+2
View File
@@ -1,3 +1,5 @@
.direnv .direnv
.envrc .envrc
__pycache__ __pycache__
bin/
lib64
+7 -7
View File
@@ -186,7 +186,6 @@ class Backend:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
async def __healthcheck(self): async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint health_check_url = self.benchmark_handler.healthcheck_endpoint
if health_check_url is None: if health_check_url is None:
@@ -200,19 +199,20 @@ class Backend:
log.debug("Healthcheck successful") log.debug("Healthcheck successful")
elif response.status == 503: elif response.status == 503:
log.debug(f"Healthcheck failed with status: {response.status}") log.debug(f"Healthcheck failed with status: {response.status}")
self.backend_errored(f"Healthcheck failed with status: {response.status}") self.backend_errored(
f"Healthcheck failed with status: {response.status}"
)
else: else:
# endpoint not ready yet so bail # endpoint not ready yet so bail
log.debug( log.debug(f"Healthcheck Endpoint not ready: {response.status}")
f"Healthcheck Endpoint not ready: {response.status}"
)
except Exception as e: except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}") log.debug(f"Healthcheck failed with exception: {e}")
self.backend_errored(str(e)) self.backend_errored(str(e))
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
await gather(self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()) await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
)
def backend_errored(self, msg: str) -> None: def backend_errored(self, msg: str) -> None:
self.metrics._model_errored(msg) self.metrics._model_errored(msg)
+6 -4
View File
@@ -8,6 +8,7 @@ from aiohttp import web, ClientResponse
import inspect import inspect
import psutil import psutil
import requests
""" """
@@ -22,8 +23,10 @@ class JsonDataException(Exception):
def __init__(self, json_msg: Dict[str, Any]): def __init__(self, json_msg: Dict[str, Any]):
self.message = json_msg self.message = json_msg
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload") ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
@dataclass @dataclass
class ApiPayload(ABC): class ApiPayload(ABC):
@@ -45,7 +48,9 @@ class ApiPayload(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_json_msg(cls: Type[ApiPayload_T], json_msg: Dict[str, Any]) -> ApiPayload_T: def from_json_msg(
cls: Type[ApiPayload_T], json_msg: Dict[str, Any]
) -> ApiPayload_T:
""" """
defines how to create an API payload from a JSON message, defines how to create an API payload from a JSON message,
it should throw an JsonDataException if there are issues with some fields it should throw an JsonDataException if there are issues with some fields
@@ -84,8 +89,6 @@ class AuthData:
) )
@dataclass @dataclass
class EndpointHandler(ABC, Generic[ApiPayload_T]): class EndpointHandler(ABC, Generic[ApiPayload_T]):
""" """
@@ -108,7 +111,6 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]):
"""the endpoint on the model API that is used for healthchecks""" """the endpoint on the model API that is used for healthchecks"""
pass pass
@classmethod @classmethod
@abstractmethod @abstractmethod
def payload_cls(cls) -> Type[ApiPayload_T]: def payload_cls(cls) -> Type[ApiPayload_T]:
+27 -2
View File
@@ -1,3 +1,4 @@
import logging
import os import os
import time import time
import argparse import argparse
@@ -13,6 +14,13 @@ import requests
from lib.data_types import AuthData, ApiPayload from lib.data_types import AuthData, ApiPayload
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
class ClientStatus(Enum): class ClientStatus(Enum):
FetchEndpoint = 1 FetchEndpoint = 1
@@ -71,9 +79,19 @@ class ClientState:
def make_call(self): def make_call(self):
self.status = ClientStatus.FetchEndpoint self.status = ClientStatus.FetchEndpoint
endpoint_api_key = AuthData.get_endpoint_api_key(
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
)
if not endpoint_api_key:
self.as_error.append(
f"Endpoint {self.endpoint_group_name} not found for API key",
)
self.status = ClientStatus.Error
return
route_payload = { route_payload = {
"endpoint": self.endpoint_group_name, "endpoint": self.endpoint_group_name,
"api_key": self.api_key, "api_key": endpoint_api_key,
"cost": self.payload.count_workload(), "cost": self.payload.count_workload(),
} }
response = requests.post( response = requests.post(
@@ -215,11 +233,18 @@ def run_test(
print_thread = threading.Thread(target=print_state, args=(clients, num_requests)) print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
print_thread.daemon = True # makes threads get killed on program exit print_thread.daemon = True # makes threads get killed on program exit
print_thread.start() print_thread.start()
endpoint_api_key = AuthData.get_endpoint_api_key(
endpoint_name=endpoint_group_name,
account_api_key=api_key,
)
if not endpoint_api_key:
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
return
try: try:
for _ in range(num_requests): for _ in range(num_requests):
client = ClientState( client = ClientState(
endpoint_group_name=endpoint_group_name, endpoint_group_name=endpoint_group_name,
api_key=api_key, api_key=endpoint_api_key,
server_url=server_url, server_url=server_url,
worker_endpoint=worker_endpoint, worker_endpoint=worker_endpoint,
payload=payload_cls.for_test(), payload=payload_cls.for_test(),
+2 -3
View File
@@ -8,11 +8,10 @@ SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
ENV_PATH="$WORKSPACE_DIR/worker-env" ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log" DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log" PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
MODEL_LOG="$WORKSPACE_DIR/model.log"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}" REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}" USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}" WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR" mkdir -p "$WORKSPACE_DIR"
cd "$WORKSPACE_DIR" cd "$WORKSPACE_DIR"
@@ -47,7 +46,7 @@ env | grep _ >> /etc/environment;
if [ ! -d "$ENV_PATH" ] if [ ! -d "$ENV_PATH" ]
then then
apt install -y python3.10-venv apt install -y python3.12-venv
echo "setting up venv" echo "setting up venv"
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR" git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
+69
View File
@@ -0,0 +1,69 @@
from logging import log
from typing import Any, Dict, Optional
import requests
class Endpoint:
"""
Utility class for handling endpoint operations.
"""
@staticmethod
def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]:
"""
Fetch endpoint API key from VastAI console following the healthcheck pattern.
Args:
endpoint_name: Name of the endpoint
account_api_key: Account API key for authentication
Returns:
Endpoint API key if successful, None otherwise
"""
vast_console_url = "https://console.vast.ai/api/v0/endptjobs/"
headers = {"Authorization": f"Bearer {account_api_key}"}
try:
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
response = requests.get(vast_console_url, headers=headers)
if response.status_code != 200:
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
log.debug(error_msg)
return None
try:
data = response.json()
except requests.exceptions.JSONDecodeError as e:
log.debug(f"Failed to parse JSON response: {e}")
return None
result = data.get("results", [])
endpoint: Optional[Dict[str, Any]] = next(
(item for item in result if item["endpoint_name"] == endpoint_name),
None,
)
if not endpoint:
error_msg = f"Endpoint '{endpoint_name}' not found."
log.debug(error_msg)
return None
endpoint_api_key = endpoint.get("api_key")
if not endpoint_api_key:
error_msg = f"API key for endpoint '{endpoint_name}' not found."
log.debug(error_msg)
return None
log.debug(f"Successfully retrieved API key for endpoint: {endpoint_name}")
return endpoint_api_key
except requests.exceptions.RequestException as e:
error_msg = f"Request error while fetching endpoint API key: {e}"
log.debug(error_msg)
return None
except Exception as e:
error_msg = f"Unexpected error while fetching endpoint API key: {e}"
log.debug(error_msg)
return None
+24 -2
View File
@@ -1,12 +1,20 @@
import logging
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
from lib.test_utils import print_truncate_res from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
""" """
NOTE: this client example uses a custom comfy workflow compatible with SD3 only NOTE: this client example uses a custom comfy workflow compatible with SD3 only
""" """
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_default_workflow( def call_default_workflow(
@@ -24,6 +32,7 @@ def call_default_workflow(
json=route_payload, json=route_payload,
timeout=4, timeout=4,
) )
response.raise_for_status()
message = response.json() message = response.json()
url = message["url"] url = message["url"]
auth_data = dict( auth_data = dict(
@@ -43,6 +52,7 @@ def call_default_workflow(
url, url,
json=req_data, json=req_data,
) )
response.raise_for_status()
print_truncate_res(str(response.json())) print_truncate_res(str(response.json()))
@@ -61,6 +71,7 @@ def call_custom_workflow_for_sd3(
json=route_payload, json=route_payload,
timeout=4, timeout=4,
) )
response.raise_for_status()
message = response.json() message = response.json()
url = message["url"] url = message["url"]
auth_data = dict( auth_data = dict(
@@ -131,6 +142,7 @@ def call_custom_workflow_for_sd3(
url, url,
json=req_data, json=req_data,
) )
response.raise_for_status()
print_truncate_res(str(response.json())) print_truncate_res(str(response.json()))
@@ -138,13 +150,23 @@ if __name__ == "__main__":
from lib.test_utils import test_args from lib.test_utils import test_args
args = test_args.parse_args() args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
)
if endpoint_api_key:
try:
call_default_workflow( call_default_workflow(
api_key=args.api_key, api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url, server_url=args.server_url,
) )
call_custom_workflow_for_sd3( call_custom_workflow_for_sd3(
api_key=args.api_key, api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url, server_url=args.server_url,
) )
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
+37 -9
View File
@@ -1,8 +1,16 @@
import logging
import sys import sys
import json import json
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
from utils.endpoint_util import Endpoint
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None: def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
@@ -18,28 +26,31 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
json=route_payload, json=route_payload,
timeout=4, timeout=4,
) )
response.raise_for_status() # Raise an exception for bad status codes
message = response.json() message = response.json()
url = message["url"] url = message["url"]
auth_data = dict( auth_data = dict(
signature=message["signature"], signature=message["signature"],
cost=message["cost"], cost=message["cost"],
endpoint=message["endpoint"], endpoint=message["endpoint"],
reqnum=message["reqnum"], reqnum=message["reqnum"],
url=message["url"], url=url,
) )
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500)) payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data) req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT) url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}") print(f"url: {url}")
response = requests.post( response = requests.post(url, json=req_data)
url, response.raise_for_status()
json=req_data,
)
res = response.json() res = response.json()
print(res) print(res)
def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str): def call_generate_stream(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/generate_stream" WORKER_ENDPOINT = "/generate_stream"
COST = 100 COST = 100
route_payload = { route_payload = {
@@ -52,6 +63,7 @@ def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str
json=route_payload, json=route_payload,
timeout=4, timeout=4,
) )
response.raise_for_status() # Raise an exception for bad status codes
message = response.json() message = response.json()
url = message["url"] url = message["url"]
print(f"url: {url}") print(f"url: {url}")
@@ -66,12 +78,17 @@ def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str
req_data = dict(payload=payload, auth_data=auth_data) req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT) url = urljoin(url, WORKER_ENDPOINT)
response = requests.post(url, json=req_data, stream=True) response = requests.post(url, json=req_data, stream=True)
response.raise_for_status() # Raise an exception for bad status codes
for line in response.iter_lines(): for line in response.iter_lines():
payload = line.decode().lstrip("data:").rstrip() payload = line.decode().lstrip("data:").rstrip()
if payload: if payload:
try:
data = json.loads(payload) data = json.loads(payload)
print(data["token"]["text"], end="") print(data["token"]["text"], end="")
sys.stdout.flush() sys.stdout.flush()
except (json.JSONDecodeError, KeyError) as e:
log.warning(f"Failed to parse streaming response: {e}")
continue
print() print()
@@ -79,13 +96,24 @@ if __name__ == "__main__":
from lib.test_utils import test_args from lib.test_utils import test_args
args = test_args.parse_args() args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
)
if endpoint_api_key:
try:
call_generate( call_generate(
api_key=args.api_key, api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url, server_url=args.server_url,
) )
call_generate_stream( call_generate_stream(
api_key=args.api_key, api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url, server_url=args.server_url,
) )
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
+10 -6
View File
@@ -14,8 +14,15 @@ from .data_types import InputData
MODEL_SERVER_URL = "http://0.0.0.0:5001" MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded # This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"'] MODEL_SERVER_START_LOG_MSG = [
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"] '"message":"Connected","target":"text_generation_router"',
'"message":"Connected","target":"text_generation_router::server"',
]
MODEL_SERVER_ERROR_LOG_MSGS = [
"Error: WebserverFailed",
"Error: DownloadError",
"Error: ShardCannotStart",
]
logging.basicConfig( logging.basicConfig(
@@ -99,10 +106,7 @@ backend = Backend(
allow_parallel_requests=True, allow_parallel_requests=True,
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256), benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[ log_actions=[
*[ *[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
(LogAction.ModelLoaded, info_msg)
for info_msg in MODEL_SERVER_START_LOG_MSG
],
(LogAction.Info, '"message":"Download'), (LogAction.Info, '"message":"Download'),
*[ *[
(LogAction.ModelError, error_msg) (LogAction.ModelError, error_msg)