Endpoint update pr one (#1)
* Added endpoint flexibility along with existing log. extended the log support * Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support * Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors * Added endpoint flexibility along with existing log. extended the log support Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors * Endpoint Utils and API changes
This commit is contained in:
committed by
Nader Arbabian
parent
b1ca68c349
commit
71ed54ebe4
+13
-9
@@ -14,8 +14,15 @@ from .data_types import InputData
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"']
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"]
|
||||
MODEL_SERVER_START_LOG_MSG = [
|
||||
'"message":"Connected","target":"text_generation_router"',
|
||||
'"message":"Connected","target":"text_generation_router::server"',
|
||||
]
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"Error: WebserverFailed",
|
||||
"Error: DownloadError",
|
||||
"Error: ShardCannotStart",
|
||||
]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -36,7 +43,7 @@ class GenerateHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> str:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
@@ -62,11 +69,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> str:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
@@ -99,10 +106,7 @@ backend = Backend(
|
||||
allow_parallel_requests=True,
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
*[
|
||||
(LogAction.ModelLoaded, info_msg)
|
||||
for info_msg in MODEL_SERVER_START_LOG_MSG
|
||||
],
|
||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
|
||||
Reference in New Issue
Block a user