diff --git a/.gitignore b/.gitignore index c9558f0..226869e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .direnv .envrc __pycache__ +bin/ +lib64 \ No newline at end of file diff --git a/lib/backend.py b/lib/backend.py index a4dc756..2a92bd0 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -186,7 +186,6 @@ class Backend: log.debug(f"Exception in main handler loop {e}") return web.Response(status=500) - async def __healthcheck(self): health_check_url = self.benchmark_handler.healthcheck_endpoint if health_check_url is None: @@ -200,19 +199,20 @@ class Backend: log.debug("Healthcheck successful") elif response.status == 503: log.debug(f"Healthcheck failed with status: {response.status}") - self.backend_errored(f"Healthcheck failed with status: {response.status}") + self.backend_errored( + f"Healthcheck failed with status: {response.status}" + ) else: # endpoint not ready yet so bail - log.debug( - f"Healthcheck Endpoint not ready: {response.status}" - ) + log.debug(f"Healthcheck Endpoint not ready: {response.status}") except Exception as e: log.debug(f"Healthcheck failed with exception: {e}") self.backend_errored(str(e)) - async def _start_tracking(self) -> None: - await gather(self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()) + await gather( + self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck() + ) def backend_errored(self, msg: str) -> None: self.metrics._model_errored(msg) diff --git a/lib/data_types.py b/lib/data_types.py index c90d472..9af8d03 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -8,6 +8,7 @@ from aiohttp import web, ClientResponse import inspect import psutil +import requests """ @@ -22,8 +23,10 @@ class JsonDataException(Exception): def __init__(self, json_msg: Dict[str, Any]): self.message = json_msg + ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload") + @dataclass class ApiPayload(ABC): @@ -45,7 +48,9 @@ class ApiPayload(ABC): @classmethod @abstractmethod - def from_json_msg(cls: Type[ApiPayload_T], json_msg: Dict[str, Any]) -> ApiPayload_T: + def from_json_msg( + cls: Type[ApiPayload_T], json_msg: Dict[str, Any] + ) -> ApiPayload_T: """ defines how to create an API payload from a JSON message, it should throw an JsonDataException if there are issues with some fields @@ -84,8 +89,6 @@ class AuthData: ) - - @dataclass class EndpointHandler(ABC, Generic[ApiPayload_T]): """ @@ -108,7 +111,6 @@ class EndpointHandler(ABC, Generic[ApiPayload_T]): """the endpoint on the model API that is used for healthchecks""" pass - @classmethod @abstractmethod def payload_cls(cls) -> Type[ApiPayload_T]: diff --git a/lib/test_utils.py b/lib/test_utils.py index fdb635c..3002dce 100644 --- a/lib/test_utils.py +++ b/lib/test_utils.py @@ -1,3 +1,4 @@ +import logging import os import time import argparse @@ -13,6 +14,13 @@ import requests from lib.data_types import AuthData, ApiPayload +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s[%(levelname)-5s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +log = logging.getLogger(__file__) + class ClientStatus(Enum): FetchEndpoint = 1 @@ -71,9 +79,19 @@ class ClientState: def make_call(self): self.status = ClientStatus.FetchEndpoint + endpoint_api_key = AuthData.get_endpoint_api_key( + endpoint_name=self.endpoint_group_name, + account_api_key=self.api_key, + ) + if not endpoint_api_key: + self.as_error.append( + f"Endpoint {self.endpoint_group_name} not found for API key", + ) + self.status = ClientStatus.Error + return route_payload = { "endpoint": self.endpoint_group_name, - "api_key": self.api_key, + "api_key": endpoint_api_key, "cost": self.payload.count_workload(), } response = requests.post( @@ -215,11 +233,18 @@ def run_test( print_thread = threading.Thread(target=print_state, args=(clients, num_requests)) print_thread.daemon = True # makes threads get killed on program exit print_thread.start() + endpoint_api_key = AuthData.get_endpoint_api_key( + endpoint_name=endpoint_group_name, + account_api_key=api_key, + ) + if not endpoint_api_key: + log.debug(f"Endpoint {endpoint_group_name} not found for API key") + return try: for _ in range(num_requests): client = ClientState( endpoint_group_name=endpoint_group_name, - api_key=api_key, + api_key=endpoint_api_key, server_url=server_url, worker_endpoint=worker_endpoint, payload=payload_cls.for_test(), diff --git a/start_server.sh b/start_server.sh index 45a40ee..d7da082 100755 --- a/start_server.sh +++ b/start_server.sh @@ -8,11 +8,10 @@ SERVER_DIR="$WORKSPACE_DIR/vast-pyworker" ENV_PATH="$WORKSPACE_DIR/worker-env" DEBUG_LOG="$WORKSPACE_DIR/debug.log" PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log" - +MODEL_LOG="$WORKSPACE_DIR/model.log" REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}" USE_SSL="${USE_SSL:-true}" WORKER_PORT="${WORKER_PORT:-3000}" - mkdir -p "$WORKSPACE_DIR" cd "$WORKSPACE_DIR" @@ -47,7 +46,7 @@ env | grep _ >> /etc/environment; if [ ! -d "$ENV_PATH" ] then - apt install -y python3.10-venv + apt install -y python3.12-venv echo "setting up venv" git clone https://github.com/vast-ai/pyworker "$SERVER_DIR" diff --git a/utils/endpoint_util.py b/utils/endpoint_util.py new file mode 100644 index 0000000..4886dd9 --- /dev/null +++ b/utils/endpoint_util.py @@ -0,0 +1,69 @@ +from logging import log +from typing import Any, Dict, Optional + +import requests + + +class Endpoint: + """ + Utility class for handling endpoint operations. + """ + + @staticmethod + def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]: + """ + Fetch endpoint API key from VastAI console following the healthcheck pattern. + + Args: + endpoint_name: Name of the endpoint + account_api_key: Account API key for authentication + + Returns: + Endpoint API key if successful, None otherwise + """ + vast_console_url = "https://console.vast.ai/api/v0/endptjobs/" + headers = {"Authorization": f"Bearer {account_api_key}"} + + try: + log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}") + response = requests.get(vast_console_url, headers=headers) + + if response.status_code != 200: + error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}" + log.debug(error_msg) + return None + + try: + data = response.json() + except requests.exceptions.JSONDecodeError as e: + log.debug(f"Failed to parse JSON response: {e}") + return None + + result = data.get("results", []) + + endpoint: Optional[Dict[str, Any]] = next( + (item for item in result if item["endpoint_name"] == endpoint_name), + None, + ) + if not endpoint: + error_msg = f"Endpoint '{endpoint_name}' not found." + log.debug(error_msg) + return None + + endpoint_api_key = endpoint.get("api_key") + if not endpoint_api_key: + error_msg = f"API key for endpoint '{endpoint_name}' not found." + log.debug(error_msg) + return None + + log.debug(f"Successfully retrieved API key for endpoint: {endpoint_name}") + return endpoint_api_key + + except requests.exceptions.RequestException as e: + error_msg = f"Request error while fetching endpoint API key: {e}" + log.debug(error_msg) + return None + except Exception as e: + error_msg = f"Unexpected error while fetching endpoint API key: {e}" + log.debug(error_msg) + return None diff --git a/workers/comfyui/client.py b/workers/comfyui/client.py index 2c0a9b7..23c8ca0 100644 --- a/workers/comfyui/client.py +++ b/workers/comfyui/client.py @@ -1,12 +1,20 @@ +import logging from urllib.parse import urljoin import requests from lib.test_utils import print_truncate_res +from utils.endpoint_util import Endpoint """ NOTE: this client example uses a custom comfy workflow compatible with SD3 only """ +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s[%(levelname)-5s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +log = logging.getLogger(__file__) def call_default_workflow( @@ -24,6 +32,7 @@ def call_default_workflow( json=route_payload, timeout=4, ) + response.raise_for_status() message = response.json() url = message["url"] auth_data = dict( @@ -43,6 +52,7 @@ def call_default_workflow( url, json=req_data, ) + response.raise_for_status() print_truncate_res(str(response.json())) @@ -61,6 +71,7 @@ def call_custom_workflow_for_sd3( json=route_payload, timeout=4, ) + response.raise_for_status() message = response.json() url = message["url"] auth_data = dict( @@ -131,6 +142,7 @@ def call_custom_workflow_for_sd3( url, json=req_data, ) + response.raise_for_status() print_truncate_res(str(response.json())) @@ -138,13 +150,23 @@ if __name__ == "__main__": from lib.test_utils import test_args args = test_args.parse_args() - call_default_workflow( - api_key=args.api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, - ) - call_custom_workflow_for_sd3( - api_key=args.api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, + endpoint_api_key = Endpoint.get_endpoint_api_key( + endpoint_name=args.endpoint_group_name, + account_api_key=args.api_key, ) + if endpoint_api_key: + try: + call_default_workflow( + api_key=endpoint_api_key, + endpoint_group_name=args.endpoint_group_name, + server_url=args.server_url, + ) + call_custom_workflow_for_sd3( + api_key=endpoint_api_key, + endpoint_group_name=args.endpoint_group_name, + server_url=args.server_url, + ) + except Exception as e: + log.error(f"Error during API call: {e}") + else: + log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ") diff --git a/workers/comfyui/server.py b/workers/comfyui/server.py index a999f74..40ee389 100644 --- a/workers/comfyui/server.py +++ b/workers/comfyui/server.py @@ -69,11 +69,11 @@ class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]): @property def endpoint(self) -> str: return "/runsync" - + @property def healthcheck_endpoint(self) -> Optional[str]: return None - + @classmethod def payload_cls(cls) -> Type[DefaultComfyWorkflowData]: return DefaultComfyWorkflowData @@ -93,11 +93,11 @@ class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]): @property def endpoint(self) -> str: return "/runsync" - + @property def healthcheck_endpoint(self) -> Optional[str]: return None - + @classmethod def payload_cls(cls) -> Type[CustomComfyWorkflowData]: return CustomComfyWorkflowData diff --git a/workers/hello_world/server.py b/workers/hello_world/server.py index d9c439f..91fb9a5 100644 --- a/workers/hello_world/server.py +++ b/workers/hello_world/server.py @@ -49,11 +49,11 @@ class GenerateHandler(EndpointHandler[InputData]): def endpoint(self) -> str: # the API endpoint return "/generate" - + @property def healthcheck_endpoint(self) -> Optional[str]: return None - + @classmethod def payload_cls(cls) -> Type[InputData]: return InputData @@ -97,11 +97,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]): @property def endpoint(self) -> str: return "/generate_stream" - + @property def healthcheck_endpoint(self) -> Optional[str]: return None - + @classmethod def payload_cls(cls) -> Type[InputData]: return InputData diff --git a/workers/tgi/client.py b/workers/tgi/client.py index 966798a..cfa91f8 100644 --- a/workers/tgi/client.py +++ b/workers/tgi/client.py @@ -1,8 +1,16 @@ +import logging import sys import json from urllib.parse import urljoin - import requests +from utils.endpoint_util import Endpoint + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s[%(levelname)-5s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +log = logging.getLogger(__file__) def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None: @@ -18,28 +26,31 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No json=route_payload, timeout=4, ) + response.raise_for_status() # Raise an exception for bad status codes message = response.json() url = message["url"] + auth_data = dict( signature=message["signature"], cost=message["cost"], endpoint=message["endpoint"], reqnum=message["reqnum"], - url=message["url"], + url=url, ) + payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500)) req_data = dict(payload=payload, auth_data=auth_data) url = urljoin(url, WORKER_ENDPOINT) print(f"url: {url}") - response = requests.post( - url, - json=req_data, - ) + response = requests.post(url, json=req_data) + response.raise_for_status() res = response.json() print(res) -def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str): +def call_generate_stream( + endpoint_group_name: str, api_key: str, server_url: str +) -> None: WORKER_ENDPOINT = "/generate_stream" COST = 100 route_payload = { @@ -52,6 +63,7 @@ def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str json=route_payload, timeout=4, ) + response.raise_for_status() # Raise an exception for bad status codes message = response.json() url = message["url"] print(f"url: {url}") @@ -66,12 +78,17 @@ def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str req_data = dict(payload=payload, auth_data=auth_data) url = urljoin(url, WORKER_ENDPOINT) response = requests.post(url, json=req_data, stream=True) + response.raise_for_status() # Raise an exception for bad status codes for line in response.iter_lines(): payload = line.decode().lstrip("data:").rstrip() if payload: - data = json.loads(payload) - print(data["token"]["text"], end="") - sys.stdout.flush() + try: + data = json.loads(payload) + print(data["token"]["text"], end="") + sys.stdout.flush() + except (json.JSONDecodeError, KeyError) as e: + log.warning(f"Failed to parse streaming response: {e}") + continue print() @@ -79,13 +96,24 @@ if __name__ == "__main__": from lib.test_utils import test_args args = test_args.parse_args() - call_generate( - api_key=args.api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, - ) - call_generate_stream( - api_key=args.api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, + + endpoint_api_key = Endpoint.get_endpoint_api_key( + endpoint_name=args.endpoint_group_name, + account_api_key=args.api_key, ) + if endpoint_api_key: + try: + call_generate( + api_key=endpoint_api_key, + endpoint_group_name=args.endpoint_group_name, + server_url=args.server_url, + ) + call_generate_stream( + api_key=endpoint_api_key, + endpoint_group_name=args.endpoint_group_name, + server_url=args.server_url, + ) + except Exception as e: + log.error(f"Error during API call: {e}") + else: + log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ") diff --git a/workers/tgi/server.py b/workers/tgi/server.py index 95cc5bc..99fc810 100644 --- a/workers/tgi/server.py +++ b/workers/tgi/server.py @@ -14,8 +14,15 @@ from .data_types import InputData MODEL_SERVER_URL = "http://0.0.0.0:5001" # This is the last log line that gets emitted once comfyui+extensions have been fully loaded -MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"'] -MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"] +MODEL_SERVER_START_LOG_MSG = [ + '"message":"Connected","target":"text_generation_router"', + '"message":"Connected","target":"text_generation_router::server"', +] +MODEL_SERVER_ERROR_LOG_MSGS = [ + "Error: WebserverFailed", + "Error: DownloadError", + "Error: ShardCannotStart", +] logging.basicConfig( @@ -36,7 +43,7 @@ class GenerateHandler(EndpointHandler[InputData]): @property def healthcheck_endpoint(self) -> str: return f"{MODEL_SERVER_URL}/health" - + @classmethod def payload_cls(cls) -> Type[InputData]: return InputData @@ -62,11 +69,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]): @property def endpoint(self) -> str: return "/generate_stream" - + @property def healthcheck_endpoint(self) -> str: return f"{MODEL_SERVER_URL}/health" - + @classmethod def payload_cls(cls) -> Type[InputData]: return InputData @@ -99,10 +106,7 @@ backend = Backend( allow_parallel_requests=True, benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256), log_actions=[ - *[ - (LogAction.ModelLoaded, info_msg) - for info_msg in MODEL_SERVER_START_LOG_MSG - ], + *[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG], (LogAction.Info, '"message":"Download'), *[ (LogAction.ModelError, error_msg)