import os import json import time import base64 import subprocess import dataclasses import logging from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from functools import cached_property from distutils.util import strtobool from anyio import open_file from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector import asyncio import requests from Crypto.Signature import pkcs1_15 from Crypto.Hash import SHA256 from Crypto.PublicKey import RSA from lib.metrics import Metrics from lib.data_types import ( AuthData, EndpointHandler, LogAction, ApiPayload_T, JsonDataException, RequestMetrics, BenchmarkResult ) VERSION = "0.1.0" MSG_HISTORY_LEN = 100 log = logging.getLogger(__file__) # defines the minimum wait time between sending updates to autoscaler LOG_POLL_INTERVAL = 0.1 BENCHMARK_INDICATOR_FILE = ".has_benchmark" MAX_PUBKEY_FETCH_ATTEMPTS = 3 @dataclasses.dataclass class Backend: """ This class is responsible for: 1. Tailing logs and updating load time metrics 2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and sending the request. It also updates metrics as it makes those requests. 3. Running a benchmark from an EndpointHandler """ model_server_url: str model_log_file: str allow_parallel_requests: bool benchmark_handler: ( EndpointHandler # this endpoint handler will be used for benchmarking ) log_actions: List[Tuple[LogAction, str]] max_wait_time: float = 10.0 request_queue = asyncio.Queue() worker_task = asyncio.create_task(_worker()) reqnum = -1 version = VERSION msg_history = [] sem: Semaphore = dataclasses.field(default_factory=Semaphore) unsecured: bool = dataclasses.field( default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), ) def __post_init__(self): self.metrics = Metrics() self.metrics._set_version(self.version) self._total_pubkey_fetch_errors = 0 self._pubkey = self._fetch_pubkey() self.__start_healthcheck: bool = False @property def pubkey(self) -> Optional[RSA.RsaKey]: if self._pubkey is None: self._pubkey = self._fetch_pubkey() return self._pubkey @cached_property def session(self): log.debug(f"starting session with {self.model_server_url}") connector = TCPConnector( force_close=True, # Required for long running jobs enable_cleanup_closed=True, ) timeout = ClientTimeout(total=None) return ClientSession(self.model_server_url, timeout=timeout, connector=connector) async def _worker(self): while True: handler, request, fut = await self.request_queue.get() try: res = await self.__process_request(handler, request) fut.set_result(res) except Exception as e: fut.set_exception(e) finally: self.request_queue.task_done() def create_handler( self, handler: EndpointHandler[ApiPayload_T], ) -> Callable[[web.Request], Awaitable[Union[web.Response, web.StreamResponse]]]: async def handler_fn( request: web.Request, ) -> Union[web.Response, web.StreamResponse]: return await self.__handle_request(handler=handler, request=request) return handler_fn #######################################Private####################################### def _fetch_pubkey(self): command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"] result = subprocess.check_output(command, universal_newlines=True) log.debug("public key:") log.debug(result) key = None for _ in range(5): try: key = RSA.import_key(result) break except ValueError as e: log.debug(f"Error downloading key: {e}") time.sleep(15) if key is None: self._total_pubkey_fetch_errors += 1 if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS: self.backend_errored("Failed to get autoscaler pubkey") return key async def __handle_request( self, handler: EndpointHandler[ApiPayload_T], request: web.Request, ) -> Union[web.Response, web.StreamResponse]: """use this function to forward requests to the model endpoint""" try: data = await request.json() auth_data, payload = handler.get_data_from_request(data) except JsonDataException as e: return web.json_response(data=e.message, status=422) except json.JSONDecodeError: return web.json_response(dict(error="invalid JSON"), status=422) workload = payload.count_workload() request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") async def cancel_api_call_if_disconnected() -> web.Response: await request.wait_for_disconnection() log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled") self.metrics._request_canceled(request_metrics) raise asyncio.CancelledError async def make_request() -> Union[web.Response, web.StreamResponse]: try: response = await self.__call_api(handler=handler, payload=payload) status_code = response.status log.debug( " ".join( [ f"request with reqnum:{request_metrics.reqnum}", f"returned status code: {status_code},", ] ) ) res = await handler.generate_client_response(request, response) self.metrics._request_success(request_metrics) return res except requests.exceptions.RequestException as e: log.debug(f"[backend] Request error: {e}") self.metrics._request_errored(request_metrics) return web.Response(status=500) ########### if self.__check_signature(auth_data) is False: self.metrics._request_reject(request_metrics) return web.Response(status=401) if self.metrics.model_metrics.wait_time > self.max_wait_time: self.metrics._request_reject(request_metrics) return web.Response(status=429) acquired = False try: self.metrics._request_start(request_metrics) if self.allow_parallel_requests is False: log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") await self.sem.acquire() acquired = True log.debug( f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..." ) else: log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") done, pending = await wait( [ create_task(make_request()), create_task(cancel_api_call_if_disconnected()), ], return_when=FIRST_COMPLETED, ) for t in pending: t.cancel() await asyncio.gather(*pending, return_exceptions=True) done_task = done.pop() try: return done_task.result() except Exception as e: log.debug(f"Request task raised exception: {e}") return web.Response(status=500) except asyncio.CancelledError: # Client is gone. Do not write a response; just unwind. return web.Response(status=499) except Exception as e: log.debug(f"Exception in main handler loop {e}") return web.Response(status=500) finally: # Always release the semaphore if it was acquired if acquired: self.sem.release() self.metrics._request_end(request_metrics) @cached_property def healthcheck_session(self): """Dedicated session for healthchecks to avoid conflicts with API session""" log.debug("creating dedicated healthcheck session") connector = TCPConnector( force_close=True, # Keep this for isolation enable_cleanup_closed=True, ) timeout = ClientTimeout(total=10) # Reasonable timeout for healthchecks return ClientSession(timeout=timeout, connector=connector) async def __healthcheck(self): health_check_url = self.benchmark_handler.healthcheck_endpoint if health_check_url is None: log.debug("No healthcheck endpoint defined, skipping healthcheck") return while True: await sleep(10) if self.__start_healthcheck is False: continue try: log.debug(f"Performing healthcheck on {health_check_url}") async with self.healthcheck_session.get(health_check_url) as response: if response.status == 200: log.debug("Healthcheck successful") elif response.status == 503: log.debug(f"Healthcheck failed with status: {response.status}") self.backend_errored( f"Healthcheck failed with status: {response.status}" ) else: log.debug(f"Healthcheck Endpoint not ready: {response.status}") except Exception as e: log.debug(f"Healthcheck failed with exception: {e}") self.backend_errored(str(e)) async def _start_tracking(self) -> None: await gather( self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop() ) def backend_errored(self, msg: str) -> None: self.metrics._model_errored(msg) async def __call_api( self, handler: EndpointHandler[ApiPayload_T], payload: ApiPayload_T ) -> ClientResponse: api_payload = payload.generate_payload_json() log.debug(f"posting to endpoint: '{handler.endpoint}', payload: {api_payload}") return await self.session.post(url=handler.endpoint, json=api_payload) def __check_signature(self, auth_data: AuthData) -> bool: if self.unsecured is True: return True def verify_signature(message, signature): if self.pubkey is None: log.debug(f"No Public Key!") return False h = SHA256.new(message.encode()) try: pkcs1_15.new(self.pubkey).verify(h, base64.b64decode(signature)) return True except (ValueError, TypeError): return False message = { key: value for (key, value) in (dataclasses.asdict(auth_data).items()) if key != "signature" } if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN): log.debug( f"reqnum failure, got {auth_data.reqnum}, current_reqnum: {self.reqnum}" ) return False elif message in self.msg_history: log.debug(f"message: {message} already in message history") return False elif verify_signature(json.dumps(message, indent=4), auth_data.signature): self.reqnum = max(auth_data.reqnum, self.reqnum) self.msg_history.append(message) self.msg_history = self.msg_history[-MSG_HISTORY_LEN:] return True else: log.debug( f"signature verification failed, sig:{auth_data.signature}, message: {message}" ) return False async def __read_logs(self) -> Awaitable[NoReturn]: async def run_benchmark() -> float: log.debug("starting benchmark") try: with open(BENCHMARK_INDICATOR_FILE, "r") as f: log.debug("already ran benchmark") # trigger model load payload = self.benchmark_handler.make_benchmark_payload() _ = await self.__call_api( handler=self.benchmark_handler, payload=payload ) return float(f.readline()) except FileNotFoundError: pass log.debug("Initial run to trigger model loading...") payload = self.benchmark_handler.make_benchmark_payload() await self.__call_api(handler=self.benchmark_handler, payload=payload) max_throughput = 0 sum_throughput = 0 concurrent_requests = 10 if self.allow_parallel_requests else 1 for run in range(1, self.benchmark_handler.benchmark_runs + 1): start = time.time() benchmark_requests = [] for i in range(concurrent_requests): payload = self.benchmark_handler.make_benchmark_payload() workload = payload.count_workload() task = self.__call_api(handler=self.benchmark_handler, payload=payload) benchmark_requests.append( BenchmarkResult(request_idx=i, workload=workload, task=task) ) responses = await gather(*[br.task for br in benchmark_requests]) for br, response in zip(benchmark_requests, responses): br.response = response total_workload = sum(br.workload for br in benchmark_requests if br.is_successful) time_elapsed = time.time() - start successful_responses = sum([1 for br in benchmark_requests if br.is_successful]) if successful_responses == 0: self.backend_errored("No successful responses from benchmark") log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses") throughput = total_workload / time_elapsed sum_throughput += throughput max_throughput = max(max_throughput, throughput) # Log results for debugging log.debug( "\n".join( [ "#" * 60, f"Run: {run}, concurrent_requests: {concurrent_requests}", f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s", f"Throughput: {throughput} workload/s", f"Successful responses: {successful_responses}/{concurrent_requests}", "#" * 60, ] ) ) average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs log.debug( f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}" ) with open(BENCHMARK_INDICATOR_FILE, "w") as f: f.write(str(max_throughput)) return max_throughput async def handle_log_line(log_line: str) -> None: """ Implement this function to handle each log line for your model. This function should mutate self.system_metrics and self.model_metrics """ for action, msg in self.log_actions: match action: case LogAction.ModelLoaded if msg in log_line: log.debug( f"Got log line indicating model is loaded: {log_line}" ) # some backends need a few seconds after logging successful startup before # they can begin accepting requests await sleep(5) try: max_throughput = await run_benchmark() self.__start_healthcheck = True self.metrics._model_loaded( max_throughput=max_throughput, ) except ClientConnectorError as e: log.debug( f"failed to connect to comfyui api during benchmark" ) self.backend_errored(str(e)) case LogAction.ModelError if msg in log_line: log.debug(f"Got log line indicating error: {log_line}") self.backend_errored(msg) break case LogAction.Info if msg in log_line: log.debug(f"Info from model logs: {log_line}") async def tail_log(): log.debug(f"tailing file: {self.model_log_file}") async with await open_file(self.model_log_file) as f: while True: line = await f.readline() if line: await handle_log_line(line.rstrip()) else: await asyncio.sleep(LOG_POLL_INTERVAL) ########### while True: if os.path.isfile(self.model_log_file) is True: return await tail_log() else: await sleep(1)