ed7a952bf4
add pyworker v2
136 lines
4.3 KiB
Python
136 lines
4.3 KiB
Python
import os
|
|
import logging
|
|
import dataclasses
|
|
import base64
|
|
from typing import Union, Type
|
|
|
|
from aiohttp import web, ClientResponse
|
|
from anyio import open_file
|
|
|
|
from lib.backend import Backend, LogAction
|
|
from lib.data_types import EndpointHandler
|
|
from lib.server import start_server
|
|
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
|
|
|
|
|
|
MODEL_SERVER_URL = "http://0.0.0.0:38188"
|
|
|
|
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
|
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
|
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
|
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
|
"Value not in list: unet_name", # This error is emitted when the model file is not there at all
|
|
]
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
log = logging.getLogger(__file__)
|
|
|
|
|
|
async def generate_client_response(
|
|
request: web.Request, response: ClientResponse
|
|
) -> Union[web.Response, web.StreamResponse]:
|
|
_ = request
|
|
match response.status:
|
|
case 200:
|
|
log.debug("SUCCESS")
|
|
res = await response.json()
|
|
if "output" not in res:
|
|
return web.json_response(
|
|
data=dict(error="there was an error in the workflow"),
|
|
status=422,
|
|
)
|
|
image_paths = [path["local_path"] for path in res["output"]["images"]]
|
|
if not image_paths:
|
|
return web.json_response(
|
|
data=dict(error="workflow did not produce any images"),
|
|
status=422,
|
|
)
|
|
images = []
|
|
for image_path in image_paths:
|
|
async with await open_file(image_path, mode="rb") as f:
|
|
contents = await f.read()
|
|
images.append(
|
|
f"data:image/png;base64,{base64.b64encode(contents).decode('utf-8')}"
|
|
)
|
|
return web.json_response(data=dict(images=images))
|
|
case code:
|
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
|
return web.Response(status=code)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
|
|
|
|
@property
|
|
def endpoint(self) -> str:
|
|
return "/runsync"
|
|
|
|
@classmethod
|
|
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
|
|
return DefaultComfyWorkflowData
|
|
|
|
def make_benchmark_payload(self) -> DefaultComfyWorkflowData:
|
|
return DefaultComfyWorkflowData.for_test()
|
|
|
|
async def generate_client_response(
|
|
self, client_request: web.Request, model_response: ClientResponse
|
|
) -> Union[web.Response, web.StreamResponse]:
|
|
return await generate_client_response(client_request, model_response)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
|
|
|
|
@property
|
|
def endpoint(self) -> str:
|
|
return "/runsync"
|
|
|
|
@classmethod
|
|
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
|
|
return CustomComfyWorkflowData
|
|
|
|
def make_benchmark_payload(self) -> CustomComfyWorkflowData:
|
|
return CustomComfyWorkflowData.for_test()
|
|
|
|
async def generate_client_response(
|
|
self, client_request: web.Request, model_response: ClientResponse
|
|
) -> Union[web.Response, web.StreamResponse]:
|
|
return await generate_client_response(client_request, model_response)
|
|
|
|
|
|
backend = Backend(
|
|
model_server_url=MODEL_SERVER_URL,
|
|
model_log_file=os.environ["MODEL_LOG"],
|
|
allow_parallel_requests=False,
|
|
benchmark_handler=DefaultComfyWorkflowHandler(
|
|
benchmark_runs=3, benchmark_words=100
|
|
),
|
|
log_actions=[
|
|
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
|
(LogAction.Info, "Downloading:"),
|
|
*[
|
|
(LogAction.ModelError, error_msg)
|
|
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
|
],
|
|
],
|
|
)
|
|
|
|
|
|
async def handle_ping(_):
|
|
return web.Response(body="pong")
|
|
|
|
|
|
routes = [
|
|
web.post("/prompt", backend.create_handler(DefaultComfyWorkflowHandler())),
|
|
web.post("/custom-workflow", backend.create_handler(CustomComfyWorkflowHandler())),
|
|
web.get("/ping", handle_ping),
|
|
]
|
|
|
|
if __name__ == "__main__":
|
|
start_server(backend, routes)
|