Added endpoint flexibility along with existing log. extended the log support
Switched Endpoint back to vast-ai, Added endpoint flexibility along with existing log. extended the log support Modify the endpoint return type as optional and check via pyright to ensure there are not compilation/type errors
This commit is contained in:
committed by
Nader Arbabian
parent
6de4ee2b59
commit
b1ca68c349
@@ -174,7 +174,7 @@ class CustomComfyWorkflowData(ApiPayload):
|
||||
|
||||
@classmethod
|
||||
def for_test(cls):
|
||||
raise NotImplemented("Custom comfy workflow is not used for testing")
|
||||
raise NotImplementedError("Custom comfy workflow is not used for testing")
|
||||
|
||||
def count_workload(self) -> float:
|
||||
return count_workload(
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import logging
|
||||
import dataclasses
|
||||
import base64
|
||||
from typing import Union, Type
|
||||
from typing import Optional, Union, Type
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
from anyio import open_file
|
||||
@@ -69,7 +69,11 @@ class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/runsync"
|
||||
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
|
||||
return DefaultComfyWorkflowData
|
||||
@@ -89,7 +93,11 @@ class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/runsync"
|
||||
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
|
||||
return CustomComfyWorkflowData
|
||||
|
||||
@@ -13,7 +13,7 @@ EndpointHandler. This is useful for endpoints such as healthchecks. See below fo
|
||||
import os
|
||||
import logging
|
||||
import dataclasses
|
||||
from typing import Dict, Any, Union, Type
|
||||
from typing import Dict, Any, Optional, Union, Type
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
|
||||
@@ -49,7 +49,11 @@ class GenerateHandler(EndpointHandler[InputData]):
|
||||
def endpoint(self) -> str:
|
||||
# the API endpoint
|
||||
return "/generate"
|
||||
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
@@ -93,7 +97,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
+14
-3
@@ -14,7 +14,7 @@ from .data_types import InputData
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = '"message":"Connected","target":"text_generation_router"'
|
||||
MODEL_SERVER_START_LOG_MSG = ['"message":"Connected","target":"text_generation_router"', '"message":"Connected","target":"text_generation_router::server"']
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError", "Error: ShardCannotStart"]
|
||||
|
||||
|
||||
@@ -33,6 +33,10 @@ class GenerateHandler(EndpointHandler[InputData]):
|
||||
def endpoint(self) -> str:
|
||||
return "/generate"
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> str:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
@@ -58,7 +62,11 @@ class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> str:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
@@ -91,7 +99,10 @@ backend = Backend(
|
||||
allow_parallel_requests=True,
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
*[
|
||||
(LogAction.ModelLoaded, info_msg)
|
||||
for info_msg in MODEL_SERVER_START_LOG_MSG
|
||||
],
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
|
||||
Reference in New Issue
Block a user