remove legacy pyworker
This commit is contained in:
-434
@@ -1,434 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import base64
|
||||
import subprocess
|
||||
import dataclasses
|
||||
import logging
|
||||
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||
from functools import cached_property
|
||||
from distutils.util import strtobool
|
||||
|
||||
from anyio import open_file
|
||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
||||
import asyncio
|
||||
|
||||
import requests
|
||||
from Crypto.Signature import pkcs1_15
|
||||
from Crypto.Hash import SHA256
|
||||
from Crypto.PublicKey import RSA
|
||||
|
||||
from lib.metrics import Metrics
|
||||
from lib.data_types import (
|
||||
AuthData,
|
||||
EndpointHandler,
|
||||
LogAction,
|
||||
ApiPayload_T,
|
||||
JsonDataException,
|
||||
RequestMetrics,
|
||||
BenchmarkResult
|
||||
)
|
||||
|
||||
VERSION = "0.2.1"
|
||||
|
||||
MSG_HISTORY_LEN = 100
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
# defines the minimum wait time between sending updates to autoscaler
|
||||
LOG_POLL_INTERVAL = 0.1
|
||||
BENCHMARK_INDICATOR_FILE = ".has_benchmark"
|
||||
MAX_PUBKEY_FETCH_ATTEMPTS = 3
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Backend:
|
||||
"""
|
||||
This class is responsible for:
|
||||
1. Tailing logs and updating load time metrics
|
||||
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
|
||||
sending the request. It also updates metrics as it makes those requests.
|
||||
3. Running a benchmark from an EndpointHandler
|
||||
"""
|
||||
|
||||
model_server_url: str
|
||||
model_log_file: str
|
||||
allow_parallel_requests: bool
|
||||
benchmark_handler: (
|
||||
EndpointHandler # this endpoint handler will be used for benchmarking
|
||||
)
|
||||
log_actions: List[Tuple[LogAction, str]]
|
||||
max_wait_time: float = 10.0
|
||||
reqnum = -1
|
||||
version = VERSION
|
||||
msg_history = []
|
||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||
unsecured: bool = dataclasses.field(
|
||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||
)
|
||||
report_addr: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
||||
)
|
||||
mtoken: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.metrics = Metrics()
|
||||
self.metrics._set_version(self.version)
|
||||
self.metrics._set_mtoken(self.mtoken)
|
||||
self._total_pubkey_fetch_errors = 0
|
||||
self._pubkey = self._fetch_pubkey()
|
||||
self.__start_healthcheck: bool = False
|
||||
|
||||
@property
|
||||
def pubkey(self) -> Optional[RSA.RsaKey]:
|
||||
if self._pubkey is None:
|
||||
self._pubkey = self._fetch_pubkey()
|
||||
return self._pubkey
|
||||
|
||||
@cached_property
|
||||
def session(self):
|
||||
log.debug(f"starting session with {self.model_server_url}")
|
||||
connector = TCPConnector(
|
||||
force_close=True, # Required for long running jobs
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
|
||||
timeout = ClientTimeout(total=None)
|
||||
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
||||
|
||||
def create_handler(
|
||||
self,
|
||||
handler: EndpointHandler[ApiPayload_T],
|
||||
) -> Callable[[web.Request], Awaitable[Union[web.Response, web.StreamResponse]]]:
|
||||
async def handler_fn(
|
||||
request: web.Request,
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
return await self.__handle_request(handler=handler, request=request)
|
||||
|
||||
return handler_fn
|
||||
|
||||
#######################################Private#######################################
|
||||
def _fetch_pubkey(self):
|
||||
report_addr = self.report_addr.rstrip("/")
|
||||
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
||||
try:
|
||||
result = subprocess.check_output(command, universal_newlines=True)
|
||||
log.debug("public key:")
|
||||
log.debug(result)
|
||||
key = RSA.import_key(result)
|
||||
if key is not None:
|
||||
return key
|
||||
except (ValueError , subprocess.CalledProcessError) as e:
|
||||
log.debug(f"Error downloading key: {e}")
|
||||
self.backend_errored("Failed to get autoscaler pubkey")
|
||||
|
||||
|
||||
async def __handle_request(
|
||||
self,
|
||||
handler: EndpointHandler[ApiPayload_T],
|
||||
request: web.Request,
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""use this function to forward requests to the model endpoint"""
|
||||
try:
|
||||
data = await request.json()
|
||||
auth_data, payload = handler.get_data_from_request(data)
|
||||
except JsonDataException as e:
|
||||
return web.json_response(data=e.message, status=422)
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||
workload = payload.count_workload()
|
||||
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
||||
|
||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||
await request.wait_for_disconnection()
|
||||
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
|
||||
self.metrics._request_canceled(request_metrics)
|
||||
raise asyncio.CancelledError
|
||||
|
||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||
try:
|
||||
response = await self.__call_api(handler=handler, payload=payload)
|
||||
status_code = response.status
|
||||
log.debug(
|
||||
" ".join(
|
||||
[
|
||||
f"request with reqnum:{request_metrics.reqnum}",
|
||||
f"returned status code: {status_code},",
|
||||
]
|
||||
)
|
||||
)
|
||||
res = await handler.generate_client_response(request, response)
|
||||
self.metrics._request_success(request_metrics)
|
||||
return res
|
||||
except requests.exceptions.RequestException as e:
|
||||
log.debug(f"[backend] Request error: {e}")
|
||||
self.metrics._request_errored(request_metrics)
|
||||
return web.Response(status=500)
|
||||
|
||||
###########
|
||||
|
||||
if self.__check_signature(auth_data) is False:
|
||||
self.metrics._request_reject(request_metrics)
|
||||
return web.Response(status=401)
|
||||
|
||||
if self.metrics.model_metrics.wait_time > self.max_wait_time:
|
||||
self.metrics._request_reject(request_metrics)
|
||||
return web.Response(status=429)
|
||||
|
||||
acquired = False
|
||||
try:
|
||||
self.metrics._request_start(request_metrics)
|
||||
if self.allow_parallel_requests is False:
|
||||
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
||||
await self.sem.acquire()
|
||||
acquired = True
|
||||
log.debug(
|
||||
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
|
||||
)
|
||||
else:
|
||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||
done, pending = await wait(
|
||||
[
|
||||
create_task(make_request()),
|
||||
create_task(cancel_api_call_if_disconnected()),
|
||||
],
|
||||
return_when=FIRST_COMPLETED,
|
||||
)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
done_task = done.pop()
|
||||
try:
|
||||
return done_task.result()
|
||||
except Exception as e:
|
||||
log.debug(f"Request task raised exception: {e}")
|
||||
return web.Response(status=500)
|
||||
except asyncio.CancelledError:
|
||||
# Client is gone. Do not write a response; just unwind.
|
||||
return web.Response(status=499)
|
||||
except Exception as e:
|
||||
log.debug(f"Exception in main handler loop {e}")
|
||||
return web.Response(status=500)
|
||||
finally:
|
||||
# Always release the semaphore if it was acquired
|
||||
if acquired:
|
||||
self.sem.release()
|
||||
self.metrics._request_end(request_metrics)
|
||||
|
||||
@cached_property
|
||||
def healthcheck_session(self):
|
||||
"""Dedicated session for healthchecks to avoid conflicts with API session"""
|
||||
log.debug("creating dedicated healthcheck session")
|
||||
connector = TCPConnector(
|
||||
force_close=True, # Keep this for isolation
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
timeout = ClientTimeout(total=10) # Reasonable timeout for healthchecks
|
||||
return ClientSession(timeout=timeout, connector=connector)
|
||||
|
||||
async def __healthcheck(self):
|
||||
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
||||
if health_check_url is None:
|
||||
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||
return
|
||||
|
||||
while True:
|
||||
await sleep(10)
|
||||
if self.__start_healthcheck is False:
|
||||
continue
|
||||
try:
|
||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||
async with self.healthcheck_session.get(health_check_url) as response:
|
||||
if response.status == 200:
|
||||
log.debug("Healthcheck successful")
|
||||
elif response.status == 503:
|
||||
log.debug(f"Healthcheck failed with status: {response.status}")
|
||||
self.backend_errored(
|
||||
f"Healthcheck failed with status: {response.status}"
|
||||
)
|
||||
else:
|
||||
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
||||
except Exception as e:
|
||||
log.debug(f"Healthcheck failed with exception: {e}")
|
||||
self.backend_errored(str(e))
|
||||
|
||||
async def _start_tracking(self) -> None:
|
||||
await gather(
|
||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
||||
)
|
||||
|
||||
def backend_errored(self, msg: str) -> None:
|
||||
self.metrics._model_errored(msg)
|
||||
|
||||
async def __call_api(
|
||||
self, handler: EndpointHandler[ApiPayload_T], payload: ApiPayload_T
|
||||
) -> ClientResponse:
|
||||
api_payload = payload.generate_payload_json()
|
||||
log.debug(f"posting to endpoint: '{handler.endpoint}', payload: {api_payload}")
|
||||
return await self.session.post(url=handler.endpoint, json=api_payload)
|
||||
|
||||
def __check_signature(self, auth_data: AuthData) -> bool:
|
||||
if self.unsecured is True:
|
||||
return True
|
||||
|
||||
def verify_signature(message, signature):
|
||||
if self.pubkey is None:
|
||||
log.debug(f"No Public Key!")
|
||||
return False
|
||||
|
||||
h = SHA256.new(message.encode())
|
||||
try:
|
||||
pkcs1_15.new(self.pubkey).verify(h, base64.b64decode(signature))
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
message = {
|
||||
key: value
|
||||
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||
if key != "signature" and key != "__request_id"
|
||||
}
|
||||
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||
log.debug(
|
||||
f"reqnum failure, got {auth_data.reqnum}, current_reqnum: {self.reqnum}"
|
||||
)
|
||||
return False
|
||||
elif message in self.msg_history:
|
||||
log.debug(f"message: {message} already in message history")
|
||||
return False
|
||||
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
|
||||
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||
self.msg_history.append(message)
|
||||
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||
return True
|
||||
else:
|
||||
log.debug(
|
||||
f"signature verification failed, sig:{auth_data.signature}, message: {message}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def __read_logs(self) -> Awaitable[NoReturn]:
|
||||
|
||||
async def run_benchmark() -> float:
|
||||
log.debug("starting benchmark")
|
||||
try:
|
||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||
log.debug("already ran benchmark")
|
||||
# trigger model load
|
||||
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||
# _ = await self.__call_api(
|
||||
# handler=self.benchmark_handler, payload=payload
|
||||
# )
|
||||
return float(f.readline())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
log.debug("Initial run to trigger model loading...")
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||
|
||||
max_throughput = 0
|
||||
sum_throughput = 0
|
||||
concurrent_requests = 10 if self.allow_parallel_requests else 1
|
||||
|
||||
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
||||
start = time.time()
|
||||
benchmark_requests = []
|
||||
|
||||
for i in range(concurrent_requests):
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
workload = payload.count_workload()
|
||||
task = self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||
benchmark_requests.append(
|
||||
BenchmarkResult(request_idx=i, workload=workload, task=task)
|
||||
)
|
||||
|
||||
responses = await gather(*[br.task for br in benchmark_requests])
|
||||
for br, response in zip(benchmark_requests, responses):
|
||||
br.response = response
|
||||
|
||||
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
||||
time_elapsed = time.time() - start
|
||||
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
||||
if successful_responses == 0:
|
||||
self.backend_errored("No successful responses from benchmark")
|
||||
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
|
||||
|
||||
throughput = total_workload / time_elapsed
|
||||
sum_throughput += throughput
|
||||
max_throughput = max(max_throughput, throughput)
|
||||
|
||||
# Log results for debugging
|
||||
log.debug(
|
||||
"\n".join(
|
||||
[
|
||||
"#" * 60,
|
||||
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
||||
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
||||
f"Throughput: {throughput} workload/s",
|
||||
f"Successful responses: {successful_responses}/{concurrent_requests}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
||||
log.debug(
|
||||
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||
)
|
||||
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||
f.write(str(max_throughput))
|
||||
return max_throughput
|
||||
|
||||
async def handle_log_line(log_line: str) -> None:
|
||||
"""
|
||||
Implement this function to handle each log line for your model.
|
||||
This function should mutate self.system_metrics and self.model_metrics
|
||||
"""
|
||||
for action, msg in self.log_actions:
|
||||
match action:
|
||||
case LogAction.ModelLoaded if msg in log_line:
|
||||
log.debug(
|
||||
f"Got log line indicating model is loaded: {log_line}"
|
||||
)
|
||||
# some backends need a few seconds after logging successful startup before
|
||||
# they can begin accepting requests
|
||||
# await sleep(5)
|
||||
try:
|
||||
max_throughput = await run_benchmark()
|
||||
self.__start_healthcheck = True
|
||||
self.metrics._model_loaded(
|
||||
max_throughput=max_throughput,
|
||||
)
|
||||
except ClientConnectorError as e:
|
||||
log.debug(
|
||||
f"failed to connect to comfyui api during benchmark"
|
||||
)
|
||||
self.backend_errored(str(e))
|
||||
case LogAction.ModelError if msg in log_line:
|
||||
log.debug(f"Got log line indicating error: {log_line}")
|
||||
self.backend_errored(msg)
|
||||
break
|
||||
case LogAction.Info if msg in log_line:
|
||||
log.debug(f"Info from model logs: {log_line}")
|
||||
|
||||
async def tail_log():
|
||||
log.debug(f"tailing file: {self.model_log_file}")
|
||||
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
|
||||
while True:
|
||||
line = await f.readline()
|
||||
if line:
|
||||
await handle_log_line(line.rstrip())
|
||||
else:
|
||||
await asyncio.sleep(LOG_POLL_INTERVAL)
|
||||
|
||||
###########
|
||||
|
||||
while True:
|
||||
if os.path.isfile(self.model_log_file) is True:
|
||||
return await tail_log()
|
||||
else:
|
||||
await sleep(1)
|
||||
@@ -1,324 +0,0 @@
|
||||
import time
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
|
||||
from aiohttp import web, ClientResponse
|
||||
import inspect
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
"""
|
||||
type variable representing an incoming payload to pyworker that will used to calculate load and will then
|
||||
be forwarded to the model
|
||||
"""
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class JsonDataException(Exception):
|
||||
def __init__(self, json_msg: Dict[str, Any]):
|
||||
self.message = json_msg
|
||||
|
||||
|
||||
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiPayload(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def for_test(cls: Type[ApiPayload_T]) -> ApiPayload_T:
|
||||
"""defines how create a payload for load testing"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count_workload(self) -> float:
|
||||
"""defines how to calculate workload for a payload"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_json_msg(
|
||||
cls: Type[ApiPayload_T], json_msg: Dict[str, Any]
|
||||
) -> ApiPayload_T:
|
||||
"""
|
||||
defines how to create an API payload from a JSON message,
|
||||
it should throw an JsonDataException if there are issues with some fields
|
||||
or they are missing in the format of
|
||||
{
|
||||
"field": "error msg"
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthData:
|
||||
"""data used to authenticate requester"""
|
||||
|
||||
cost: str
|
||||
endpoint: str
|
||||
reqnum: int
|
||||
request_idx: int
|
||||
signature: str
|
||||
url: str
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointHandler(ABC, Generic[ApiPayload_T]):
|
||||
"""
|
||||
Each model endpoint will have a handler responsible for counting workload from the incoming ApiPayload
|
||||
and converting it to json to be forwarded to model API
|
||||
"""
|
||||
|
||||
benchmark_runs: int = 8
|
||||
benchmark_words: int = 100
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def endpoint(self) -> str:
|
||||
"""the endpoint on the model API"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
"""the endpoint on the model API that is used for healthchecks"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def payload_cls(cls) -> Type[ApiPayload_T]:
|
||||
"""ApiPayload class"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def make_benchmark_payload(self) -> ApiPayload_T:
|
||||
"""defines how to create an ApiPayload for benchmarking."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""
|
||||
defines how to convert a model API response to a response to PyWorker client
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_data_from_request(
|
||||
cls, req_data: Dict[str, Any]
|
||||
) -> Tuple[AuthData, ApiPayload_T]:
|
||||
errors = {}
|
||||
auth_data: Optional[AuthData] = None
|
||||
payload: Optional[ApiPayload_T] = None
|
||||
try:
|
||||
if "auth_data" in req_data:
|
||||
auth_data = AuthData.from_json_msg(req_data["auth_data"])
|
||||
else:
|
||||
errors["auth_data"] = "field missing"
|
||||
except JsonDataException as e:
|
||||
errors["auth_data"] = e.message
|
||||
try:
|
||||
if "payload" in req_data:
|
||||
payload_cls = cls.payload_cls()
|
||||
payload = payload_cls.from_json_msg(req_data["payload"])
|
||||
else:
|
||||
errors["payload"] = "field missing"
|
||||
except JsonDataException as e:
|
||||
errors["payload"] = e.message
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
if auth_data and payload:
|
||||
return (auth_data, payload)
|
||||
else:
|
||||
raise Exception("error deserializing request data")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMetrics:
|
||||
"""General system metrics"""
|
||||
|
||||
model_loading_start: float
|
||||
model_loading_time: Union[float, None]
|
||||
last_disk_usage: float
|
||||
additional_disk_usage: float
|
||||
model_is_loaded: bool
|
||||
|
||||
@staticmethod
|
||||
def get_disk_usage_GB():
|
||||
return psutil.disk_usage("/").used / (2**30) # want units of GB
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return cls(
|
||||
model_loading_start=time.time(),
|
||||
model_loading_time=None,
|
||||
last_disk_usage=SystemMetrics.get_disk_usage_GB(),
|
||||
additional_disk_usage=0.0,
|
||||
model_is_loaded=False,
|
||||
)
|
||||
|
||||
def update_disk_usage(self):
|
||||
disk_usage = SystemMetrics.get_disk_usage_GB()
|
||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||
self.last_disk_usage = disk_usage
|
||||
|
||||
def reset(self, expected: float | None) -> None:
|
||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||
# as well: they should send model_loading_time once when they are done loading
|
||||
if self.model_loading_time == expected:
|
||||
self.model_loading_time = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestMetrics:
|
||||
"""Tracks metrics for an active request."""
|
||||
request_idx: int
|
||||
reqnum: int
|
||||
workload: float
|
||||
status: str
|
||||
success: bool = False
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
request_idx: int
|
||||
workload: float
|
||||
task: Awaitable[ClientResponse]
|
||||
response: Optional[ClientResponse] = None
|
||||
|
||||
@property
|
||||
def is_successful(self) -> bool:
|
||||
return self.response is not None and self.response.status == 200
|
||||
|
||||
@dataclass
|
||||
class ModelMetrics:
|
||||
"""Model specific metrics"""
|
||||
|
||||
# these are reset after being sent to autoscaler
|
||||
workload_served: float
|
||||
workload_received: float
|
||||
workload_cancelled: float
|
||||
workload_errored: float
|
||||
workload_rejected: float
|
||||
# these are not
|
||||
workload_pending: float
|
||||
error_msg: Optional[str]
|
||||
max_throughput: float
|
||||
requests_recieved: Set[int] = field(default_factory=set)
|
||||
requests_working: dict[int, RequestMetrics] = field(default_factory=dict)
|
||||
requests_deleting: list[RequestMetrics] = field(default_factory=list)
|
||||
last_update: float = field(default_factory=time.time)
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return cls(
|
||||
workload_pending=0.0,
|
||||
workload_served=0.0,
|
||||
workload_cancelled=0.0,
|
||||
workload_errored=0.0,
|
||||
workload_rejected=0.0,
|
||||
workload_received=0.0,
|
||||
error_msg=None,
|
||||
max_throughput=0.0,
|
||||
)
|
||||
|
||||
@property
|
||||
def workload_processing(self) -> float:
|
||||
return max(self.workload_received - self.workload_cancelled, 0.0)
|
||||
|
||||
@property
|
||||
def wait_time(self) -> float:
|
||||
if (len(self.requests_working) == 0):
|
||||
return 0.0
|
||||
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
|
||||
|
||||
@property
|
||||
def cur_load(self) -> float:
|
||||
return sum([request.workload for request in self.requests_working.values()])
|
||||
|
||||
@property
|
||||
def working_request_idxs(self) -> list[int]:
|
||||
return [req.request_idx for req in self.requests_working.values()]
|
||||
|
||||
def set_errored(self, error_msg):
|
||||
self.reset()
|
||||
self.error_msg = error_msg
|
||||
|
||||
def reset(self):
|
||||
self.workload_served = 0
|
||||
self.workload_received = 0
|
||||
self.workload_cancelled = 0
|
||||
self.workload_errored = 0
|
||||
self.workload_rejected = 0
|
||||
self.last_update = time.time()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoScalerData:
|
||||
"""Data that is reported to autoscaler"""
|
||||
|
||||
id: int
|
||||
mtoken: str
|
||||
version: str
|
||||
loadtime: float
|
||||
cur_load: float
|
||||
rej_load: float
|
||||
new_load: float
|
||||
error_msg: str
|
||||
max_perf: float
|
||||
cur_perf: float
|
||||
cur_capacity: float
|
||||
max_capacity: float
|
||||
num_requests_working: int
|
||||
num_requests_recieved: int
|
||||
additional_disk_usage: float
|
||||
working_request_idxs: list[int]
|
||||
url: str
|
||||
|
||||
|
||||
class LogAction(Enum):
|
||||
"""
|
||||
These actions tell the backend what a log value means, for example:
|
||||
actions [
|
||||
# this marks the model server as loaded
|
||||
(LogAction.ModelLoaded, "Starting server"),
|
||||
# these mark the model server as errored
|
||||
(LogAction.ModelError, "Exception loading model"),
|
||||
(LogAction.ModelError, "Server failed to bind to port"),
|
||||
# this tells the backend to print any logs containing the string into its own logs
|
||||
# which are visible in the vast console instance logs
|
||||
(LogAction.Info, "Starting model download"),
|
||||
]
|
||||
"""
|
||||
|
||||
ModelLoaded = 1
|
||||
ModelError = 2
|
||||
Info = 3
|
||||
-286
@@ -1,286 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
from asyncio import sleep
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from functools import cache
|
||||
import asyncio
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError
|
||||
|
||||
from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics
|
||||
from typing import Awaitable, NoReturn, List
|
||||
|
||||
METRICS_UPDATE_INTERVAL = 1
|
||||
DELETE_REQUESTS_INTERVAL = 1
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
@cache
|
||||
def get_url() -> str:
|
||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||
worker_port = os.environ[f"VAST_TCP_PORT_{os.environ['WORKER_PORT']}"]
|
||||
public_ip = os.environ["PUBLIC_IPADDR"]
|
||||
return f"http{'s' if use_ssl else ''}://{public_ip}:{worker_port}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metrics:
|
||||
version: str = "0"
|
||||
mtoken: str = ""
|
||||
last_metric_update: float = 0.0
|
||||
last_request_served: float = 0.0
|
||||
update_pending: bool = False
|
||||
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
|
||||
report_addr: List[str] = field(
|
||||
default_factory=lambda: os.environ["REPORT_ADDR"].split(",")
|
||||
)
|
||||
url: str = field(default_factory=get_url)
|
||||
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
|
||||
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
|
||||
_session: ClientSession | None = field(default=None, init=False, repr=False)
|
||||
|
||||
async def http(self) -> ClientSession:
|
||||
if self._session is None:
|
||||
self._session = ClientSession(
|
||||
timeout=ClientTimeout(total=10),
|
||||
connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True)
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
def _request_start(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called prior to forwarding a request to a model API.
|
||||
"""
|
||||
log.debug("request start")
|
||||
request.status = "Started"
|
||||
self.model_metrics.workload_pending += request.workload
|
||||
self.model_metrics.workload_received += request.workload
|
||||
self.model_metrics.requests_recieved.add(request.reqnum)
|
||||
self.model_metrics.requests_working[request.reqnum] = request
|
||||
self.update_pending = True
|
||||
|
||||
def _request_end(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called after handling of a request ends, regardless of the outcome
|
||||
"""
|
||||
self.model_metrics.workload_pending -= request.workload
|
||||
self.model_metrics.requests_working.pop(request.reqnum, None)
|
||||
self.model_metrics.requests_deleting.append(request)
|
||||
self.last_request_served = time.time()
|
||||
|
||||
def _request_success(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called after a response from model API is received and forwarded.
|
||||
"""
|
||||
self.model_metrics.workload_served += request.workload
|
||||
request.status = "Success"
|
||||
request.success = True
|
||||
self.update_pending = True
|
||||
|
||||
def _request_errored(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called if model API returns an error
|
||||
"""
|
||||
self.model_metrics.workload_errored += request.workload
|
||||
request.status = "Error"
|
||||
request.success = False
|
||||
self.update_pending = True
|
||||
|
||||
def _request_canceled(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called if client drops connection before model API has responded
|
||||
"""
|
||||
self.model_metrics.workload_cancelled += request.workload
|
||||
request.success = True
|
||||
request.status = "Cancelled"
|
||||
|
||||
def _request_reject(self, request: RequestMetrics):
|
||||
"""
|
||||
this function is called if the current wait time for the model is above max_wait_time
|
||||
"""
|
||||
self.model_metrics.requests_recieved.add(request.reqnum)
|
||||
self.model_metrics.requests_deleting.append(request)
|
||||
self.model_metrics.workload_rejected += request.workload
|
||||
request.success = False
|
||||
request.status = "Rejected"
|
||||
self.update_pending = True
|
||||
|
||||
async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]:
|
||||
while True:
|
||||
await sleep(DELETE_REQUESTS_INTERVAL)
|
||||
if len(self.model_metrics.requests_deleting) > 0:
|
||||
await self.__send_delete_requests_and_reset()
|
||||
|
||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||
while True:
|
||||
await sleep(METRICS_UPDATE_INTERVAL)
|
||||
elapsed = time.time() - self.last_metric_update
|
||||
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
||||
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
||||
await self.__send_metrics_and_reset()
|
||||
elif self.update_pending or elapsed > 10:
|
||||
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
||||
await self.__send_metrics_and_reset()
|
||||
|
||||
def _model_loaded(self, max_throughput: float) -> None:
|
||||
self.system_metrics.model_loading_time = (
|
||||
time.time() - self.system_metrics.model_loading_start
|
||||
)
|
||||
self.system_metrics.model_is_loaded = True
|
||||
self.model_metrics.max_throughput = max_throughput
|
||||
|
||||
def _model_errored(self, error_msg: str) -> None:
|
||||
self.model_metrics.set_errored(error_msg)
|
||||
self.system_metrics.model_is_loaded = True
|
||||
|
||||
def _set_version(self, version: str) -> None:
|
||||
self.version = version
|
||||
|
||||
def _set_mtoken(self, mtoken: str) -> None:
|
||||
self.mtoken = mtoken
|
||||
|
||||
#######################################Private#######################################
|
||||
|
||||
async def __send_delete_requests_and_reset(self):
|
||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||
data = {
|
||||
"worker_id": self.id,
|
||||
"mtoken": self.mtoken,
|
||||
"request_idxs": idxs,
|
||||
"success": success_flag,
|
||||
}
|
||||
log.debug(
|
||||
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
|
||||
)
|
||||
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
session = await self.http()
|
||||
async with session.post(full_path, json=data) as res:
|
||||
log.debug(f"delete_requests response: {res.status}")
|
||||
res.raise_for_status()
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
log.debug("delete_requests timed out")
|
||||
except (ClientResponseError, Exception) as e:
|
||||
log.debug(f"delete_requests failed with error: {e}")
|
||||
await asyncio.sleep(2)
|
||||
log.debug(f"retrying delete_request, attempt: {attempt}")
|
||||
return False
|
||||
|
||||
# Take a snapshot of what we plan to send this tick.
|
||||
# New arrivals after this snapshot will remain in the queue for the next tick.
|
||||
snapshot = list(self.model_metrics.requests_deleting)
|
||||
success_idxs = [r.request_idx for r in snapshot if r.success is True]
|
||||
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
|
||||
|
||||
if not success_idxs and not failed_idxs:
|
||||
return # nothing to do
|
||||
|
||||
for report_addr in self.report_addr:
|
||||
# TODO: Add a Redis subscriber queue for delete_requests
|
||||
if report_addr == "https://cloud.vast.ai/api/v0":
|
||||
# Patch: ignore the Redis API report_addr
|
||||
continue
|
||||
sent_success = True
|
||||
sent_failed = True
|
||||
|
||||
if success_idxs:
|
||||
sent_success = await post(report_addr, success_idxs, True)
|
||||
if failed_idxs:
|
||||
sent_failed = await post(report_addr, failed_idxs, False)
|
||||
|
||||
if sent_success and sent_failed:
|
||||
# Remove only the items we actually sent from the live queue.
|
||||
sent_set = set(success_idxs) | set(failed_idxs)
|
||||
self.model_metrics.requests_deleting[:] = [
|
||||
r for r in self.model_metrics.requests_deleting
|
||||
if r.request_idx not in sent_set
|
||||
]
|
||||
break
|
||||
|
||||
|
||||
async def __send_metrics_and_reset(self):
|
||||
|
||||
loadtime_snapshot = self.system_metrics.model_loading_time
|
||||
|
||||
def compute_autoscaler_data() -> AutoScalerData:
|
||||
return AutoScalerData(
|
||||
id=self.id,
|
||||
mtoken=self.mtoken,
|
||||
version=self.version,
|
||||
loadtime=(loadtime_snapshot or 0.0),
|
||||
new_load=self.model_metrics.workload_processing,
|
||||
cur_load=self.model_metrics.cur_load,
|
||||
rej_load=self.model_metrics.workload_rejected,
|
||||
max_perf=self.model_metrics.max_throughput,
|
||||
cur_perf=self.model_metrics.workload_served,
|
||||
error_msg=self.model_metrics.error_msg or "",
|
||||
num_requests_working=len(self.model_metrics.requests_working),
|
||||
num_requests_recieved=len(self.model_metrics.requests_recieved),
|
||||
additional_disk_usage=self.system_metrics.additional_disk_usage,
|
||||
working_request_idxs=self.model_metrics.working_request_idxs,
|
||||
cur_capacity=0,
|
||||
max_capacity=0,
|
||||
url=self.url,
|
||||
)
|
||||
|
||||
async def send_data(report_addr: str) -> bool:
|
||||
data = compute_autoscaler_data()
|
||||
log_data = asdict(data)
|
||||
def obfuscate(secret: str) -> str:
|
||||
if secret is None:
|
||||
return ""
|
||||
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
||||
|
||||
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
||||
log.debug(
|
||||
"\n".join(
|
||||
[
|
||||
"#" * 60,
|
||||
f"sending data to autoscaler",
|
||||
f"{json.dumps(log_data, indent=2)}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
session = await self.http()
|
||||
async with session.post(full_path, json=asdict(data)) as res:
|
||||
res.raise_for_status()
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
log.debug(f"autoscaler status update timed out")
|
||||
except (ClientResponseError, Exception) as e:
|
||||
log.debug(f"autoscaler status update failed with error: {e}")
|
||||
await asyncio.sleep(2)
|
||||
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
|
||||
log.debug(f"failed to send update through {report_addr}")
|
||||
return False
|
||||
|
||||
###########
|
||||
|
||||
self.system_metrics.update_disk_usage()
|
||||
|
||||
sent = False
|
||||
for report_addr in self.report_addr:
|
||||
if await send_data(report_addr):
|
||||
sent = True
|
||||
break
|
||||
|
||||
if sent:
|
||||
# clear the one-shot loadtime only if we actually sent *this* value
|
||||
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||
self.update_pending = False
|
||||
self.model_metrics.reset()
|
||||
self.last_metric_update = time.time()
|
||||
@@ -1,60 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List
|
||||
import ssl
|
||||
from asyncio import run, gather
|
||||
import asyncio
|
||||
|
||||
from lib.backend import Backend
|
||||
from lib.metrics import Metrics
|
||||
from aiohttp import web
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||
try:
|
||||
log.debug("getting certificate...")
|
||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||
if use_ssl is True:
|
||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context.load_cert_chain(
|
||||
certfile="/etc/instance.crt",
|
||||
keyfile="/etc/instance.key",
|
||||
)
|
||||
else:
|
||||
ssl_context = None
|
||||
|
||||
async def main():
|
||||
log.debug("starting server...")
|
||||
app = web.Application()
|
||||
app.add_routes(routes)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(
|
||||
runner,
|
||||
ssl_context=ssl_context,
|
||||
port=int(os.environ["WORKER_PORT"]),
|
||||
**kwargs
|
||||
)
|
||||
await gather(site.start(), backend._start_tracking())
|
||||
|
||||
run(main())
|
||||
|
||||
except Exception as e:
|
||||
err_msg = f"PyWorker failed to launch: {e}"
|
||||
log.error(err_msg)
|
||||
|
||||
async def beacon():
|
||||
metrics = Metrics()
|
||||
metrics._set_version(getattr(backend, "version", "0"))
|
||||
metrics._set_mtoken(getattr(backend, "mtoken", ""))
|
||||
try:
|
||||
while True:
|
||||
metrics._model_errored(err_msg)
|
||||
await metrics._Metrics__send_metrics_and_reset()
|
||||
await asyncio.sleep(10)
|
||||
finally:
|
||||
await metrics.aclose()
|
||||
|
||||
run(beacon())
|
||||
@@ -1,310 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
from typing import Callable, List, Dict, Tuple, Dict, Any, Type
|
||||
from time import sleep
|
||||
import threading
|
||||
from enum import Enum
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from urllib.parse import urljoin
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
import requests
|
||||
|
||||
from lib.data_types import AuthData, ApiPayload
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class ClientStatus(Enum):
|
||||
FetchEndpoint = 1
|
||||
Generating = 2
|
||||
Done = 3
|
||||
Error = 4
|
||||
|
||||
|
||||
total_success = 0
|
||||
last_res = []
|
||||
stop_event = threading.Event()
|
||||
|
||||
start_time = time.time()
|
||||
test_args = argparse.ArgumentParser(description="Test inference endpoint")
|
||||
test_args.add_argument(
|
||||
"-k", dest="api_key", type=str, required=True, help="Your vast account API key"
|
||||
)
|
||||
test_args.add_argument(
|
||||
"-e",
|
||||
dest="endpoint_group_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Endpoint group name",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"-l",
|
||||
dest="server_url",
|
||||
action="store_const",
|
||||
const="http://localhost:8081",
|
||||
default="https://run.vast.ai",
|
||||
help="Call local autoscaler instead of prod, for dev use only",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"-i",
|
||||
dest="instance",
|
||||
type=str,
|
||||
default="prod",
|
||||
help="Autoscaler shard to run the command against, default: prod",
|
||||
)
|
||||
|
||||
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
|
||||
|
||||
|
||||
def print_truncate_res(res: str):
|
||||
if len(res) > 150:
|
||||
print(f"{res[:50]}....{res[-100:]}")
|
||||
else:
|
||||
print(res)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientState:
|
||||
endpoint_group_name: str
|
||||
api_key: str
|
||||
server_url: str
|
||||
worker_endpoint: str
|
||||
instance: str
|
||||
payload: ApiPayload
|
||||
url: str = ""
|
||||
status: ClientStatus = ClientStatus.FetchEndpoint
|
||||
as_error: List[str] = field(default_factory=list)
|
||||
infer_error: List[str] = field(default_factory=list)
|
||||
conn_errors: Counter = field(default_factory=Counter)
|
||||
|
||||
def make_call(self):
|
||||
self.status = ClientStatus.FetchEndpoint
|
||||
if not self.api_key:
|
||||
self.as_error.append(
|
||||
f"Endpoint {self.endpoint_group_name} not found for API key",
|
||||
)
|
||||
self.status = ClientStatus.Error
|
||||
return
|
||||
route_payload = {
|
||||
"endpoint": self.endpoint_group_name,
|
||||
"api_key": self.api_key,
|
||||
"cost": self.payload.count_workload(),
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
response = requests.post(
|
||||
urljoin(self.server_url, "/route/"),
|
||||
json=route_payload,
|
||||
headers=headers,
|
||||
timeout=4,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
self.as_error.append(
|
||||
f"code: {response.status_code}, body: {response.text}",
|
||||
)
|
||||
self.status = ClientStatus.Error
|
||||
return
|
||||
message = response.json()
|
||||
worker_address = message["url"]
|
||||
req_data = dict(
|
||||
payload=asdict(self.payload),
|
||||
auth_data=asdict(AuthData.from_json_msg(message)),
|
||||
)
|
||||
self.url = worker_address
|
||||
url = urljoin(worker_address, self.worker_endpoint)
|
||||
self.status = ClientStatus.Generating
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
self.infer_error.append(
|
||||
f"code: {response.status_code}, body: {response.text}, url: {url}",
|
||||
)
|
||||
self.status = ClientStatus.Error
|
||||
return
|
||||
res = str(response.json())
|
||||
global total_success
|
||||
global last_res
|
||||
total_success += 1
|
||||
last_res.append(res)
|
||||
self.status = ClientStatus.Done
|
||||
|
||||
def simulate_user(self) -> None:
|
||||
try:
|
||||
self.make_call()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.status = ClientStatus.Error
|
||||
_ = e
|
||||
self.conn_errors[self.url] += 1
|
||||
|
||||
|
||||
def print_state(clients: List[ClientState], num_clients: int) -> None:
|
||||
print("starting up...")
|
||||
sleep(2)
|
||||
center_size = 14
|
||||
global start_time
|
||||
while len(clients) < num_clients or (
|
||||
any(
|
||||
map(
|
||||
lambda client: client.status
|
||||
in [ClientStatus.FetchEndpoint, ClientStatus.Generating],
|
||||
clients,
|
||||
)
|
||||
)
|
||||
):
|
||||
sleep(0.5)
|
||||
os.system("clear")
|
||||
print(
|
||||
" | ".join(
|
||||
[member.name.center(center_size) for member in ClientStatus]
|
||||
+ [
|
||||
item.center(center_size)
|
||||
for item in [
|
||||
"urls",
|
||||
"as_error",
|
||||
"infer_error",
|
||||
"conn_error",
|
||||
"total_success",
|
||||
]
|
||||
]
|
||||
)
|
||||
)
|
||||
unique_urls = len(set([c.url for c in clients if c.url != ""]))
|
||||
as_errors = sum(
|
||||
map(
|
||||
lambda client: len(client.as_error),
|
||||
[client for client in clients],
|
||||
)
|
||||
)
|
||||
infer_errors = sum(
|
||||
map(
|
||||
lambda client: len(client.infer_error),
|
||||
[client for client in clients],
|
||||
)
|
||||
)
|
||||
conn_errors = sum([client.conn_errors for client in clients], start=Counter())
|
||||
conn_errors_str = ",".join(map(str, conn_errors.values())) or "0"
|
||||
elapsed = time.time() - start_time
|
||||
print(
|
||||
" | ".join(
|
||||
map(
|
||||
lambda item: str(item).center(center_size),
|
||||
[
|
||||
len(list(filter(lambda x: x.status == member, clients)))
|
||||
for member in ClientStatus
|
||||
]
|
||||
+ [
|
||||
unique_urls,
|
||||
as_errors,
|
||||
infer_errors,
|
||||
conn_errors_str,
|
||||
f"{total_success}({((total_success/elapsed) * 60):.2f}/minute)",
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
if conn_errors:
|
||||
print("conn_errors:")
|
||||
for url, count in conn_errors.items():
|
||||
print(url.ljust(28), ": ", str(count))
|
||||
elapsed = time.time() - start_time
|
||||
print(f"\n elapsed: {int(elapsed // 60)}:{int(elapsed % 60)}")
|
||||
if last_res:
|
||||
for i, res in enumerate(last_res[-10:]):
|
||||
print_truncate_res(f"res #{1+i+max(len(last_res )-10,0)}: {res}")
|
||||
if stop_event.is_set():
|
||||
print("\n### waiting for existing connections to close ###")
|
||||
|
||||
|
||||
def run_test(
|
||||
num_requests: int,
|
||||
requests_per_second: int,
|
||||
endpoint_group_name: str,
|
||||
api_key: str,
|
||||
server_url: str,
|
||||
worker_endpoint: str,
|
||||
payload_cls: Type[ApiPayload],
|
||||
instance: str,
|
||||
):
|
||||
threads = []
|
||||
|
||||
clients = []
|
||||
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
|
||||
print_thread.daemon = True # makes threads get killed on program exit
|
||||
print_thread.start()
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
|
||||
)
|
||||
if not endpoint_api_key:
|
||||
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
||||
return
|
||||
try:
|
||||
for _ in range(num_requests):
|
||||
client = ClientState(
|
||||
endpoint_group_name=endpoint_group_name,
|
||||
api_key=endpoint_api_key,
|
||||
server_url=server_url,
|
||||
worker_endpoint=worker_endpoint,
|
||||
payload=payload_cls.for_test(),
|
||||
instance=instance,
|
||||
)
|
||||
clients.append(client)
|
||||
thread = threading.Thread(target=client.simulate_user, args=())
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
sleep(1 / requests_per_second)
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
print("done spawning workers")
|
||||
except KeyboardInterrupt:
|
||||
stop_event.set()
|
||||
|
||||
|
||||
def test_load_cmd(
|
||||
payload_cls: Type[ApiPayload], endpoint: str, arg_parser: argparse.ArgumentParser
|
||||
):
|
||||
arg_parser.add_argument(
|
||||
"-n",
|
||||
dest="num_requests",
|
||||
type=int,
|
||||
required=True,
|
||||
help="total number of requests",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"-rps",
|
||||
dest="requests_per_second",
|
||||
type=float,
|
||||
required=True,
|
||||
help="requests per second",
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
if hasattr(args, "comfy_model"):
|
||||
os.environ["COMFY_MODEL"] = args.comfy_model
|
||||
server_url = {
|
||||
"prod": "https://run.vast.ai",
|
||||
"alpha": "https://run-alpha.vast.ai",
|
||||
"candidate": "https://run-candidate.vast.ai",
|
||||
"local": "http://localhost:8080",
|
||||
}.get(args.instance, "http://localhost:8080")
|
||||
run_test(
|
||||
num_requests=args.num_requests,
|
||||
requests_per_second=args.requests_per_second,
|
||||
api_key=args.api_key,
|
||||
server_url=server_url,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
worker_endpoint=endpoint,
|
||||
payload_cls=payload_cls,
|
||||
instance=args.instance,
|
||||
)
|
||||
Reference in New Issue
Block a user