Endpoint update pr one (#1)

* Added endpoint flexibility along with existing log. extended the log support

* Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

* Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors

* Added endpoint flexibility along with existing log. extended the log support

Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors

* Endpoint Utils and API changes
This commit is contained in:
Abiola Akinnubi
2025-06-02 17:13:25 -07:00
committed by Nader Arbabian
parent b1ca68c349
commit 71ed54ebe4
11 changed files with 212 additions and 61 deletions
+13 -9
View File
@@ -14,8 +14,15 @@ from .data_types import InputData
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"']
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"]
MODEL_SERVER_START_LOG_MSG = [
'"message":"Connected","target":"text_generation_router"',
'"message":"Connected","target":"text_generation_router::server"',
]
MODEL_SERVER_ERROR_LOG_MSGS = [
"Error: WebserverFailed",
"Error: DownloadError",
"Error: ShardCannotStart",
]
logging.basicConfig(
@@ -36,7 +43,7 @@ class GenerateHandler(EndpointHandler[InputData]):
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
@@ -62,11 +69,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
@@ -99,10 +106,7 @@ backend = Backend(
allow_parallel_requests=True,
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[
(LogAction.ModelLoaded, info_msg)
for info_msg in MODEL_SERVER_START_LOG_MSG
],
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)