Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 249ca2eb99 | |||
| d8bb1fcc68 |
+97
-37
@@ -9,6 +9,7 @@ from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
|||||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
from anyio import open_file
|
from anyio import open_file
|
||||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
||||||
@@ -30,7 +31,7 @@ from lib.data_types import (
|
|||||||
BenchmarkResult
|
BenchmarkResult
|
||||||
)
|
)
|
||||||
|
|
||||||
VERSION = "0.2.0"
|
VERSION = "0.2.1"
|
||||||
|
|
||||||
MSG_HISTORY_LEN = 100
|
MSG_HISTORY_LEN = 100
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
@@ -63,6 +64,7 @@ class Backend:
|
|||||||
version = VERSION
|
version = VERSION
|
||||||
msg_history = []
|
msg_history = []
|
||||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||||
|
queue: deque = dataclasses.field(default_factory=deque, repr=False)
|
||||||
unsecured: bool = dataclasses.field(
|
unsecured: bool = dataclasses.field(
|
||||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||||
)
|
)
|
||||||
@@ -141,11 +143,26 @@ class Backend:
|
|||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
||||||
|
|
||||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
|
||||||
|
def advance_queue_after_completion(event: asyncio.Event):
|
||||||
|
"""Pop current head and wake next waiter, if any."""
|
||||||
|
# If this event is current head, wake next waiter
|
||||||
|
if self.queue and self.queue[0] is event:
|
||||||
|
self.queue.popleft()
|
||||||
|
if self.queue:
|
||||||
|
self.queue[0].set()
|
||||||
|
else:
|
||||||
|
# Else, remove it from the queue
|
||||||
|
try:
|
||||||
|
self.queue.remove(event)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def cancel_api_call_if_disconnected() -> None:
|
||||||
await request.wait_for_disconnection()
|
await request.wait_for_disconnection()
|
||||||
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
|
log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled")
|
||||||
self.metrics._request_canceled(request_metrics)
|
self.metrics._request_canceled(request_metrics)
|
||||||
raise asyncio.CancelledError
|
return
|
||||||
|
|
||||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||||
try:
|
try:
|
||||||
@@ -162,7 +179,9 @@ class Backend:
|
|||||||
res = await handler.generate_client_response(request, response)
|
res = await handler.generate_client_response(request, response)
|
||||||
self.metrics._request_success(request_metrics)
|
self.metrics._request_success(request_metrics)
|
||||||
return res
|
return res
|
||||||
except requests.exceptions.RequestException as e:
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
log.debug(f"[backend] Request error: {e}")
|
log.debug(f"[backend] Request error: {e}")
|
||||||
self.metrics._request_errored(request_metrics)
|
self.metrics._request_errored(request_metrics)
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
@@ -177,46 +196,87 @@ class Backend:
|
|||||||
self.metrics._request_reject(request_metrics)
|
self.metrics._request_reject(request_metrics)
|
||||||
return web.Response(status=429)
|
return web.Response(status=429)
|
||||||
|
|
||||||
acquired = False
|
disconnect_task = create_task(cancel_api_call_if_disconnected())
|
||||||
try:
|
next_request_task = None
|
||||||
self.metrics._request_start(request_metrics)
|
work_task = None
|
||||||
if self.allow_parallel_requests is False:
|
event = asyncio.Event() # Used in finally block, so initialize here
|
||||||
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
|
||||||
await self.sem.acquire()
|
self.metrics._request_start(request_metrics)
|
||||||
acquired = True
|
|
||||||
log.debug(
|
try:
|
||||||
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
|
if self.allow_parallel_requests:
|
||||||
)
|
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||||
else:
|
work_task = create_task(make_request())
|
||||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
|
||||||
done, pending = await wait(
|
|
||||||
[
|
for t in pending:
|
||||||
create_task(make_request()),
|
t.cancel()
|
||||||
create_task(cancel_api_call_if_disconnected()),
|
await asyncio.gather(*pending, return_exceptions=True)
|
||||||
],
|
|
||||||
return_when=FIRST_COMPLETED,
|
if disconnect_task in done:
|
||||||
)
|
return web.Response(status=499)
|
||||||
for t in pending:
|
|
||||||
t.cancel()
|
# otherwise work_task completed
|
||||||
await asyncio.gather(*pending, return_exceptions=True)
|
return await work_task
|
||||||
|
|
||||||
|
# FIFO-queue branch
|
||||||
|
else:
|
||||||
|
# Insert a Event into the queue for this request
|
||||||
|
# Event.set() == our request is up next
|
||||||
|
self.queue.append(event)
|
||||||
|
if self.queue and self.queue[0] is event:
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
# Race between our request being next and request being cancelled
|
||||||
|
next_request_task = create_task(event.wait())
|
||||||
|
first_done, first_pending = await wait(
|
||||||
|
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the disconnect task wins the race
|
||||||
|
if disconnect_task in first_done:
|
||||||
|
# Clean up the next_request_task, then exit
|
||||||
|
for t in first_pending:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*first_pending, return_exceptions=True)
|
||||||
|
return web.Response(status=499)
|
||||||
|
|
||||||
|
# We are the next-up request in the queue
|
||||||
|
log.debug(f"Starting work on request {request_metrics.reqnum}...")
|
||||||
|
|
||||||
|
# Race the backend API call with the disconnect task
|
||||||
|
work_task = create_task(make_request())
|
||||||
|
|
||||||
|
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
|
||||||
|
for t in pending:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
|
||||||
|
if disconnect_task in done:
|
||||||
|
return web.Response(status=499)
|
||||||
|
|
||||||
|
# otherwise work_task completed
|
||||||
|
return await work_task
|
||||||
|
|
||||||
done_task = done.pop()
|
|
||||||
try:
|
|
||||||
return done_task.result()
|
|
||||||
except Exception as e:
|
|
||||||
log.debug(f"Request task raised exception: {e}")
|
|
||||||
return web.Response(status=500)
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# Client is gone. Do not write a response; just unwind.
|
|
||||||
return web.Response(status=499)
|
return web.Response(status=499)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Exception in main handler loop {e}")
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Always release the semaphore if it was acquired
|
if not self.allow_parallel_requests:
|
||||||
if acquired:
|
advance_queue_after_completion(event)
|
||||||
self.sem.release()
|
|
||||||
self.metrics._request_end(request_metrics)
|
self.metrics._request_end(request_metrics)
|
||||||
|
cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t]
|
||||||
|
for t in cleanup_tasks:
|
||||||
|
if not t.done():
|
||||||
|
t.cancel()
|
||||||
|
if cleanup_tasks:
|
||||||
|
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def healthcheck_session(self):
|
def healthcheck_session(self):
|
||||||
|
|||||||
+25
-45
@@ -3,58 +3,38 @@ import logging
|
|||||||
from typing import List
|
from typing import List
|
||||||
import ssl
|
import ssl
|
||||||
from asyncio import run, gather
|
from asyncio import run, gather
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from lib.backend import Backend
|
from lib.backend import Backend
|
||||||
from lib.metrics import Metrics
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||||
try:
|
log.debug("getting certificate...")
|
||||||
log.debug("getting certificate...")
|
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
if use_ssl is True:
|
||||||
if use_ssl is True:
|
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
ssl_context.load_cert_chain(
|
||||||
ssl_context.load_cert_chain(
|
certfile="/etc/instance.crt",
|
||||||
certfile="/etc/instance.crt",
|
keyfile="/etc/instance.key",
|
||||||
keyfile="/etc/instance.key",
|
)
|
||||||
)
|
else:
|
||||||
else:
|
ssl_context = None
|
||||||
ssl_context = None
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
log.debug("starting server...")
|
log.debug("starting server...")
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
runner = web.AppRunner(app)
|
runner = web.AppRunner(app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(
|
site = web.TCPSite(
|
||||||
runner,
|
runner,
|
||||||
ssl_context=ssl_context,
|
ssl_context=ssl_context,
|
||||||
port=int(os.environ["WORKER_PORT"]),
|
port=int(os.environ["WORKER_PORT"]),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
await gather(site.start(), backend._start_tracking())
|
await gather(site.start(), backend._start_tracking())
|
||||||
|
|
||||||
run(main())
|
run(main())
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
err_msg = f"PyWorker failed to launch: {e}"
|
|
||||||
log.error(err_msg)
|
|
||||||
|
|
||||||
async def beacon():
|
|
||||||
metrics = Metrics()
|
|
||||||
metrics._set_version(getattr(backend, "version", "0"))
|
|
||||||
metrics._set_mtoken(getattr(backend, "mtoken", ""))
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
metrics._model_errored(err_msg)
|
|
||||||
await metrics._Metrics__send_metrics_and_reset()
|
|
||||||
await asyncio.sleep(10)
|
|
||||||
finally:
|
|
||||||
await metrics.aclose()
|
|
||||||
|
|
||||||
run(beacon())
|
|
||||||
|
|||||||
+1
-40
@@ -128,44 +128,5 @@ echo "launching PyWorker server"
|
|||||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||||
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
||||||
|
|
||||||
|
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
||||||
set +e
|
|
||||||
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
|
|
||||||
PY_STATUS=${PIPESTATUS[0]}
|
|
||||||
set -e
|
|
||||||
|
|
||||||
if [ "${PY_STATUS}" -ne 0 ]; then
|
|
||||||
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
|
|
||||||
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
|
|
||||||
MTOKEN="${MASTER_TOKEN:-}"
|
|
||||||
VERSION="${PYWORKER_VERSION:-0}"
|
|
||||||
|
|
||||||
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
|
|
||||||
for addr in "${REPORT_ADDRS[@]}"; do
|
|
||||||
curl -sS -X POST -H 'Content-Type: application/json' \
|
|
||||||
-d "$(cat <<JSON
|
|
||||||
{
|
|
||||||
"id": ${CONTAINER_ID:-0},
|
|
||||||
"mtoken": "${MTOKEN}",
|
|
||||||
"version": "${VERSION}",
|
|
||||||
"loadtime": 0,
|
|
||||||
"new_load": 0,
|
|
||||||
"cur_load": 0,
|
|
||||||
"rej_load": 0,
|
|
||||||
"max_perf": 0,
|
|
||||||
"cur_perf": 0,
|
|
||||||
"error_msg": "${ERROR_MSG}",
|
|
||||||
"num_requests_working": 0,
|
|
||||||
"num_requests_recieved": 0,
|
|
||||||
"additional_disk_usage": 0,
|
|
||||||
"working_request_idxs": [],
|
|
||||||
"cur_capacity": 0,
|
|
||||||
"max_capacity": 0,
|
|
||||||
"url": "${URL}"
|
|
||||||
}
|
|
||||||
JSON
|
|
||||||
)" "${addr%/}/worker_status/" || true
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "launching PyWorker server done"
|
echo "launching PyWorker server done"
|
||||||
Reference in New Issue
Block a user