Added endpoint flexibility along with existing log. extended the log support

Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support

Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors
This commit is contained in:
Abiola Akinnubi
2025-05-29 17:22:31 -07:00
committed by Nader Arbabian
parent 6de4ee2b59
commit b1ca68c349
6 changed files with 77 additions and 16 deletions
+1 -1
View File
@@ -174,7 +174,7 @@ class CustomComfyWorkflowData(ApiPayload):
@classmethod
def for_test(cls):
raise NotImplemented("Custom comfy workflow is not used for testing")
raise NotImplementedError("Custom comfy workflow is not used for testing")
def count_workload(self) -> float:
return count_workload(
+11 -3
View File
@@ -2,7 +2,7 @@ import os
import logging
import dataclasses
import base64
from typing import Union, Type
from typing import Optional, Union, Type
from aiohttp import web, ClientResponse
from anyio import open_file
@@ -69,7 +69,11 @@ class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
return DefaultComfyWorkflowData
@@ -89,7 +93,11 @@ class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
return CustomComfyWorkflowData
+11 -3
View File
@@ -13,7 +13,7 @@ EndpointHandler. This is useful for endpoints such as healthchecks. See below fo
import os
import logging
import dataclasses
from typing import Dict, Any, Union, Type
from typing import Dict, Any, Optional, Union, Type
from aiohttp import web, ClientResponse
@@ -49,7 +49,11 @@ class GenerateHandler(EndpointHandler[InputData]):
def endpoint(self) -> str:
# the API endpoint
return "/generate"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
@@ -93,7 +97,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
+14 -3
View File
@@ -14,7 +14,7 @@ from .data_types import InputData
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = '"message":"Connected","target":"text_generation_router"'
MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"']
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"]
@@ -33,6 +33,10 @@ class GenerateHandler(EndpointHandler[InputData]):
def endpoint(self) -> str:
return "/generate"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
@@ -58,7 +62,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
@@ -91,7 +99,10 @@ backend = Backend(
allow_parallel_requests=True,
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
*[
(LogAction.ModelLoaded, info_msg)
for info_msg in MODEL_SERVER_START_LOG_MSG
],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)