Compare commits

..

2 Commits

Author SHA1 Message Date
Lucas Armand 249ca2eb99 refactor, handle zombie tasks 2025-11-12 15:23:42 -08:00
Lucas Armand d8bb1fcc68 add fifo queue
Bump pyworker version
2025-11-12 12:26:15 -08:00
3 changed files with 125 additions and 124 deletions
+98 -38
View File
@@ -9,6 +9,7 @@ from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool from distutils.util import strtobool
from collections import deque
from anyio import open_file from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
@@ -30,7 +31,7 @@ from lib.data_types import (
BenchmarkResult BenchmarkResult
) )
VERSION = "0.2.0" VERSION = "0.2.1"
MSG_HISTORY_LEN = 100 MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
@@ -63,6 +64,7 @@ class Backend:
version = VERSION version = VERSION
msg_history = [] msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore) sem: Semaphore = dataclasses.field(default_factory=Semaphore)
queue: deque = dataclasses.field(default_factory=deque, repr=False)
unsecured: bool = dataclasses.field( unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
) )
@@ -141,11 +143,26 @@ class Backend:
workload = payload.count_workload() workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
async def cancel_api_call_if_disconnected() -> web.Response:
def advance_queue_after_completion(event: asyncio.Event):
"""Pop current head and wake next waiter, if any."""
# If this event is current head, wake next waiter
if self.queue and self.queue[0] is event:
self.queue.popleft()
if self.queue:
self.queue[0].set()
else:
# Else, remove it from the queue
try:
self.queue.remove(event)
except ValueError:
pass
async def cancel_api_call_if_disconnected() -> None:
await request.wait_for_disconnection() await request.wait_for_disconnection()
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled") log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics) self.metrics._request_canceled(request_metrics)
raise asyncio.CancelledError return
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
try: try:
@@ -162,7 +179,9 @@ class Backend:
res = await handler.generate_client_response(request, response) res = await handler.generate_client_response(request, response)
self.metrics._request_success(request_metrics) self.metrics._request_success(request_metrics)
return res return res
except requests.exceptions.RequestException as e: except asyncio.CancelledError:
raise
except Exception as e:
log.debug(f"[backend] Request error: {e}") log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(request_metrics) self.metrics._request_errored(request_metrics)
return web.Response(status=500) return web.Response(status=500)
@@ -177,46 +196,87 @@ class Backend:
self.metrics._request_reject(request_metrics) self.metrics._request_reject(request_metrics)
return web.Response(status=429) return web.Response(status=429)
acquired = False disconnect_task = create_task(cancel_api_call_if_disconnected())
try: next_request_task = None
self.metrics._request_start(request_metrics) work_task = None
if self.allow_parallel_requests is False: event = asyncio.Event() # Used in finally block, so initialize here
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
acquired = True
log.debug(
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
done_task = done.pop() self.metrics._request_start(request_metrics)
try:
return done_task.result() try:
except Exception as e: if self.allow_parallel_requests:
log.debug(f"Request task raised exception: {e}") log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
return web.Response(status=500) work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done:
return web.Response(status=499)
# otherwise work_task completed
return await work_task
# FIFO-queue branch
else:
# Insert a Event into the queue for this request
# Event.set() == our request is up next
self.queue.append(event)
if self.queue and self.queue[0] is event:
event.set()
# Race between our request being next and request being cancelled
next_request_task = create_task(event.wait())
first_done, first_pending = await wait(
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
)
# If the disconnect task wins the race
if disconnect_task in first_done:
# Clean up the next_request_task, then exit
for t in first_pending:
t.cancel()
await asyncio.gather(*first_pending, return_exceptions=True)
return web.Response(status=499)
# We are the next-up request in the queue
log.debug(f"Starting work on request {request_metrics.reqnum}...")
# Race the backend API call with the disconnect task
work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done:
return web.Response(status=499)
# otherwise work_task completed
return await work_task
except asyncio.CancelledError: except asyncio.CancelledError:
# Client is gone. Do not write a response; just unwind. return web.Response(status=499)
return web.Response(status=499)
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
finally: finally:
# Always release the semaphore if it was acquired if not self.allow_parallel_requests:
if acquired: advance_queue_after_completion(event)
self.sem.release()
self.metrics._request_end(request_metrics) self.metrics._request_end(request_metrics)
cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t]
for t in cleanup_tasks:
if not t.done():
t.cancel()
if cleanup_tasks:
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
@cached_property @cached_property
def healthcheck_session(self): def healthcheck_session(self):
+25 -45
View File
@@ -3,58 +3,38 @@ import logging
from typing import List from typing import List
import ssl import ssl
from asyncio import run, gather from asyncio import run, gather
import asyncio
from lib.backend import Backend from lib.backend import Backend
from lib.metrics import Metrics
from aiohttp import web from aiohttp import web
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs): def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
try: log.debug("getting certificate...")
log.debug("getting certificate...") use_ssl = os.environ.get("USE_SSL", "false") == "true"
use_ssl = os.environ.get("USE_SSL", "false") == "true" if use_ssl is True:
if use_ssl is True: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_context.load_cert_chain(
ssl_context.load_cert_chain( certfile="/etc/instance.crt",
certfile="/etc/instance.crt", keyfile="/etc/instance.key",
keyfile="/etc/instance.key", )
) else:
else: ssl_context = None
ssl_context = None
async def main(): async def main():
log.debug("starting server...") log.debug("starting server...")
app = web.Application() app = web.Application()
app.add_routes(routes) app.add_routes(routes)
runner = web.AppRunner(app) runner = web.AppRunner(app)
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(
runner, runner,
ssl_context=ssl_context, ssl_context=ssl_context,
port=int(os.environ["WORKER_PORT"]), port=int(os.environ["WORKER_PORT"]),
**kwargs **kwargs
) )
await gather(site.start(), backend._start_tracking()) await gather(site.start(), backend._start_tracking())
run(main()) run(main())
except Exception as e:
err_msg = f"PyWorker failed to launch: {e}"
log.error(err_msg)
async def beacon():
metrics = Metrics()
metrics._set_version(getattr(backend, "version", "0"))
metrics._set_mtoken(getattr(backend, "mtoken", ""))
try:
while True:
metrics._model_errored(err_msg)
await metrics._Metrics__send_metrics_and_reset()
await asyncio.sleep(10)
finally:
await metrics.aclose()
run(beacon())
+2 -41
View File
@@ -128,44 +128,5 @@ echo "launching PyWorker server"
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only # from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG" [ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
set +e echo "launching PyWorker server done"
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
set -e
if [ "${PY_STATUS}" -ne 0 ]; then
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
MTOKEN="${MASTER_TOKEN:-}"
VERSION="${PYWORKER_VERSION:-0}"
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
for addr in "${REPORT_ADDRS[@]}"; do
curl -sS -X POST -H 'Content-Type: application/json' \
-d "$(cat <<JSON
{
"id": ${CONTAINER_ID:-0},
"mtoken": "${MTOKEN}",
"version": "${VERSION}",
"loadtime": 0,
"new_load": 0,
"cur_load": 0,
"rej_load": 0,
"max_perf": 0,
"cur_perf": 0,
"error_msg": "${ERROR_MSG}",
"num_requests_working": 0,
"num_requests_recieved": 0,
"additional_disk_usage": 0,
"working_request_idxs": [],
"cur_capacity": 0,
"max_capacity": 0,
"url": "${URL}"
}
JSON
)" "${addr%/}/worker_status/" || true
done
fi
echo "launching PyWorker server done"