Compare commits

..

5 Commits

Author SHA1 Message Date
Lucas Armand a47c9d1ed0 remove test bugs 2025-11-11 18:13:46 -08:00
Lucas Armand 0b14562a63 dont exit on pyworker fail 2025-11-11 17:57:08 -08:00
Lucas Armand de9b50abb9 use set +e 2025-11-11 17:53:36 -08:00
Lucas Armand c510801723 fix 2025-11-11 17:49:34 -08:00
Lucas Armand a12523b1d2 Added bad code to tgi server to test 2025-11-11 17:41:12 -08:00
3 changed files with 122 additions and 123 deletions
+32 -92
View File
@@ -9,7 +9,6 @@ from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool from distutils.util import strtobool
from collections import deque
from anyio import open_file from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
@@ -31,7 +30,7 @@ from lib.data_types import (
BenchmarkResult BenchmarkResult
) )
VERSION = "0.2.1" VERSION = "0.2.0"
MSG_HISTORY_LEN = 100 MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
@@ -64,7 +63,6 @@ class Backend:
version = VERSION version = VERSION
msg_history = [] msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore) sem: Semaphore = dataclasses.field(default_factory=Semaphore)
queue: deque = dataclasses.field(default_factory=deque, repr=False)
unsecured: bool = dataclasses.field( unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
) )
@@ -143,26 +141,11 @@ class Backend:
workload = payload.count_workload() workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
async def cancel_api_call_if_disconnected() -> web.Response:
def advance_queue_after_completion(event: asyncio.Event):
"""Pop current head and wake next waiter, if any."""
# If this event is current head, wake next waiter
if self.queue and self.queue[0] is event:
self.queue.popleft()
if self.queue:
self.queue[0].set()
else:
# Else, remove it from the queue
try:
self.queue.remove(event)
except ValueError:
pass
async def cancel_api_call_if_disconnected() -> None:
await request.wait_for_disconnection() await request.wait_for_disconnection()
log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled") log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics) self.metrics._request_canceled(request_metrics)
return raise asyncio.CancelledError
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
try: try:
@@ -179,9 +162,7 @@ class Backend:
res = await handler.generate_client_response(request, response) res = await handler.generate_client_response(request, response)
self.metrics._request_success(request_metrics) self.metrics._request_success(request_metrics)
return res return res
except asyncio.CancelledError: except requests.exceptions.RequestException as e:
raise
except Exception as e:
log.debug(f"[backend] Request error: {e}") log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(request_metrics) self.metrics._request_errored(request_metrics)
return web.Response(status=500) return web.Response(status=500)
@@ -196,87 +177,46 @@ class Backend:
self.metrics._request_reject(request_metrics) self.metrics._request_reject(request_metrics)
return web.Response(status=429) return web.Response(status=429)
disconnect_task = create_task(cancel_api_call_if_disconnected()) acquired = False
next_request_task = None
work_task = None
event = asyncio.Event() # Used in finally block, so initialize here
self.metrics._request_start(request_metrics)
try: try:
if self.allow_parallel_requests: self.metrics._request_start(request_metrics)
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") if self.allow_parallel_requests is False:
work_task = create_task(make_request()) log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED) await self.sem.acquire()
acquired = True
for t in pending: log.debug(
t.cancel() f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
await asyncio.gather(*pending, return_exceptions=True) )
else:
if disconnect_task in done: log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
return web.Response(status=499) done, pending = await wait(
[
# otherwise work_task completed create_task(make_request()),
return await work_task create_task(cancel_api_call_if_disconnected()),
],
# FIFO-queue branch return_when=FIRST_COMPLETED,
else:
# Insert a Event into the queue for this request
# Event.set() == our request is up next
self.queue.append(event)
if self.queue and self.queue[0] is event:
event.set()
# Race between our request being next and request being cancelled
next_request_task = create_task(event.wait())
first_done, first_pending = await wait(
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
) )
# If the disconnect task wins the race
if disconnect_task in first_done:
# Clean up the next_request_task, then exit
for t in first_pending:
t.cancel()
await asyncio.gather(*first_pending, return_exceptions=True)
return web.Response(status=499)
# We are the next-up request in the queue
log.debug(f"Starting work on request {request_metrics.reqnum}...")
# Race the backend API call with the disconnect task
work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
for t in pending: for t in pending:
t.cancel() t.cancel()
await asyncio.gather(*pending, return_exceptions=True) await asyncio.gather(*pending, return_exceptions=True)
if disconnect_task in done: done_task = done.pop()
return web.Response(status=499) try:
return done_task.result()
# otherwise work_task completed except Exception as e:
return await work_task log.debug(f"Request task raised exception: {e}")
return web.Response(status=500)
except asyncio.CancelledError: except asyncio.CancelledError:
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499) return web.Response(status=499)
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
finally: finally:
if not self.allow_parallel_requests: # Always release the semaphore if it was acquired
advance_queue_after_completion(event) if acquired:
self.sem.release()
self.metrics._request_end(request_metrics) self.metrics._request_end(request_metrics)
cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t]
for t in cleanup_tasks:
if not t.done():
t.cancel()
if cleanup_tasks:
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
@cached_property @cached_property
def healthcheck_session(self): def healthcheck_session(self):
+21 -1
View File
@@ -3,15 +3,17 @@ import logging
from typing import List from typing import List
import ssl import ssl
from asyncio import run, gather from asyncio import run, gather
import asyncio
from lib.backend import Backend from lib.backend import Backend
from lib.metrics import Metrics
from aiohttp import web from aiohttp import web
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs): def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
try:
log.debug("getting certificate...") log.debug("getting certificate...")
use_ssl = os.environ.get("USE_SSL", "false") == "true" use_ssl = os.environ.get("USE_SSL", "false") == "true"
if use_ssl is True: if use_ssl is True:
@@ -38,3 +40,21 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
await gather(site.start(), backend._start_tracking()) await gather(site.start(), backend._start_tracking())
run(main()) run(main())
except Exception as e:
err_msg = f"PyWorker failed to launch: {e}"
log.error(err_msg)
async def beacon():
metrics = Metrics()
metrics._set_version(getattr(backend, "version", "0"))
metrics._set_mtoken(getattr(backend, "mtoken", ""))
try:
while True:
metrics._model_errored(err_msg)
await metrics._Metrics__send_metrics_and_reset()
await asyncio.sleep(10)
finally:
await metrics.aclose()
run(beacon())
+40 -1
View File
@@ -128,5 +128,44 @@ echo "launching PyWorker server"
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only # from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG" [ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
set +e
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
set -e
if [ "${PY_STATUS}" -ne 0 ]; then
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
MTOKEN="${MASTER_TOKEN:-}"
VERSION="${PYWORKER_VERSION:-0}"
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
for addr in "${REPORT_ADDRS[@]}"; do
curl -sS -X POST -H 'Content-Type: application/json' \
-d "$(cat <<JSON
{
"id": ${CONTAINER_ID:-0},
"mtoken": "${MTOKEN}",
"version": "${VERSION}",
"loadtime": 0,
"new_load": 0,
"cur_load": 0,
"rej_load": 0,
"max_perf": 0,
"cur_perf": 0,
"error_msg": "${ERROR_MSG}",
"num_requests_working": 0,
"num_requests_recieved": 0,
"additional_disk_usage": 0,
"working_request_idxs": [],
"cur_capacity": 0,
"max_capacity": 0,
"url": "${URL}"
}
JSON
)" "${addr%/}/worker_status/" || true
done
fi
echo "launching PyWorker server done" echo "launching PyWorker server done"