Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 353462ecb8 |
+36
-96
@@ -9,7 +9,6 @@ from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||
from functools import cached_property
|
||||
from distutils.util import strtobool
|
||||
from collections import deque
|
||||
|
||||
from anyio import open_file
|
||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
||||
@@ -31,7 +30,7 @@ from lib.data_types import (
|
||||
BenchmarkResult
|
||||
)
|
||||
|
||||
VERSION = "0.2.1"
|
||||
VERSION = "0.2.0"
|
||||
|
||||
MSG_HISTORY_LEN = 100
|
||||
log = logging.getLogger(__file__)
|
||||
@@ -64,7 +63,6 @@ class Backend:
|
||||
version = VERSION
|
||||
msg_history = []
|
||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||
queue: deque = dataclasses.field(default_factory=deque, repr=False)
|
||||
unsecured: bool = dataclasses.field(
|
||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||
)
|
||||
@@ -143,26 +141,11 @@ class Backend:
|
||||
workload = payload.count_workload()
|
||||
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
||||
|
||||
|
||||
def advance_queue_after_completion(event: asyncio.Event):
|
||||
"""Pop current head and wake next waiter, if any."""
|
||||
# If this event is current head, wake next waiter
|
||||
if self.queue and self.queue[0] is event:
|
||||
self.queue.popleft()
|
||||
if self.queue:
|
||||
self.queue[0].set()
|
||||
else:
|
||||
# Else, remove it from the queue
|
||||
try:
|
||||
self.queue.remove(event)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def cancel_api_call_if_disconnected() -> None:
|
||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||
await request.wait_for_disconnection()
|
||||
log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled")
|
||||
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
|
||||
self.metrics._request_canceled(request_metrics)
|
||||
return
|
||||
raise asyncio.CancelledError
|
||||
|
||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||
try:
|
||||
@@ -179,9 +162,7 @@ class Backend:
|
||||
res = await handler.generate_client_response(request, response)
|
||||
self.metrics._request_success(request_metrics)
|
||||
return res
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
except requests.exceptions.RequestException as e:
|
||||
log.debug(f"[backend] Request error: {e}")
|
||||
self.metrics._request_errored(request_metrics)
|
||||
return web.Response(status=500)
|
||||
@@ -196,87 +177,46 @@ class Backend:
|
||||
self.metrics._request_reject(request_metrics)
|
||||
return web.Response(status=429)
|
||||
|
||||
disconnect_task = create_task(cancel_api_call_if_disconnected())
|
||||
next_request_task = None
|
||||
work_task = None
|
||||
event = asyncio.Event() # Used in finally block, so initialize here
|
||||
|
||||
self.metrics._request_start(request_metrics)
|
||||
|
||||
acquired = False
|
||||
try:
|
||||
if self.allow_parallel_requests:
|
||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||
work_task = create_task(make_request())
|
||||
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
|
||||
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
if disconnect_task in done:
|
||||
return web.Response(status=499)
|
||||
|
||||
# otherwise work_task completed
|
||||
return await work_task
|
||||
|
||||
# FIFO-queue branch
|
||||
else:
|
||||
# Insert a Event into the queue for this request
|
||||
# Event.set() == our request is up next
|
||||
self.queue.append(event)
|
||||
if self.queue and self.queue[0] is event:
|
||||
event.set()
|
||||
|
||||
# Race between our request being next and request being cancelled
|
||||
next_request_task = create_task(event.wait())
|
||||
first_done, first_pending = await wait(
|
||||
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
|
||||
self.metrics._request_start(request_metrics)
|
||||
if self.allow_parallel_requests is False:
|
||||
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
||||
await self.sem.acquire()
|
||||
acquired = True
|
||||
log.debug(
|
||||
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
|
||||
)
|
||||
else:
|
||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||
done, pending = await wait(
|
||||
[
|
||||
create_task(make_request()),
|
||||
create_task(cancel_api_call_if_disconnected()),
|
||||
],
|
||||
return_when=FIRST_COMPLETED,
|
||||
)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
# If the disconnect task wins the race
|
||||
if disconnect_task in first_done:
|
||||
# Clean up the next_request_task, then exit
|
||||
for t in first_pending:
|
||||
t.cancel()
|
||||
await asyncio.gather(*first_pending, return_exceptions=True)
|
||||
return web.Response(status=499)
|
||||
|
||||
# We are the next-up request in the queue
|
||||
log.debug(f"Starting work on request {request_metrics.reqnum}...")
|
||||
|
||||
# Race the backend API call with the disconnect task
|
||||
work_task = create_task(make_request())
|
||||
|
||||
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
if disconnect_task in done:
|
||||
return web.Response(status=499)
|
||||
|
||||
# otherwise work_task completed
|
||||
return await work_task
|
||||
|
||||
done_task = done.pop()
|
||||
try:
|
||||
return done_task.result()
|
||||
except Exception as e:
|
||||
log.debug(f"Request task raised exception: {e}")
|
||||
return web.Response(status=500)
|
||||
except asyncio.CancelledError:
|
||||
return web.Response(status=499)
|
||||
|
||||
# Client is gone. Do not write a response; just unwind.
|
||||
return web.Response(status=499)
|
||||
except Exception as e:
|
||||
log.debug(f"Exception in main handler loop {e}")
|
||||
return web.Response(status=500)
|
||||
|
||||
finally:
|
||||
if not self.allow_parallel_requests:
|
||||
advance_queue_after_completion(event)
|
||||
|
||||
# Always release the semaphore if it was acquired
|
||||
if acquired:
|
||||
self.sem.release()
|
||||
self.metrics._request_end(request_metrics)
|
||||
cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t]
|
||||
for t in cleanup_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
if cleanup_tasks:
|
||||
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
@cached_property
|
||||
def healthcheck_session(self):
|
||||
|
||||
@@ -89,7 +89,7 @@ class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=False,
|
||||
allow_parallel_requests=True,
|
||||
benchmark_handler=ComfyWorkflowHandler(
|
||||
benchmark_runs=3, benchmark_words=100
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user