refactor, handle zombie tasks

This commit is contained in:
Lucas Armand
2025-11-12 15:23:42 -08:00
parent d8bb1fcc68
commit 249ca2eb99
+26 -45
View File
@@ -146,21 +146,23 @@ class Backend:
def advance_queue_after_completion(event: asyncio.Event): def advance_queue_after_completion(event: asyncio.Event):
"""Pop current head and wake next waiter, if any.""" """Pop current head and wake next waiter, if any."""
# If this event is current head, wake next waiter
if self.queue and self.queue[0] is event: if self.queue and self.queue[0] is event:
self.queue.popleft() self.queue.popleft()
if self.queue: if self.queue:
self.queue[0].set() self.queue[0].set()
else: else:
# Else, remove it from the queue
try: try:
self.queue.remove(event) self.queue.remove(event)
except ValueError: except ValueError:
pass pass
async def cancel_api_call_if_disconnected() -> web.Response: async def cancel_api_call_if_disconnected() -> None:
await request.wait_for_disconnection() await request.wait_for_disconnection()
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled") log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics) self.metrics._request_canceled(request_metrics)
raise asyncio.CancelledError return
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
try: try:
@@ -177,6 +179,8 @@ class Backend:
res = await handler.generate_client_response(request, response) res = await handler.generate_client_response(request, response)
self.metrics._request_success(request_metrics) self.metrics._request_success(request_metrics)
return res return res
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
log.debug(f"[backend] Request error: {e}") log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(request_metrics) self.metrics._request_errored(request_metrics)
@@ -193,10 +197,14 @@ class Backend:
return web.Response(status=429) return web.Response(status=429)
disconnect_task = create_task(cancel_api_call_if_disconnected()) disconnect_task = create_task(cancel_api_call_if_disconnected())
next_request_task = None
work_task = None
event = asyncio.Event() # Used in finally block, so initialize here
self.metrics._request_start(request_metrics) self.metrics._request_start(request_metrics)
if self.allow_parallel_requests:
try: try:
if self.allow_parallel_requests:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
work_task = create_task(make_request()) work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED) done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
@@ -206,33 +214,19 @@ class Backend:
await asyncio.gather(*pending, return_exceptions=True) await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done: if disconnect_task in done:
# Make sure work_task is settled/cancelled
try:
await work_task
except Exception:
pass
return web.Response(status=499) return web.Response(status=499)
# otherwise work_task completed # otherwise work_task completed
return await work_task return await work_task
except asyncio.CancelledError: # FIFO-queue branch
return web.Response(status=499)
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
finally:
self.metrics._request_end(request_metrics)
else: else:
# Insert a Event into the queue for this request # Insert a Event into the queue for this request
# Event.set() == our request is up next # Event.set() == our request is up next
event = asyncio.Event()
self.queue.append(event) self.queue.append(event)
if self.queue and self.queue[0] is event: if self.queue and self.queue[0] is event:
event.set() event.set()
try:
# Race between our request being next and request being cancelled # Race between our request being next and request being cancelled
next_request_task = create_task(event.wait()) next_request_task = create_task(event.wait())
first_done, first_pending = await wait( first_done, first_pending = await wait(
@@ -240,15 +234,8 @@ class Backend:
) )
# If the disconnect task wins the race # If the disconnect task wins the race
if disconnect_task in first_done and not event.is_set(): if disconnect_task in first_done:
was_head = (self.queue and self.queue[0] is event) # Clean up the next_request_task, then exit
try:
self.queue.remove(event)
except ValueError:
pass
if was_head and self.queue:
self.queue[0].set()
for t in first_pending: for t in first_pending:
t.cancel() t.cancel()
await asyncio.gather(*first_pending, return_exceptions=True) await asyncio.gather(*first_pending, return_exceptions=True)
@@ -259,32 +246,19 @@ class Backend:
# Race the backend API call with the disconnect task # Race the backend API call with the disconnect task
work_task = create_task(make_request()) work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED) done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
for t in pending: for t in pending:
t.cancel() t.cancel()
await asyncio.gather(*pending, return_exceptions=True) await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done: if disconnect_task in done:
# ensure work is cancelled and accounted for
try:
await work_task
except Exception:
pass
return web.Response(status=499) return web.Response(status=499)
# otherwise work_task completed # otherwise work_task completed
return await work_task return await work_task
except asyncio.CancelledError: except asyncio.CancelledError:
# Cleanup if request was cancelled
was_head = (self.queue and self.queue[0] is event)
try:
self.queue.remove(event)
except ValueError:
pass
if was_head and self.queue:
self.queue[0].set()
return web.Response(status=499) return web.Response(status=499)
except Exception as e: except Exception as e:
@@ -292,11 +266,18 @@ class Backend:
return web.Response(status=500) return web.Response(status=500)
finally: finally:
self.metrics._request_end(request_metrics) if not self.allow_parallel_requests:
if event.is_set():
# The request is done, advance the queue
advance_queue_after_completion(event) advance_queue_after_completion(event)
self.metrics._request_end(request_metrics)
cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t]
for t in cleanup_tasks:
if not t.done():
t.cancel()
if cleanup_tasks:
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
@cached_property @cached_property
def healthcheck_session(self): def healthcheck_session(self):
"""Dedicated session for healthchecks to avoid conflicts with API session""" """Dedicated session for healthchecks to avoid conflicts with API session"""