remove legacy pyworker
This commit is contained in:
-434
@@ -1,434 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import base64
|
||||
import subprocess
|
||||
import dataclasses
|
||||
import logging
|
||||
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||
from functools import cached_property
|
||||
from distutils.util import strtobool
|
||||
|
||||
from anyio import open_file
|
||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
||||
import asyncio
|
||||
|
||||
import requests
|
||||
from Crypto.Signature import pkcs1_15
|
||||
from Crypto.Hash import SHA256
|
||||
from Crypto.PublicKey import RSA
|
||||
|
||||
from lib.metrics import Metrics
|
||||
from lib.data_types import (
|
||||
AuthData,
|
||||
EndpointHandler,
|
||||
LogAction,
|
||||
ApiPayload_T,
|
||||
JsonDataException,
|
||||
RequestMetrics,
|
||||
BenchmarkResult
|
||||
)
|
||||
|
||||
VERSION = "0.2.1"
|
||||
|
||||
MSG_HISTORY_LEN = 100
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
# defines the minimum wait time between sending updates to autoscaler
|
||||
LOG_POLL_INTERVAL = 0.1
|
||||
BENCHMARK_INDICATOR_FILE = ".has_benchmark"
|
||||
MAX_PUBKEY_FETCH_ATTEMPTS = 3
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Backend:
|
||||
"""
|
||||
This class is responsible for:
|
||||
1. Tailing logs and updating load time metrics
|
||||
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
|
||||
sending the request. It also updates metrics as it makes those requests.
|
||||
3. Running a benchmark from an EndpointHandler
|
||||
"""
|
||||
|
||||
model_server_url: str
|
||||
model_log_file: str
|
||||
allow_parallel_requests: bool
|
||||
benchmark_handler: (
|
||||
EndpointHandler # this endpoint handler will be used for benchmarking
|
||||
)
|
||||
log_actions: List[Tuple[LogAction, str]]
|
||||
max_wait_time: float = 10.0
|
||||
reqnum = -1
|
||||
version = VERSION
|
||||
msg_history = []
|
||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||
unsecured: bool = dataclasses.field(
|
||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||
)
|
||||
report_addr: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
||||
)
|
||||
mtoken: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.metrics = Metrics()
|
||||
self.metrics._set_version(self.version)
|
||||
self.metrics._set_mtoken(self.mtoken)
|
||||
self._total_pubkey_fetch_errors = 0
|
||||
self._pubkey = self._fetch_pubkey()
|
||||
self.__start_healthcheck: bool = False
|
||||
|
||||
@property
|
||||
def pubkey(self) -> Optional[RSA.RsaKey]:
|
||||
if self._pubkey is None:
|
||||
self._pubkey = self._fetch_pubkey()
|
||||
return self._pubkey
|
||||
|
||||
@cached_property
|
||||
def session(self):
|
||||
log.debug(f"starting session with {self.model_server_url}")
|
||||
connector = TCPConnector(
|
||||
force_close=True, # Required for long running jobs
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
|
||||
timeout = ClientTimeout(total=None)
|
||||
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
||||
|
||||
def create_handler(
|
||||
self,
|
||||
handler: EndpointHandler[ApiPayload_T],
|
||||
) -> Callable[[web.Request], Awaitable[Union[web.Response, web.StreamResponse]]]:
|
||||
async def handler_fn(
|
||||
request: web.Request,
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
return await self.__handle_request(handler=handler, request=request)
|
||||
|
||||
return handler_fn
|
||||
|
||||
#######################################Private#######################################
|
||||
def _fetch_pubkey(self):
|
||||
report_addr = self.report_addr.rstrip("/")
|
||||
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
||||
try:
|
||||
result = subprocess.check_output(command, universal_newlines=True)
|
||||
log.debug("public key:")
|
||||
log.debug(result)
|
||||
key = RSA.import_key(result)
|
||||
if key is not None:
|
||||
return key
|
||||
except (ValueError , subprocess.CalledProcessError) as e:
|
||||
log.debug(f"Error downloading key: {e}")
|
||||
self.backend_errored("Failed to get autoscaler pubkey")
|
||||
|
||||
|
||||
async def __handle_request(
|
||||
self,
|
||||
handler: EndpointHandler[ApiPayload_T],
|
||||
request: web.Request,
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""use this function to forward requests to the model endpoint"""
|
||||
try:
|
||||
data = await request.json()
|
||||
auth_data, payload = handler.get_data_from_request(data)
|
||||
except JsonDataException as e:
|
||||
return web.json_response(data=e.message, status=422)
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||
workload = payload.count_workload()
|
||||
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
||||
|
||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||
await request.wait_for_disconnection()
|
||||
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
|
||||
self.metrics._request_canceled(request_metrics)
|
||||
raise asyncio.CancelledError
|
||||
|
||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||
try:
|
||||
response = await self.__call_api(handler=handler, payload=payload)
|
||||
status_code = response.status
|
||||
log.debug(
|
||||
" ".join(
|
||||
[
|
||||
f"request with reqnum:{request_metrics.reqnum}",
|
||||
f"returned status code: {status_code},",
|
||||
]
|
||||
)
|
||||
)
|
||||
res = await handler.generate_client_response(request, response)
|
||||
self.metrics._request_success(request_metrics)
|
||||
return res
|
||||
except requests.exceptions.RequestException as e:
|
||||
log.debug(f"[backend] Request error: {e}")
|
||||
self.metrics._request_errored(request_metrics)
|
||||
return web.Response(status=500)
|
||||
|
||||
###########
|
||||
|
||||
if self.__check_signature(auth_data) is False:
|
||||
self.metrics._request_reject(request_metrics)
|
||||
return web.Response(status=401)
|
||||
|
||||
if self.metrics.model_metrics.wait_time > self.max_wait_time:
|
||||
self.metrics._request_reject(request_metrics)
|
||||
return web.Response(status=429)
|
||||
|
||||
acquired = False
|
||||
try:
|
||||
self.metrics._request_start(request_metrics)
|
||||
if self.allow_parallel_requests is False:
|
||||
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
||||
await self.sem.acquire()
|
||||
acquired = True
|
||||
log.debug(
|
||||
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
|
||||
)
|
||||
else:
|
||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||
done, pending = await wait(
|
||||
[
|
||||
create_task(make_request()),
|
||||
create_task(cancel_api_call_if_disconnected()),
|
||||
],
|
||||
return_when=FIRST_COMPLETED,
|
||||
)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
done_task = done.pop()
|
||||
try:
|
||||
return done_task.result()
|
||||
except Exception as e:
|
||||
log.debug(f"Request task raised exception: {e}")
|
||||
return web.Response(status=500)
|
||||
except asyncio.CancelledError:
|
||||
# Client is gone. Do not write a response; just unwind.
|
||||
return web.Response(status=499)
|
||||
except Exception as e:
|
||||
log.debug(f"Exception in main handler loop {e}")
|
||||
return web.Response(status=500)
|
||||
finally:
|
||||
# Always release the semaphore if it was acquired
|
||||
if acquired:
|
||||
self.sem.release()
|
||||
self.metrics._request_end(request_metrics)
|
||||
|
||||
@cached_property
|
||||
def healthcheck_session(self):
|
||||
"""Dedicated session for healthchecks to avoid conflicts with API session"""
|
||||
log.debug("creating dedicated healthcheck session")
|
||||
connector = TCPConnector(
|
||||
force_close=True, # Keep this for isolation
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
timeout = ClientTimeout(total=10) # Reasonable timeout for healthchecks
|
||||
return ClientSession(timeout=timeout, connector=connector)
|
||||
|
||||
async def __healthcheck(self):
|
||||
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
||||
if health_check_url is None:
|
||||
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||
return
|
||||
|
||||
while True:
|
||||
await sleep(10)
|
||||
if self.__start_healthcheck is False:
|
||||
continue
|
||||
try:
|
||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||
async with self.healthcheck_session.get(health_check_url) as response:
|
||||
if response.status == 200:
|
||||
log.debug("Healthcheck successful")
|
||||
elif response.status == 503:
|
||||
log.debug(f"Healthcheck failed with status: {response.status}")
|
||||
self.backend_errored(
|
||||
f"Healthcheck failed with status: {response.status}"
|
||||
)
|
||||
else:
|
||||
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
||||
except Exception as e:
|
||||
log.debug(f"Healthcheck failed with exception: {e}")
|
||||
self.backend_errored(str(e))
|
||||
|
||||
async def _start_tracking(self) -> None:
|
||||
await gather(
|
||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
||||
)
|
||||
|
||||
def backend_errored(self, msg: str) -> None:
|
||||
self.metrics._model_errored(msg)
|
||||
|
||||
async def __call_api(
|
||||
self, handler: EndpointHandler[ApiPayload_T], payload: ApiPayload_T
|
||||
) -> ClientResponse:
|
||||
api_payload = payload.generate_payload_json()
|
||||
log.debug(f"posting to endpoint: '{handler.endpoint}', payload: {api_payload}")
|
||||
return await self.session.post(url=handler.endpoint, json=api_payload)
|
||||
|
||||
def __check_signature(self, auth_data: AuthData) -> bool:
|
||||
if self.unsecured is True:
|
||||
return True
|
||||
|
||||
def verify_signature(message, signature):
|
||||
if self.pubkey is None:
|
||||
log.debug(f"No Public Key!")
|
||||
return False
|
||||
|
||||
h = SHA256.new(message.encode())
|
||||
try:
|
||||
pkcs1_15.new(self.pubkey).verify(h, base64.b64decode(signature))
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
message = {
|
||||
key: value
|
||||
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||
if key != "signature" and key != "__request_id"
|
||||
}
|
||||
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||
log.debug(
|
||||
f"reqnum failure, got {auth_data.reqnum}, current_reqnum: {self.reqnum}"
|
||||
)
|
||||
return False
|
||||
elif message in self.msg_history:
|
||||
log.debug(f"message: {message} already in message history")
|
||||
return False
|
||||
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
|
||||
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||
self.msg_history.append(message)
|
||||
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||
return True
|
||||
else:
|
||||
log.debug(
|
||||
f"signature verification failed, sig:{auth_data.signature}, message: {message}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def __read_logs(self) -> Awaitable[NoReturn]:
|
||||
|
||||
async def run_benchmark() -> float:
|
||||
log.debug("starting benchmark")
|
||||
try:
|
||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||
log.debug("already ran benchmark")
|
||||
# trigger model load
|
||||
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||
# _ = await self.__call_api(
|
||||
# handler=self.benchmark_handler, payload=payload
|
||||
# )
|
||||
return float(f.readline())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
log.debug("Initial run to trigger model loading...")
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||
|
||||
max_throughput = 0
|
||||
sum_throughput = 0
|
||||
concurrent_requests = 10 if self.allow_parallel_requests else 1
|
||||
|
||||
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
||||
start = time.time()
|
||||
benchmark_requests = []
|
||||
|
||||
for i in range(concurrent_requests):
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
workload = payload.count_workload()
|
||||
task = self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||
benchmark_requests.append(
|
||||
BenchmarkResult(request_idx=i, workload=workload, task=task)
|
||||
)
|
||||
|
||||
responses = await gather(*[br.task for br in benchmark_requests])
|
||||
for br, response in zip(benchmark_requests, responses):
|
||||
br.response = response
|
||||
|
||||
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
||||
time_elapsed = time.time() - start
|
||||
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
||||
if successful_responses == 0:
|
||||
self.backend_errored("No successful responses from benchmark")
|
||||
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
|
||||
|
||||
throughput = total_workload / time_elapsed
|
||||
sum_throughput += throughput
|
||||
max_throughput = max(max_throughput, throughput)
|
||||
|
||||
# Log results for debugging
|
||||
log.debug(
|
||||
"\n".join(
|
||||
[
|
||||
"#" * 60,
|
||||
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
||||
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
||||
f"Throughput: {throughput} workload/s",
|
||||
f"Successful responses: {successful_responses}/{concurrent_requests}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
||||
log.debug(
|
||||
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||
)
|
||||
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||
f.write(str(max_throughput))
|
||||
return max_throughput
|
||||
|
||||
async def handle_log_line(log_line: str) -> None:
|
||||
"""
|
||||
Implement this function to handle each log line for your model.
|
||||
This function should mutate self.system_metrics and self.model_metrics
|
||||
"""
|
||||
for action, msg in self.log_actions:
|
||||
match action:
|
||||
case LogAction.ModelLoaded if msg in log_line:
|
||||
log.debug(
|
||||
f"Got log line indicating model is loaded: {log_line}"
|
||||
)
|
||||
# some backends need a few seconds after logging successful startup before
|
||||
# they can begin accepting requests
|
||||
# await sleep(5)
|
||||
try:
|
||||
max_throughput = await run_benchmark()
|
||||
self.__start_healthcheck = True
|
||||
self.metrics._model_loaded(
|
||||
max_throughput=max_throughput,
|
||||
)
|
||||
except ClientConnectorError as e:
|
||||
log.debug(
|
||||
f"failed to connect to comfyui api during benchmark"
|
||||
)
|
||||
self.backend_errored(str(e))
|
||||
case LogAction.ModelError if msg in log_line:
|
||||
log.debug(f"Got log line indicating error: {log_line}")
|
||||
self.backend_errored(msg)
|
||||
break
|
||||
case LogAction.Info if msg in log_line:
|
||||
log.debug(f"Info from model logs: {log_line}")
|
||||
|
||||
async def tail_log():
|
||||
log.debug(f"tailing file: {self.model_log_file}")
|
||||
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
|
||||
while True:
|
||||
line = await f.readline()
|
||||
if line:
|
||||
await handle_log_line(line.rstrip())
|
||||
else:
|
||||
await asyncio.sleep(LOG_POLL_INTERVAL)
|
||||
|
||||
###########
|
||||
|
||||
while True:
|
||||
if os.path.isfile(self.model_log_file) is True:
|
||||
return await tail_log()
|
||||
else:
|
||||
await sleep(1)
|
||||
@@ -1,324 +0,0 @@
|
||||
import time
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
|
||||
from aiohttp import web, ClientResponse
|
||||
import inspect
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
"""
|
||||
type variable representing an incoming payload to pyworker that will used to calculate load and will then
|
||||
be forwarded to the model
|
||||
"""
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class JsonDataException(Exception):
|
||||
def __init__(self, json_msg: Dict[str, Any]):
|
||||
self.message = json_msg
|
||||
|
||||
|
||||
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiPayload(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def for_test(cls: Type[ApiPayload_T]) -> ApiPayload_T:
|
||||
"""defines how create a payload for load testing"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count_workload(self) -> float:
|
||||
"""defines how to calculate workload for a payload"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_json_msg(
|
||||
cls: Type[ApiPayload_T], json_msg: Dict[str, Any]
|
||||
) -> ApiPayload_T:
|
||||
"""
|
||||
defines how to create an API payload from a JSON message,
|
||||
it should throw an JsonDataException if there are issues with some fields
|
||||
or they are missing in the format of
|
||||
{
|
||||
"field": "error msg"
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthData:
|
||||
"""data used to authenticate requester"""
|
||||
|
||||
cost: str
|
||||
endpoint: str
|
||||
reqnum: int
|
||||
request_idx: int
|
||||
signature: str
|
||||
url: str
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointHandler(ABC, Generic[ApiPayload_T]):
|
||||
"""
|
||||
Each model endpoint will have a handler responsible for counting workload from the incoming ApiPayload
|
||||
and converting it to json to be forwarded to model API
|
||||
"""
|
||||
|
||||
benchmark_runs: int = 8
|
||||
benchmark_words: int = 100
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def endpoint(self) -> str:
|
||||
"""the endpoint on the model API"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
"""the endpoint on the model API that is used for healthchecks"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def payload_cls(cls) -> Type[ApiPayload_T]:
|
||||
"""ApiPayload class"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def make_benchmark_payload(self) -> ApiPayload_T:
|
||||
"""defines how to create an ApiPayload for benchmarking."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""
|
||||
defines how to convert a model API response to a response to PyWorker client
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_data_from_request(
|
||||
cls, req_data: Dict[str, Any]
|
||||
) -> Tuple[AuthData, ApiPayload_T]:
|
||||
errors = {}
|
||||
auth_data: Optional[AuthData] = None
|
||||
payload: Optional[ApiPayload_T] = None
|
||||
try:
|
||||
if "auth_data" in req_data:
|
||||
auth_data = AuthData.from_json_msg(req_data["auth_data"])
|
||||
else:
|
||||
errors["auth_data"] = "field missing"
|
||||
except JsonDataException as e:
|
||||
errors["auth_data"] = e.message
|
||||
try:
|
||||
if "payload" in req_data:
|
||||
payload_cls = cls.payload_cls()
|
||||
payload = payload_cls.from_json_msg(req_data["payload"])
|
||||
else:
|
||||
errors["payload"] = "field missing"
|
||||
except JsonDataException as e:
|
||||
errors["payload"] = e.message
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
if auth_data and payload:
|
||||
return (auth_data, payload)
|
||||
else:
|
||||
raise Exception("error deserializing request data")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMetrics:
|
||||
"""General system metrics"""
|
||||
|
||||
model_loading_start: float
|
||||
model_loading_time: Union[float, None]
|
||||
last_disk_usage: float
|
||||
additional_disk_usage: float
|
||||
model_is_loaded: bool
|
||||
|
||||
@staticmethod
|
||||
def get_disk_usage_GB():
|
||||
return psutil.disk_usage("/").used / (2**30) # want units of GB
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return cls(
|
||||
model_loading_start=time.time(),
|
||||
model_loading_time=None,
|
||||
last_disk_usage=SystemMetrics.get_disk_usage_GB(),
|
||||
additional_disk_usage=0.0,
|
||||
model_is_loaded=False,
|
||||
)
|
||||
|
||||
def update_disk_usage(self):
|
||||
disk_usage = SystemMetrics.get_disk_usage_GB()
|
||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||
self.last_disk_usage = disk_usage
|
||||
|
||||
def reset(self, expected: float | None) -> None:
|
||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||
# as well: they should send model_loading_time once when they are done loading
|
||||
if self.model_loading_time == expected:
|
||||
self.model_loading_time = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestMetrics:
|
||||
"""Tracks metrics for an active request."""
|
||||
request_idx: int
|
||||
reqnum: int
|
||||
workload: float
|
||||
status: str
|
||||
success: bool = False
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
request_idx: int
|
||||
workload: float
|
||||
task: Awaitable[ClientResponse]
|
||||
response: Optional[ClientResponse] = None
|
||||
|
||||
@property
|
||||
def is_successful(self) -> bool:
|
||||
return self.response is not None and self.response.status == 200
|
||||
|
||||
@dataclass
|
||||
class ModelMetrics:
|
||||
"""Model specific metrics"""
|
||||
|
||||
# these are reset after being sent to autoscaler
|
||||
workload_served: float
|
||||
workload_received: float
|
||||
workload_cancelled: float
|
||||
workload_errored: float
|
||||
workload_rejected: float
|
||||
# these are not
|
||||
workload_pending: float
|
||||
error_msg: Optional[str]
|
||||
max_throughput: float
|
||||
requests_recieved: Set[int] = field(default_factory=set)
|
||||
requests_working: dict[int, RequestMetrics] = field(default_factory=dict)
|
||||
requests_deleting: list[RequestMetrics] = field(default_factory=list)
|
||||
last_update: float = field(default_factory=time.time)
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return cls(
|
||||
workload_pending=0.0,
|
||||
workload_served=0.0,
|
||||
workload_cancelled=0.0,
|
||||
workload_errored=0.0,
|
||||
workload_rejected=0.0,
|
||||
workload_received=0.0,
|
||||
error_msg=None,
|
||||
max_throughput=0.0,
|
||||
)
|
||||
|
||||
@property
|
||||
def workload_processing(self) -> float:
|
||||
return max(self.workload_received - self.workload_cancelled, 0.0)
|
||||
|
||||
@property
|
||||
def wait_time(self) -> float:
|
||||
if (len(self.requests_working) == 0):
|
||||
return 0.0
|
||||
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
|
||||
|
||||
@property
|
||||
def cur_load(self) -> float:
|
||||
return sum([request.workload for request in self.requests_working.values()])
|
||||
|
||||
@property
|
||||
def working_request_idxs(self) -> list[int]:
|
||||
return [req.request_idx for req in self.requests_working.values()]
|
||||
|
||||
def set_errored(self, error_msg):
|
||||
self.reset()
|
||||
self.error_msg = error_msg
|
||||
|
||||
def reset(self):
|
||||
self.workload_served = 0
|
||||
self.workload_received = 0
|
||||
self.workload_cancelled = 0
|
||||
self.workload_errored = 0
|
||||
self.workload_rejected = 0
|
||||
self.last_update = time.time()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoScalerData:
|
||||
"""Data that is reported to autoscaler"""
|
||||
|
||||
id: int
|
||||
mtoken: str
|
||||
version: str
|
||||
loadtime: float
|
||||
cur_load: float
|
||||
rej_load: float
|
||||
new_load: float
|
||||
error_msg: str
|
||||
max_perf: float
|
||||
cur_perf: float
|
||||
cur_capacity: float
|
||||
max_capacity: float
|
||||
num_requests_working: int
|
||||
num_requests_recieved: int
|
||||
additional_disk_usage: float
|
||||
working_request_idxs: list[int]
|
||||
url: str
|
||||
|
||||
|
||||
class LogAction(Enum):
|
||||
"""
|
||||
These actions tell the backend what a log value means, for example:
|
||||
actions [
|
||||
# this marks the model server as loaded
|
||||
(LogAction.ModelLoaded, "Starting server"),
|
||||
# these mark the model server as errored
|
||||
(LogAction.ModelError, "Exception loading model"),
|
||||
(LogAction.ModelError, "Server failed to bind to port"),
|
||||
# this tells the backend to print any logs containing the string into its own logs
|
||||
# which are visible in the vast console instance logs
|
||||
(LogAction.Info, "Starting model download"),
|
||||
]
|
||||
"""
|
||||
|
||||
ModelLoaded = 1
|
||||
ModelError = 2
|
||||
Info = 3
|
||||
-286
@@ -1,286 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
from asyncio import sleep
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from functools import cache
|
||||
import asyncio
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError
|
||||
|
||||
from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics
|
||||
from typing import Awaitable, NoReturn, List
|
||||
|
||||
METRICS_UPDATE_INTERVAL = 1
|
||||
DELETE_REQUESTS_INTERVAL = 1
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
@cache
|
||||
def get_url() -> str:
|
||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||
worker_port = os.environ[f"VAST_TCP_PORT_{os.environ['WORKER_PORT']}"]
|
||||
public_ip = os.environ["PUBLIC_IPADDR"]
|
||||
return f"http{'s' if use_ssl else ''}://{public_ip}:{worker_port}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metrics:
|
||||
version: str = "0"
|
||||
mtoken: str = ""
|
||||
last_metric_update: float = 0.0
|
||||
last_request_served: float = 0.0
|
||||
update_pending: bool = False
|
||||
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
|
||||
report_addr: List[str] = field(
|
||||
default_factory=lambda: os.environ["REPORT_ADDR"].split(",")
|
||||
)
|
||||
url: str = field(default_factory=get_url)
|
||||
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
|
||||
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
|
||||
_session: ClientSession | None = field(default=None, init=False, repr=False)
|
||||
|
||||
async def http(self) -> ClientSession:
|
||||
if self._session is None:
|
||||
self._session = ClientSession(
|
||||
timeout=ClientTimeout(total=10),
|
||||
connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True)
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
def _request_start(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called prior to forwarding a request to a model API.
|
||||
"""
|
||||
log.debug("request start")
|
||||
request.status = "Started"
|
||||
self.model_metrics.workload_pending += request.workload
|
||||
self.model_metrics.workload_received += request.workload
|
||||
self.model_metrics.requests_recieved.add(request.reqnum)
|
||||
self.model_metrics.requests_working[request.reqnum] = request
|
||||
self.update_pending = True
|
||||
|
||||
def _request_end(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called after handling of a request ends, regardless of the outcome
|
||||
"""
|
||||
self.model_metrics.workload_pending -= request.workload
|
||||
self.model_metrics.requests_working.pop(request.reqnum, None)
|
||||
self.model_metrics.requests_deleting.append(request)
|
||||
self.last_request_served = time.time()
|
||||
|
||||
def _request_success(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called after a response from model API is received and forwarded.
|
||||
"""
|
||||
self.model_metrics.workload_served += request.workload
|
||||
request.status = "Success"
|
||||
request.success = True
|
||||
self.update_pending = True
|
||||
|
||||
def _request_errored(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called if model API returns an error
|
||||
"""
|
||||
self.model_metrics.workload_errored += request.workload
|
||||
request.status = "Error"
|
||||
request.success = False
|
||||
self.update_pending = True
|
||||
|
||||
def _request_canceled(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called if client drops connection before model API has responded
|
||||
"""
|
||||
self.model_metrics.workload_cancelled += request.workload
|
||||
request.success = True
|
||||
request.status = "Cancelled"
|
||||
|
||||
def _request_reject(self, request: RequestMetrics):
|
||||
"""
|
||||
this function is called if the current wait time for the model is above max_wait_time
|
||||
"""
|
||||
self.model_metrics.requests_recieved.add(request.reqnum)
|
||||
self.model_metrics.requests_deleting.append(request)
|
||||
self.model_metrics.workload_rejected += request.workload
|
||||
request.success = False
|
||||
request.status = "Rejected"
|
||||
self.update_pending = True
|
||||
|
||||
async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]:
|
||||
while True:
|
||||
await sleep(DELETE_REQUESTS_INTERVAL)
|
||||
if len(self.model_metrics.requests_deleting) > 0:
|
||||
await self.__send_delete_requests_and_reset()
|
||||
|
||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||
while True:
|
||||
await sleep(METRICS_UPDATE_INTERVAL)
|
||||
elapsed = time.time() - self.last_metric_update
|
||||
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
||||
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
||||
await self.__send_metrics_and_reset()
|
||||
elif self.update_pending or elapsed > 10:
|
||||
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
||||
await self.__send_metrics_and_reset()
|
||||
|
||||
def _model_loaded(self, max_throughput: float) -> None:
|
||||
self.system_metrics.model_loading_time = (
|
||||
time.time() - self.system_metrics.model_loading_start
|
||||
)
|
||||
self.system_metrics.model_is_loaded = True
|
||||
self.model_metrics.max_throughput = max_throughput
|
||||
|
||||
def _model_errored(self, error_msg: str) -> None:
|
||||
self.model_metrics.set_errored(error_msg)
|
||||
self.system_metrics.model_is_loaded = True
|
||||
|
||||
def _set_version(self, version: str) -> None:
|
||||
self.version = version
|
||||
|
||||
def _set_mtoken(self, mtoken: str) -> None:
|
||||
self.mtoken = mtoken
|
||||
|
||||
#######################################Private#######################################
|
||||
|
||||
async def __send_delete_requests_and_reset(self):
|
||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||
data = {
|
||||
"worker_id": self.id,
|
||||
"mtoken": self.mtoken,
|
||||
"request_idxs": idxs,
|
||||
"success": success_flag,
|
||||
}
|
||||
log.debug(
|
||||
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
|
||||
)
|
||||
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
session = await self.http()
|
||||
async with session.post(full_path, json=data) as res:
|
||||
log.debug(f"delete_requests response: {res.status}")
|
||||
res.raise_for_status()
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
log.debug("delete_requests timed out")
|
||||
except (ClientResponseError, Exception) as e:
|
||||
log.debug(f"delete_requests failed with error: {e}")
|
||||
await asyncio.sleep(2)
|
||||
log.debug(f"retrying delete_request, attempt: {attempt}")
|
||||
return False
|
||||
|
||||
# Take a snapshot of what we plan to send this tick.
|
||||
# New arrivals after this snapshot will remain in the queue for the next tick.
|
||||
snapshot = list(self.model_metrics.requests_deleting)
|
||||
success_idxs = [r.request_idx for r in snapshot if r.success is True]
|
||||
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
|
||||
|
||||
if not success_idxs and not failed_idxs:
|
||||
return # nothing to do
|
||||
|
||||
for report_addr in self.report_addr:
|
||||
# TODO: Add a Redis subscriber queue for delete_requests
|
||||
if report_addr == "https://cloud.vast.ai/api/v0":
|
||||
# Patch: ignore the Redis API report_addr
|
||||
continue
|
||||
sent_success = True
|
||||
sent_failed = True
|
||||
|
||||
if success_idxs:
|
||||
sent_success = await post(report_addr, success_idxs, True)
|
||||
if failed_idxs:
|
||||
sent_failed = await post(report_addr, failed_idxs, False)
|
||||
|
||||
if sent_success and sent_failed:
|
||||
# Remove only the items we actually sent from the live queue.
|
||||
sent_set = set(success_idxs) | set(failed_idxs)
|
||||
self.model_metrics.requests_deleting[:] = [
|
||||
r for r in self.model_metrics.requests_deleting
|
||||
if r.request_idx not in sent_set
|
||||
]
|
||||
break
|
||||
|
||||
|
||||
async def __send_metrics_and_reset(self):
|
||||
|
||||
loadtime_snapshot = self.system_metrics.model_loading_time
|
||||
|
||||
def compute_autoscaler_data() -> AutoScalerData:
|
||||
return AutoScalerData(
|
||||
id=self.id,
|
||||
mtoken=self.mtoken,
|
||||
version=self.version,
|
||||
loadtime=(loadtime_snapshot or 0.0),
|
||||
new_load=self.model_metrics.workload_processing,
|
||||
cur_load=self.model_metrics.cur_load,
|
||||
rej_load=self.model_metrics.workload_rejected,
|
||||
max_perf=self.model_metrics.max_throughput,
|
||||
cur_perf=self.model_metrics.workload_served,
|
||||
error_msg=self.model_metrics.error_msg or "",
|
||||
num_requests_working=len(self.model_metrics.requests_working),
|
||||
num_requests_recieved=len(self.model_metrics.requests_recieved),
|
||||
additional_disk_usage=self.system_metrics.additional_disk_usage,
|
||||
working_request_idxs=self.model_metrics.working_request_idxs,
|
||||
cur_capacity=0,
|
||||
max_capacity=0,
|
||||
url=self.url,
|
||||
)
|
||||
|
||||
async def send_data(report_addr: str) -> bool:
|
||||
data = compute_autoscaler_data()
|
||||
log_data = asdict(data)
|
||||
def obfuscate(secret: str) -> str:
|
||||
if secret is None:
|
||||
return ""
|
||||
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
||||
|
||||
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
||||
log.debug(
|
||||
"\n".join(
|
||||
[
|
||||
"#" * 60,
|
||||
f"sending data to autoscaler",
|
||||
f"{json.dumps(log_data, indent=2)}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
session = await self.http()
|
||||
async with session.post(full_path, json=asdict(data)) as res:
|
||||
res.raise_for_status()
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
log.debug(f"autoscaler status update timed out")
|
||||
except (ClientResponseError, Exception) as e:
|
||||
log.debug(f"autoscaler status update failed with error: {e}")
|
||||
await asyncio.sleep(2)
|
||||
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
|
||||
log.debug(f"failed to send update through {report_addr}")
|
||||
return False
|
||||
|
||||
###########
|
||||
|
||||
self.system_metrics.update_disk_usage()
|
||||
|
||||
sent = False
|
||||
for report_addr in self.report_addr:
|
||||
if await send_data(report_addr):
|
||||
sent = True
|
||||
break
|
||||
|
||||
if sent:
|
||||
# clear the one-shot loadtime only if we actually sent *this* value
|
||||
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||
self.update_pending = False
|
||||
self.model_metrics.reset()
|
||||
self.last_metric_update = time.time()
|
||||
@@ -1,60 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List
|
||||
import ssl
|
||||
from asyncio import run, gather
|
||||
import asyncio
|
||||
|
||||
from lib.backend import Backend
|
||||
from lib.metrics import Metrics
|
||||
from aiohttp import web
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||
try:
|
||||
log.debug("getting certificate...")
|
||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||
if use_ssl is True:
|
||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context.load_cert_chain(
|
||||
certfile="/etc/instance.crt",
|
||||
keyfile="/etc/instance.key",
|
||||
)
|
||||
else:
|
||||
ssl_context = None
|
||||
|
||||
async def main():
|
||||
log.debug("starting server...")
|
||||
app = web.Application()
|
||||
app.add_routes(routes)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(
|
||||
runner,
|
||||
ssl_context=ssl_context,
|
||||
port=int(os.environ["WORKER_PORT"]),
|
||||
**kwargs
|
||||
)
|
||||
await gather(site.start(), backend._start_tracking())
|
||||
|
||||
run(main())
|
||||
|
||||
except Exception as e:
|
||||
err_msg = f"PyWorker failed to launch: {e}"
|
||||
log.error(err_msg)
|
||||
|
||||
async def beacon():
|
||||
metrics = Metrics()
|
||||
metrics._set_version(getattr(backend, "version", "0"))
|
||||
metrics._set_mtoken(getattr(backend, "mtoken", ""))
|
||||
try:
|
||||
while True:
|
||||
metrics._model_errored(err_msg)
|
||||
await metrics._Metrics__send_metrics_and_reset()
|
||||
await asyncio.sleep(10)
|
||||
finally:
|
||||
await metrics.aclose()
|
||||
|
||||
run(beacon())
|
||||
@@ -1,310 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
from typing import Callable, List, Dict, Tuple, Dict, Any, Type
|
||||
from time import sleep
|
||||
import threading
|
||||
from enum import Enum
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from urllib.parse import urljoin
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
import requests
|
||||
|
||||
from lib.data_types import AuthData, ApiPayload
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class ClientStatus(Enum):
|
||||
FetchEndpoint = 1
|
||||
Generating = 2
|
||||
Done = 3
|
||||
Error = 4
|
||||
|
||||
|
||||
total_success = 0
|
||||
last_res = []
|
||||
stop_event = threading.Event()
|
||||
|
||||
start_time = time.time()
|
||||
test_args = argparse.ArgumentParser(description="Test inference endpoint")
|
||||
test_args.add_argument(
|
||||
"-k", dest="api_key", type=str, required=True, help="Your vast account API key"
|
||||
)
|
||||
test_args.add_argument(
|
||||
"-e",
|
||||
dest="endpoint_group_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Endpoint group name",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"-l",
|
||||
dest="server_url",
|
||||
action="store_const",
|
||||
const="http://localhost:8081",
|
||||
default="https://run.vast.ai",
|
||||
help="Call local autoscaler instead of prod, for dev use only",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"-i",
|
||||
dest="instance",
|
||||
type=str,
|
||||
default="prod",
|
||||
help="Autoscaler shard to run the command against, default: prod",
|
||||
)
|
||||
|
||||
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
|
||||
|
||||
|
||||
def print_truncate_res(res: str):
|
||||
if len(res) > 150:
|
||||
print(f"{res[:50]}....{res[-100:]}")
|
||||
else:
|
||||
print(res)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientState:
|
||||
endpoint_group_name: str
|
||||
api_key: str
|
||||
server_url: str
|
||||
worker_endpoint: str
|
||||
instance: str
|
||||
payload: ApiPayload
|
||||
url: str = ""
|
||||
status: ClientStatus = ClientStatus.FetchEndpoint
|
||||
as_error: List[str] = field(default_factory=list)
|
||||
infer_error: List[str] = field(default_factory=list)
|
||||
conn_errors: Counter = field(default_factory=Counter)
|
||||
|
||||
def make_call(self):
|
||||
self.status = ClientStatus.FetchEndpoint
|
||||
if not self.api_key:
|
||||
self.as_error.append(
|
||||
f"Endpoint {self.endpoint_group_name} not found for API key",
|
||||
)
|
||||
self.status = ClientStatus.Error
|
||||
return
|
||||
route_payload = {
|
||||
"endpoint": self.endpoint_group_name,
|
||||
"api_key": self.api_key,
|
||||
"cost": self.payload.count_workload(),
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
response = requests.post(
|
||||
urljoin(self.server_url, "/route/"),
|
||||
json=route_payload,
|
||||
headers=headers,
|
||||
timeout=4,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
self.as_error.append(
|
||||
f"code: {response.status_code}, body: {response.text}",
|
||||
)
|
||||
self.status = ClientStatus.Error
|
||||
return
|
||||
message = response.json()
|
||||
worker_address = message["url"]
|
||||
req_data = dict(
|
||||
payload=asdict(self.payload),
|
||||
auth_data=asdict(AuthData.from_json_msg(message)),
|
||||
)
|
||||
self.url = worker_address
|
||||
url = urljoin(worker_address, self.worker_endpoint)
|
||||
self.status = ClientStatus.Generating
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
self.infer_error.append(
|
||||
f"code: {response.status_code}, body: {response.text}, url: {url}",
|
||||
)
|
||||
self.status = ClientStatus.Error
|
||||
return
|
||||
res = str(response.json())
|
||||
global total_success
|
||||
global last_res
|
||||
total_success += 1
|
||||
last_res.append(res)
|
||||
self.status = ClientStatus.Done
|
||||
|
||||
def simulate_user(self) -> None:
|
||||
try:
|
||||
self.make_call()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.status = ClientStatus.Error
|
||||
_ = e
|
||||
self.conn_errors[self.url] += 1
|
||||
|
||||
|
||||
def print_state(clients: List[ClientState], num_clients: int) -> None:
|
||||
print("starting up...")
|
||||
sleep(2)
|
||||
center_size = 14
|
||||
global start_time
|
||||
while len(clients) < num_clients or (
|
||||
any(
|
||||
map(
|
||||
lambda client: client.status
|
||||
in [ClientStatus.FetchEndpoint, ClientStatus.Generating],
|
||||
clients,
|
||||
)
|
||||
)
|
||||
):
|
||||
sleep(0.5)
|
||||
os.system("clear")
|
||||
print(
|
||||
" | ".join(
|
||||
[member.name.center(center_size) for member in ClientStatus]
|
||||
+ [
|
||||
item.center(center_size)
|
||||
for item in [
|
||||
"urls",
|
||||
"as_error",
|
||||
"infer_error",
|
||||
"conn_error",
|
||||
"total_success",
|
||||
]
|
||||
]
|
||||
)
|
||||
)
|
||||
unique_urls = len(set([c.url for c in clients if c.url != ""]))
|
||||
as_errors = sum(
|
||||
map(
|
||||
lambda client: len(client.as_error),
|
||||
[client for client in clients],
|
||||
)
|
||||
)
|
||||
infer_errors = sum(
|
||||
map(
|
||||
lambda client: len(client.infer_error),
|
||||
[client for client in clients],
|
||||
)
|
||||
)
|
||||
conn_errors = sum([client.conn_errors for client in clients], start=Counter())
|
||||
conn_errors_str = ",".join(map(str, conn_errors.values())) or "0"
|
||||
elapsed = time.time() - start_time
|
||||
print(
|
||||
" | ".join(
|
||||
map(
|
||||
lambda item: str(item).center(center_size),
|
||||
[
|
||||
len(list(filter(lambda x: x.status == member, clients)))
|
||||
for member in ClientStatus
|
||||
]
|
||||
+ [
|
||||
unique_urls,
|
||||
as_errors,
|
||||
infer_errors,
|
||||
conn_errors_str,
|
||||
f"{total_success}({((total_success/elapsed) * 60):.2f}/minute)",
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
if conn_errors:
|
||||
print("conn_errors:")
|
||||
for url, count in conn_errors.items():
|
||||
print(url.ljust(28), ": ", str(count))
|
||||
elapsed = time.time() - start_time
|
||||
print(f"\n elapsed: {int(elapsed // 60)}:{int(elapsed % 60)}")
|
||||
if last_res:
|
||||
for i, res in enumerate(last_res[-10:]):
|
||||
print_truncate_res(f"res #{1+i+max(len(last_res )-10,0)}: {res}")
|
||||
if stop_event.is_set():
|
||||
print("\n### waiting for existing connections to close ###")
|
||||
|
||||
|
||||
def run_test(
|
||||
num_requests: int,
|
||||
requests_per_second: int,
|
||||
endpoint_group_name: str,
|
||||
api_key: str,
|
||||
server_url: str,
|
||||
worker_endpoint: str,
|
||||
payload_cls: Type[ApiPayload],
|
||||
instance: str,
|
||||
):
|
||||
threads = []
|
||||
|
||||
clients = []
|
||||
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
|
||||
print_thread.daemon = True # makes threads get killed on program exit
|
||||
print_thread.start()
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
|
||||
)
|
||||
if not endpoint_api_key:
|
||||
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
||||
return
|
||||
try:
|
||||
for _ in range(num_requests):
|
||||
client = ClientState(
|
||||
endpoint_group_name=endpoint_group_name,
|
||||
api_key=endpoint_api_key,
|
||||
server_url=server_url,
|
||||
worker_endpoint=worker_endpoint,
|
||||
payload=payload_cls.for_test(),
|
||||
instance=instance,
|
||||
)
|
||||
clients.append(client)
|
||||
thread = threading.Thread(target=client.simulate_user, args=())
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
sleep(1 / requests_per_second)
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
print("done spawning workers")
|
||||
except KeyboardInterrupt:
|
||||
stop_event.set()
|
||||
|
||||
|
||||
def test_load_cmd(
|
||||
payload_cls: Type[ApiPayload], endpoint: str, arg_parser: argparse.ArgumentParser
|
||||
):
|
||||
arg_parser.add_argument(
|
||||
"-n",
|
||||
dest="num_requests",
|
||||
type=int,
|
||||
required=True,
|
||||
help="total number of requests",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"-rps",
|
||||
dest="requests_per_second",
|
||||
type=float,
|
||||
required=True,
|
||||
help="requests per second",
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
if hasattr(args, "comfy_model"):
|
||||
os.environ["COMFY_MODEL"] = args.comfy_model
|
||||
server_url = {
|
||||
"prod": "https://run.vast.ai",
|
||||
"alpha": "https://run-alpha.vast.ai",
|
||||
"candidate": "https://run-candidate.vast.ai",
|
||||
"local": "http://localhost:8080",
|
||||
}.get(args.instance, "http://localhost:8080")
|
||||
run_test(
|
||||
num_requests=args.num_requests,
|
||||
requests_per_second=args.requests_per_second,
|
||||
api_key=args.api_key,
|
||||
server_url=server_url,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
worker_endpoint=endpoint,
|
||||
payload_cls=payload_cls,
|
||||
instance=args.instance,
|
||||
)
|
||||
@@ -1,136 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class Endpoint:
|
||||
"""
|
||||
Utility class for handling endpoint operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_endpoint_info(
|
||||
endpoint_name: str, account_api_key: str, instance: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||
url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}"
|
||||
# Retry a few times to smooth over transient propagation/network delays
|
||||
for attempt in range(4):
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=8)
|
||||
if response.status_code != 200:
|
||||
# brief backoff and retry
|
||||
time.sleep(0.3 * (attempt + 1))
|
||||
continue
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception:
|
||||
# JSON parse failed; backoff and retry
|
||||
time.sleep(0.3 * (attempt + 1))
|
||||
continue
|
||||
result = data.get("results", []) if isinstance(data, dict) else []
|
||||
endpoint = next(
|
||||
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||
None,
|
||||
)
|
||||
if endpoint and endpoint.get("id") and endpoint.get("api_key"):
|
||||
return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")}
|
||||
except Exception:
|
||||
# network or other transient error; retry
|
||||
time.sleep(0.3 * (attempt + 1))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_autoscaler_server_url(instance: str) -> str:
|
||||
endpoints = {
|
||||
"alpha": "run-alpha",
|
||||
"candidate": "run-candidate",
|
||||
"prod": "run",
|
||||
}
|
||||
host = endpoints.get(instance)
|
||||
if host:
|
||||
return f"https://{host}.vast.ai/"
|
||||
return "http://localhost:8080"
|
||||
|
||||
@staticmethod
|
||||
def get_server_url(instance: str) -> str:
|
||||
endpoints = {
|
||||
"alpha": "alpha",
|
||||
"candidate": "candidate",
|
||||
"prod": "console",
|
||||
}
|
||||
host = endpoints.get(instance, "alpha")
|
||||
return f"https://{host}.vast.ai/api/v0/endptjobs/"
|
||||
|
||||
@staticmethod
|
||||
def get_endpoint_api_key(
|
||||
endpoint_name: str, account_api_key: str, instance: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Fetch endpoint API key from VastAI console following the healthcheck pattern.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the endpoint
|
||||
account_api_key: Account API key for authentication
|
||||
|
||||
Returns:
|
||||
Endpoint API key if successful, None otherwise
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||
|
||||
try:
|
||||
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
||||
response = requests.get(
|
||||
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
||||
headers=headers,
|
||||
timeout=8,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
|
||||
log.debug(error_msg)
|
||||
return None
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to parse JSON response: {e}")
|
||||
return None
|
||||
|
||||
result = data.get("results", [])
|
||||
|
||||
endpoint: Optional[Dict[str, Any]] = next(
|
||||
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||
None,
|
||||
)
|
||||
if not endpoint:
|
||||
error_msg = f"Endpoint '{endpoint_name}' not found."
|
||||
log.debug(error_msg)
|
||||
return None
|
||||
|
||||
endpoint_api_key = endpoint.get("api_key")
|
||||
if not endpoint_api_key:
|
||||
error_msg = f"API key for endpoint '{endpoint_name}' not found."
|
||||
log.debug(error_msg)
|
||||
return None
|
||||
|
||||
log.debug(f"Successfully retrieved API key for endpoint: {endpoint_name}")
|
||||
return endpoint_api_key
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_msg = f"Request error while fetching endpoint API key: {e}"
|
||||
log.debug(error_msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error while fetching endpoint API key: {e}"
|
||||
log.debug(error_msg)
|
||||
return None
|
||||
@@ -1,15 +0,0 @@
|
||||
import tempfile
|
||||
from functools import cache
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@cache
|
||||
def get_cert_file_path():
|
||||
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
|
||||
response = requests.get(cert_url)
|
||||
response.raise_for_status()
|
||||
# Use a temporary file that is not deleted on close
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
|
||||
f.write(response.content)
|
||||
return f.name
|
||||
@@ -1,84 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import dataclasses
|
||||
from typing import Dict, Any
|
||||
from functools import cache
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
import json
|
||||
import logging
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
def count_workload() -> float:
|
||||
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
||||
# Results will indicate % or a job completed per second. Avoids sub 0.1 sec performance indication
|
||||
return 100.0
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ComfyWorkflowData(ApiPayload):
|
||||
input: dict
|
||||
|
||||
@classmethod
|
||||
def for_test(cls):
|
||||
"""
|
||||
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
|
||||
Otherwise, use the variables available to simulate workflows of the required running time
|
||||
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
||||
"""
|
||||
# Try to load benchmark.json
|
||||
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
|
||||
|
||||
if benchmark_file.exists():
|
||||
try:
|
||||
with open(benchmark_file, "r") as f:
|
||||
benchmark_workflow = json.load(f)
|
||||
return cls(
|
||||
input={
|
||||
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||
"workflow_json": benchmark_workflow
|
||||
}
|
||||
)
|
||||
except (json.JSONDecodeError, IOError):
|
||||
# JSON is malformed or file can't be read, fall through to default
|
||||
log.error(f"Failed to benchmark using {benchmark_file}")
|
||||
|
||||
# Fallback: read prompts and construct payload
|
||||
log.info("Using fallback method for benchmarking")
|
||||
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
|
||||
test_prompts = f.readlines()
|
||||
|
||||
test_prompt = random.choice(test_prompts).rstrip()
|
||||
return cls(
|
||||
input={
|
||||
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||
"modifier": "Text2Image",
|
||||
"modifications": {
|
||||
"prompt": test_prompt,
|
||||
"width": os.getenv('BENCHMARK_TEST_WIDTH', 512),
|
||||
"height": os.getenv('BENCHMARK_TEST_HEIGHT', 512),
|
||||
"steps": os.getenv('BENCHMARK_TEST_STEPS', 20),
|
||||
"seed": random.randint(0, sys.maxsize),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
# input is already a dict, just return it wrapped in the expected structure
|
||||
return {"input": self.input}
|
||||
|
||||
def count_workload(self) -> float:
|
||||
return count_workload()
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ComfyWorkflowData":
|
||||
# Extract required fields
|
||||
if "input" not in json_msg:
|
||||
raise JsonDataException({"input": "missing parameter"})
|
||||
|
||||
return cls(
|
||||
input=json_msg["input"]
|
||||
)
|
||||
@@ -1,107 +0,0 @@
|
||||
{
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": "__RANDOM_INT__",
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "text, watermark",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
|
||||
stardew valley, fine details
|
||||
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
|
||||
realistic futuristic city-downtown with short buildings, sunset
|
||||
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
|
||||
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
|
||||
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
|
||||
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
|
||||
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
|
||||
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
|
||||
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
|
||||
Pope Francis wearing biker (leather jacket), a masterpiece
|
||||
Luke Skywalker ordering a burger and fries from the Death Star canteen.
|
||||
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
|
||||
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
|
||||
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
|
||||
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
|
||||
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
|
||||
london luxurious interior living-room, light walls
|
||||
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
|
||||
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
|
||||
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
|
||||
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
|
||||
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
|
||||
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
|
||||
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
|
||||
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
|
||||
the street of amedieval fantasy town, at dawn, dark, highly detailed
|
||||
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
|
||||
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
|
||||
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
|
||||
@@ -1,117 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import dataclasses
|
||||
import base64
|
||||
from typing import Optional, Union, Type
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.data_types import EndpointHandler
|
||||
from lib.server import start_server
|
||||
from .data_types import ComfyWorkflowData
|
||||
|
||||
|
||||
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||
"Value not in list: ", # This error is emitted when the model file is not there at all
|
||||
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
|
||||
]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
async def generate_client_response(
|
||||
client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
# Check if the response is actually streaming based on response headers/content-type
|
||||
is_streaming_response = (
|
||||
model_response.content_type == "text/event-stream"
|
||||
or model_response.content_type == "application/x-ndjson"
|
||||
or model_response.headers.get("Transfer-Encoding") == "chunked"
|
||||
or "stream" in model_response.content_type.lower()
|
||||
)
|
||||
|
||||
if is_streaming_response:
|
||||
log.debug("Detected streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = model_response.content_type
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
else:
|
||||
log.debug("Detected non-streaming response...")
|
||||
content = await model_response.read()
|
||||
return web.Response(
|
||||
body=content,
|
||||
status=model_response.status,
|
||||
content_type=model_response.content_type
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate/sync"
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[ComfyWorkflowData]:
|
||||
return ComfyWorkflowData
|
||||
|
||||
def make_benchmark_payload(self) -> ComfyWorkflowData:
|
||||
return ComfyWorkflowData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
return await generate_client_response(client_request, model_response)
|
||||
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=False,
|
||||
benchmark_handler=ComfyWorkflowHandler(
|
||||
benchmark_runs=3, benchmark_words=100
|
||||
),
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
(LogAction.Info, "Downloading:"),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def handle_ping(_):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server(backend, routes)
|
||||
@@ -1,8 +0,0 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from .data_types import ComfyWorkflowData
|
||||
|
||||
WORKER_ENDPOINT = "/generate/sync"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_cmd(ComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||
@@ -1,92 +0,0 @@
|
||||
This is the base PyWorker for comfyui. It can be used to create PyWorker that use various models and
|
||||
workflows. It provides two endpoints:
|
||||
|
||||
1. `/prompt`: Uses the default comfy workflow defined under `misc/default_workflows`
|
||||
2. `/custom_workflow`: Allows the client to send their own comfy workflow with each API request.
|
||||
|
||||
To use the comfyui PyWorker, `$COMFY_MODEL` env variable must be set in the template. Current options are
|
||||
`sd3` and `flux`. Each have example clients.
|
||||
|
||||
To add new models, a JSON with name `$COMFY_MODEL.json` must be created under `misc/default_workflows`
|
||||
|
||||
NOTE: default workflows follow this format:
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"handler": "RawWorkflow",
|
||||
"aws_access_key_id": "your-s3-access-key",
|
||||
"aws_secret_access_key": "your-s3-secret-access-key",
|
||||
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
|
||||
"aws_bucket_name": "your-bucket",
|
||||
"webhook_url": "your-webhook-url",
|
||||
"webhook_extra_params": {},
|
||||
"workflow_json": {}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
You can ignore all of these fields except for `workflow_json`.
|
||||
|
||||
Fields written as "{{FOO}}" will be replaced using data from a user request. For example, SD3's workflow has the
|
||||
following nodes:
|
||||
|
||||
```json
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": "{{WIDTH}}",
|
||||
"height": "{{HEIGHT}}",
|
||||
"batch_size": 1
|
||||
},
|
||||
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "{{PROMPT}}",
|
||||
"clip": ["11", 0]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
...
|
||||
"17": {
|
||||
"inputs": {
|
||||
"scheduler": "simple",
|
||||
"steps": "{{STEPS}}",
|
||||
"denoise": 1,
|
||||
"model": ["12", 0]
|
||||
},
|
||||
"class_type": "BasicScheduler",
|
||||
"_meta": {
|
||||
"title": "BasicScheduler"
|
||||
}
|
||||
},
|
||||
...
|
||||
"25": {
|
||||
"inputs": {
|
||||
"noise_seed": "{{SEED}}"
|
||||
},
|
||||
"class_type": "RandomNoise",
|
||||
"_meta": {
|
||||
"title": "RandomNoise"
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
Incoming requests have the following JSON format:
|
||||
|
||||
```json
|
||||
{
|
||||
prompt: str
|
||||
width: int
|
||||
height: int
|
||||
steps: int
|
||||
seed: int
|
||||
}
|
||||
```
|
||||
|
||||
Each value in those fields with replace the placeholder of the same name in the default workflow.
|
||||
|
||||
See Vast's serverless documentation for more details on how to use comfyui with autoscaler
|
||||
@@ -1,170 +0,0 @@
|
||||
import logging
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from lib.test_utils import print_truncate_res
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
|
||||
from vastai import Serverless
|
||||
|
||||
|
||||
ENDPOINT_NAME = "my-comfyui-endpoint"
|
||||
COST = 100 # Use a constant cost for image generation
|
||||
|
||||
def call_default_workflow(client: Serverless) -> None:
|
||||
WORKER_ENDPOINT = "/prompt"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
response.raise_for_status()
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
)
|
||||
payload = dict(
|
||||
prompt="a fat fluffy cat", width=1024, height=1024, steps=20, seed=123456789
|
||||
)
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
print_truncate_res(str(response.json()))
|
||||
|
||||
|
||||
def call_custom_workflow_for_sd3(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
WORKER_ENDPOINT = "/custom-workflow"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
response.raise_for_status()
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
request_idx=message["request_idx"],
|
||||
)
|
||||
workflow = {
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": 156680208700286,
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": ["4", 0],
|
||||
"positive": ["6", 0],
|
||||
"negative": ["7", 0],
|
||||
"latent_image": ["5", 0],
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
},
|
||||
"4": {
|
||||
"inputs": {"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
},
|
||||
"5": {
|
||||
"inputs": {"width": 512, "height": 512, "batch_size": 1},
|
||||
"class_type": "EmptyLatentImage",
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "beautiful scenery nature glass bottle landscape, purple galaxy bottle",
|
||||
"clip": ["4", 1],
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
},
|
||||
"7": {
|
||||
"inputs": {"text": "text, watermark", "clip": ["4", 1]},
|
||||
"class_type": "CLIPTextEncode",
|
||||
},
|
||||
"8": {
|
||||
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
|
||||
"class_type": "VAEDecode",
|
||||
},
|
||||
"9": {
|
||||
"inputs": {"filename_prefix": "ComfyUI", "images": ["8", 0]},
|
||||
"class_type": "SaveImage",
|
||||
},
|
||||
}
|
||||
# these values should match the values in the custom workflow above,
|
||||
# they are used to calculate workload
|
||||
custom_fields = dict(
|
||||
steps=20,
|
||||
width=512,
|
||||
height=512,
|
||||
)
|
||||
req_data = dict(
|
||||
payload=dict(custom_fields=custom_fields, workflow=workflow),
|
||||
auth_data=auth_data,
|
||||
)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
print_truncate_res(str(response.json()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lib.test_utils import test_args
|
||||
|
||||
args = test_args.parse_args()
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
if endpoint_api_key:
|
||||
try:
|
||||
call_default_workflow(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
call_custom_workflow_for_sd3(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error during API call: {e}")
|
||||
else:
|
||||
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
|
||||
@@ -1,205 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import dataclasses
|
||||
import inspect
|
||||
from typing import Dict, Any
|
||||
from functools import cache
|
||||
from math import ceil
|
||||
from enum import Enum
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
|
||||
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
|
||||
test_prompts = f.readlines()
|
||||
|
||||
|
||||
class Model(Enum):
|
||||
Flux = "flux"
|
||||
Sd3 = "sd3"
|
||||
|
||||
def get_request_time(self) -> int:
|
||||
match self:
|
||||
case Model.Flux:
|
||||
return 23
|
||||
case Model.Sd3:
|
||||
return 6
|
||||
|
||||
|
||||
@cache
|
||||
def get_model() -> Model:
|
||||
match os.environ.get("COMFY_MODEL"):
|
||||
case "flux":
|
||||
return Model.Flux
|
||||
case "sd3":
|
||||
return Model.Sd3
|
||||
case None:
|
||||
raise Exception(
|
||||
"For comfyui pyworker, $COMFY_MODEL must be set in the vast template"
|
||||
)
|
||||
case model:
|
||||
raise Exception(f"Unsupported comfyui model: {model}")
|
||||
|
||||
|
||||
@cache
|
||||
def get_request_template() -> str:
|
||||
with open(f"workers/comfyui/misc/default_workflows/{get_model().value}.json") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def count_workload(width: int, height: int, steps: int) -> float:
|
||||
"""
|
||||
we want to normalize the workload is a number such that cur_perf(tokens/second) for 1024x1024 image with
|
||||
28 steps is 200 tokens on a 4090.
|
||||
|
||||
in order get that we calculate the
|
||||
|
||||
A = ( absolute workload based on given data )
|
||||
B = ( absolute workload for a 1024x1024 image with 28 steps )
|
||||
|
||||
and adjust the workload to 200 tokens by A/B.
|
||||
|
||||
we then adjust for difference between Flux and SD3 by multiplying this value by expected request time for a
|
||||
standard image(23s for Flux, 6s for SD3).
|
||||
On a 4090, this would give us a workload that would give a cur_perf(workload / request_time) of around 200
|
||||
"""
|
||||
|
||||
def _calculate_absolute_tokens(width_: int, height_: int, steps_: int) -> float:
|
||||
"""
|
||||
This is based on how openai counts image generation tokens, see: https://openai.com/api/pricing/
|
||||
|
||||
we count how many 512x512 grids are needed to cover the image.
|
||||
each tile is then counted as 175 tokens.
|
||||
each image generation also has constant of 85 base tokens.
|
||||
|
||||
we then adjust the count based on the number of steps. The baseline number of steps is assumed to be 28.
|
||||
Some testing with flux gave me this data:
|
||||
|
||||
steps(X) | request time(Y)
|
||||
__________|_________________
|
||||
07(0.25x) | 11s (0.47x)
|
||||
14(0.50x) | 15s (0.65x)
|
||||
21(0.75x) | 20s (0.86x)
|
||||
28(1.00x) | 23s (1.00x)
|
||||
35(1.25x) | 28s (1.21x)
|
||||
42(1.50x) | 32s (1.39x)
|
||||
49(1.75x) | 37s (1.60x)
|
||||
|
||||
this gives a linear regression of Y = 0.61*X + 6.57
|
||||
|
||||
we can use this as an adjustment_factor for token count
|
||||
|
||||
adjustment_factor = (0.61 * steps + 6.57)
|
||||
"""
|
||||
|
||||
width_grids = ceil(width_ / 512)
|
||||
height_grids = ceil(height_ / 512)
|
||||
tokens = 85 + width_grids * height_grids * 175
|
||||
adjustment_factor = 0.61 * steps_ + 6.57
|
||||
return tokens * adjustment_factor
|
||||
|
||||
REQUEST_TIME_FOR_STANDARD_IMAGE = get_model().get_request_time()
|
||||
|
||||
absolute_tokens = _calculate_absolute_tokens(
|
||||
width_=width, height_=height, steps_=steps
|
||||
)
|
||||
absolute_tokens_standard_image = _calculate_absolute_tokens(
|
||||
width_=1024, height_=1024, steps_=28
|
||||
)
|
||||
return REQUEST_TIME_FOR_STANDARD_IMAGE * (
|
||||
(absolute_tokens / absolute_tokens_standard_image) * 200
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DefaultComfyWorkflowData(ApiPayload):
|
||||
prompt: str
|
||||
width: int
|
||||
height: int
|
||||
steps: int
|
||||
seed: int
|
||||
|
||||
@classmethod
|
||||
def for_test(cls):
|
||||
|
||||
test_prompt = random.choice(test_prompts).rstrip()
|
||||
return cls(
|
||||
prompt=test_prompt,
|
||||
width=1024,
|
||||
height=1024,
|
||||
steps=28,
|
||||
seed=random.randint(0, sys.maxsize),
|
||||
)
|
||||
|
||||
def generate_payload_json(
|
||||
self,
|
||||
) -> Dict[str, Any]:
|
||||
return json.loads(
|
||||
get_request_template()
|
||||
.replace("{{PROMPT}}", self.prompt)
|
||||
# these values should be of int type. Since "{{VAR}}" is wrapped with " in the template
|
||||
# to make the JSON valid, we must replace the double quotes. i.e. "{{WIDTH}}" -> 1024 and not "1024"
|
||||
.replace('"{{WIDTH}}"', str(self.width))
|
||||
.replace('"{{HEIGHT}}"', str(self.height))
|
||||
.replace('"{{STEPS}}"', str(self.steps))
|
||||
.replace('"{{SEED}}"', str(self.seed))
|
||||
)
|
||||
|
||||
def count_workload(self) -> float:
|
||||
return count_workload(width=self.width, height=self.height, steps=self.steps)
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "DefaultComfyWorkflowData":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomComfyWorkflowData(ApiPayload):
|
||||
custom_fields: Dict[str, int]
|
||||
workflow: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def for_test(cls):
|
||||
raise NotImplementedError("Custom comfy workflow is not used for testing")
|
||||
|
||||
def count_workload(self) -> float:
|
||||
return count_workload(
|
||||
width=int(self.custom_fields.get("width", 1024)),
|
||||
height=int(self.custom_fields.get("height", 1024)),
|
||||
steps=int(self.custom_fields.get("steps", 28)),
|
||||
)
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
template_json = json.loads(get_request_template())
|
||||
template_json["input"]["workflow_json"] = self.workflow
|
||||
return template_json
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "CustomComfyWorkflowData":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
@@ -1,137 +0,0 @@
|
||||
{
|
||||
"input": {
|
||||
"handler": "RawWorkflow",
|
||||
"aws_access_key_id": "your-s3-access-key",
|
||||
"aws_secret_access_key": "your-s3-secret-access-key",
|
||||
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
|
||||
"aws_bucket_name": "your-bucket",
|
||||
"webhook_url": "your-webhook-url",
|
||||
"webhook_extra_params": {},
|
||||
"workflow_json": {
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": "{{WIDTH}}",
|
||||
"height": "{{HEIGHT}}",
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "{{PROMPT}}",
|
||||
"clip": ["11", 0]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": ["13", 0],
|
||||
"vae": ["10", 0]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": ["8", 0]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
},
|
||||
"10": {
|
||||
"inputs": {
|
||||
"vae_name": "ae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "Load VAE"
|
||||
}
|
||||
},
|
||||
"11": {
|
||||
"inputs": {
|
||||
"clip_name1": "t5xxl_fp16.safetensors",
|
||||
"clip_name2": "clip_l.safetensors",
|
||||
"type": "flux"
|
||||
},
|
||||
"class_type": "DualCLIPLoader",
|
||||
"_meta": {
|
||||
"title": "DualCLIPLoader"
|
||||
}
|
||||
},
|
||||
"12": {
|
||||
"inputs": {
|
||||
"unet_name": "flux1-dev.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "Load Diffusion Model"
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"inputs": {
|
||||
"noise": ["25", 0],
|
||||
"guider": ["22", 0],
|
||||
"sampler": ["16", 0],
|
||||
"sigmas": ["17", 0],
|
||||
"latent_image": ["5", 0]
|
||||
},
|
||||
"class_type": "SamplerCustomAdvanced",
|
||||
"_meta": {
|
||||
"title": "SamplerCustomAdvanced"
|
||||
}
|
||||
},
|
||||
"16": {
|
||||
"inputs": {
|
||||
"sampler_name": "euler"
|
||||
},
|
||||
"class_type": "KSamplerSelect",
|
||||
"_meta": {
|
||||
"title": "KSamplerSelect"
|
||||
}
|
||||
},
|
||||
"17": {
|
||||
"inputs": {
|
||||
"scheduler": "simple",
|
||||
"steps": "{{STEPS}}",
|
||||
"denoise": 1,
|
||||
"model": ["12", 0]
|
||||
},
|
||||
"class_type": "BasicScheduler",
|
||||
"_meta": {
|
||||
"title": "BasicScheduler"
|
||||
}
|
||||
},
|
||||
"22": {
|
||||
"inputs": {
|
||||
"model": ["12", 0],
|
||||
"conditioning": ["6", 0]
|
||||
},
|
||||
"class_type": "BasicGuider",
|
||||
"_meta": {
|
||||
"title": "BasicGuider"
|
||||
}
|
||||
},
|
||||
"25": {
|
||||
"inputs": {
|
||||
"noise_seed": "{{SEED}}"
|
||||
},
|
||||
"class_type": "RandomNoise",
|
||||
"_meta": {
|
||||
"title": "RandomNoise"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,142 +0,0 @@
|
||||
{
|
||||
"input": {
|
||||
"handler": "RawWorkflow",
|
||||
"aws_access_key_id": "your-s3-access-key",
|
||||
"aws_secret_access_key": "your-s3-secret-access-key",
|
||||
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
|
||||
"aws_bucket_name": "your-bucket",
|
||||
"webhook_url": "your-webhook-url",
|
||||
"webhook_extra_params": {},
|
||||
"workflow_json": {
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "{{PROMPT}}",
|
||||
"clip": ["252", 1]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"inputs": {
|
||||
"shift": 3,
|
||||
"model": ["252", 0]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "ModelSamplingSD3"
|
||||
}
|
||||
},
|
||||
"67": {
|
||||
"inputs": {
|
||||
"conditioning": ["71", 0]
|
||||
},
|
||||
"class_type": "ConditioningZeroOut",
|
||||
"_meta": {
|
||||
"title": "ConditioningZeroOut"
|
||||
}
|
||||
},
|
||||
"68": {
|
||||
"inputs": {
|
||||
"start": 0.1,
|
||||
"end": 1,
|
||||
"conditioning": ["67", 0]
|
||||
},
|
||||
"class_type": "ConditioningSetTimestepRange",
|
||||
"_meta": {
|
||||
"title": "ConditioningSetTimestepRange"
|
||||
}
|
||||
},
|
||||
"69": {
|
||||
"inputs": {
|
||||
"conditioning_1": ["68", 0],
|
||||
"conditioning_2": ["70", 0]
|
||||
},
|
||||
"class_type": "ConditioningCombine",
|
||||
"_meta": {
|
||||
"title": "Conditioning (Combine)"
|
||||
}
|
||||
},
|
||||
"70": {
|
||||
"inputs": {
|
||||
"start": 0,
|
||||
"end": 0.1,
|
||||
"conditioning": ["71", 0]
|
||||
},
|
||||
"class_type": "ConditioningSetTimestepRange",
|
||||
"_meta": {
|
||||
"title": "ConditioningSetTimestepRange"
|
||||
}
|
||||
},
|
||||
"71": {
|
||||
"inputs": {
|
||||
"text": "bad quality, poor quality, doll, disfigured, jpg, toy, bad anatomy, missing limbs, missing fingers, 3d, cgi",
|
||||
"clip": ["252", 1]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Negative Prompt)"
|
||||
}
|
||||
},
|
||||
"135": {
|
||||
"inputs": {
|
||||
"width": "{{WIDTH}}",
|
||||
"height": "{{HEIGHT}}",
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"_meta": {
|
||||
"title": "EmptySD3LatentImage"
|
||||
}
|
||||
},
|
||||
"231": {
|
||||
"inputs": {
|
||||
"samples": ["271", 0],
|
||||
"vae": ["252", 2]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"233": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": ["231", 0]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
},
|
||||
"252": {
|
||||
"inputs": {
|
||||
"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"271": {
|
||||
"inputs": {
|
||||
"seed": "{{SEED}}",
|
||||
"steps": "{{STEPS}}",
|
||||
"cfg": 4.5,
|
||||
"sampler_name": "dpmpp_2m",
|
||||
"scheduler": "sgm_uniform",
|
||||
"denoise": 1,
|
||||
"model": ["13", 0],
|
||||
"positive": ["6", 0],
|
||||
"negative": ["69", 0],
|
||||
"latent_image": ["135", 0]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
|
||||
stardew valley, fine details
|
||||
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
|
||||
realistic futuristic city-downtown with short buildings, sunset
|
||||
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
|
||||
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
|
||||
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
|
||||
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
|
||||
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
|
||||
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
|
||||
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
|
||||
Pope Francis wearing biker (leather jacket), a masterpiece
|
||||
Luke Skywalker ordering a burger and fries from the Death Star canteen.
|
||||
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
|
||||
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
|
||||
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
|
||||
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
|
||||
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
|
||||
london luxurious interior living-room, light walls
|
||||
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
|
||||
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
|
||||
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
|
||||
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
|
||||
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
|
||||
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
|
||||
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
|
||||
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
|
||||
the street of amedieval fantasy town, at dawn, dark, highly detailed
|
||||
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
|
||||
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
|
||||
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
|
||||
@@ -1,143 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import dataclasses
|
||||
import base64
|
||||
from typing import Optional, Union, Type
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
from anyio import open_file
|
||||
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.data_types import EndpointHandler
|
||||
from lib.server import start_server
|
||||
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
|
||||
|
||||
|
||||
MODEL_SERVER_URL = "http://127.0.0.1:18288" # API Wrapper Service
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||
"Value not in list: unet_name", # This error is emitted when the model file is not there at all
|
||||
]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
async def generate_client_response(
|
||||
request: web.Request, response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
_ = request
|
||||
match response.status:
|
||||
case 200:
|
||||
log.debug("SUCCESS")
|
||||
res = await response.json()
|
||||
if "output" not in res:
|
||||
return web.json_response(
|
||||
data=dict(error="there was an error in the workflow"),
|
||||
status=422,
|
||||
)
|
||||
image_paths = [path["local_path"] for path in res["output"]["images"]]
|
||||
if not image_paths:
|
||||
return web.json_response(
|
||||
data=dict(error="workflow did not produce any images"),
|
||||
status=422,
|
||||
)
|
||||
images = []
|
||||
for image_path in image_paths:
|
||||
async with await open_file(image_path, mode="rb") as f:
|
||||
contents = await f.read()
|
||||
images.append(
|
||||
f"data:image/png;base64,{base64.b64encode(contents).decode('utf-8')}"
|
||||
)
|
||||
return web.json_response(data=dict(images=images))
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/runsync"
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
|
||||
return DefaultComfyWorkflowData
|
||||
|
||||
def make_benchmark_payload(self) -> DefaultComfyWorkflowData:
|
||||
return DefaultComfyWorkflowData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
return await generate_client_response(client_request, model_response)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/runsync"
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
|
||||
return CustomComfyWorkflowData
|
||||
|
||||
def make_benchmark_payload(self) -> CustomComfyWorkflowData:
|
||||
return CustomComfyWorkflowData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
return await generate_client_response(client_request, model_response)
|
||||
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=False,
|
||||
benchmark_handler=DefaultComfyWorkflowHandler(
|
||||
benchmark_runs=3, benchmark_words=100
|
||||
),
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
(LogAction.Info, "Downloading:"),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def handle_ping(_):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/prompt", backend.create_handler(DefaultComfyWorkflowHandler())),
|
||||
web.post("/custom-workflow", backend.create_handler(CustomComfyWorkflowHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server(backend, routes)
|
||||
@@ -1,15 +0,0 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from .data_types import DefaultComfyWorkflowData, Model
|
||||
|
||||
WORKER_ENDPOINT = "/prompt"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_args.add_argument(
|
||||
"-m",
|
||||
dest="comfy_model",
|
||||
choices=list(map(lambda x: x.value, Model)),
|
||||
required=True,
|
||||
help="Image generation model name",
|
||||
)
|
||||
test_load_cmd(DefaultComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||
@@ -1,321 +0,0 @@
|
||||
# Vast PyWorker
|
||||
|
||||
## Hello_world example
|
||||
|
||||
There is a hello_world PyWorker implementation under `workers/hello_world`. This PyWorker is
|
||||
created for an LLM model server that runs on port 5001 has two API endpoints:
|
||||
|
||||
1. `/generate`: generates an full response to the prompt and sends a JSON response
|
||||
2. `/generate_stream`: streams a response one token at a time
|
||||
|
||||
Both of these endpoints take the same API JSON payload:
|
||||
|
||||
```
|
||||
{
|
||||
"prompt": String,
|
||||
"max_response_tokens": Number | null
|
||||
}
|
||||
```
|
||||
|
||||
We want the PyWorker to also expose two endpoints that correspond to the above endpoints.
|
||||
|
||||
### Structure
|
||||
|
||||
All PyWorkers have four files:
|
||||
|
||||
```
|
||||
.
|
||||
└── workers
|
||||
└── hello_world
|
||||
├── __init__.py
|
||||
├── data_types.py # contains data types representing model API endpoints
|
||||
├── server.py # contains endpoint handlers
|
||||
└── test_load.py # script for load testing
|
||||
|
||||
```
|
||||
|
||||
All of the classes follow strict type hinting. It is recommended that you type hint all of your function.
|
||||
This will allow your IDE or VSCode with `pyright` plugin to find any type errors in your implementation.
|
||||
You can also install `pyright` with `sudo npm install -g pyright` and run `pyright` in the root of the project to find
|
||||
any type errors.
|
||||
|
||||
### data_types.py: Contains data types representing model API endpoints
|
||||
|
||||
This file defines the structure of the data your model server expects (its API contract) and, critically, how PyWorker *interprets* that data for autoscaling purposes. You define Python data classes that mirror the JSON payloads your model's API uses.
|
||||
|
||||
These classes **must** inherit from `lib.data_types.ApiPayload`. This inheritance is not just for structure; it's how PyWorker knows how to:
|
||||
|
||||
* **Parse Incoming Requests:** Convert JSON from clients into usable Python objects.
|
||||
* **Calculate Workload:** Determine the computational cost of a request.
|
||||
* **Generate Test Data:** Create realistic inputs for benchmarking.
|
||||
* **Format Requests for the Model Server:** Prepare data for the underlying model.
|
||||
|
||||
|
||||
```python
|
||||
import dataclasses
|
||||
import random
|
||||
from typing import Dict, Any
|
||||
|
||||
from transformers import OpenAIGPTTokenizer # used to count tokens in a prompt
|
||||
import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model
|
||||
|
||||
from lib.data_types import ApiPayload
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
# you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputData(ApiPayload):
|
||||
prompt: str
|
||||
max_response_tokens: int
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "ApiPayload":
|
||||
"""defines how create a payload for load testing"""
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
return cls(prompt=prompt, max_response_tokens=300)
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def count_workload(self) -> float:
|
||||
"""defines how to calculate workload for a payload"""
|
||||
return len(tokenizer.tokenize(self.prompt))
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||
"""
|
||||
defines how to transform JSON data to AuthData and payload type,
|
||||
in this case `InputData` defined above represents the data sent to the model API.
|
||||
AuthData is data generated by autoscaler in order to authenticate payloads.
|
||||
In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker
|
||||
for more complicated examples
|
||||
"""
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### server.py: Creating Your Model's API Endpoints
|
||||
|
||||
This section guides you through creating the core of your custom model API: the `EndpointHandler`. Think of `EndpointHandler` as the bridge between incoming requests from users and your underlying model. It's the key to making your model accessible and scalable.
|
||||
|
||||
**Why use an `EndpointHandler`?**
|
||||
|
||||
* **Organized Request Handling:** It provides a structured way to handle different types of requests (like generating text, generating images, or performing other model-specific tasks).
|
||||
* **Scalability:** By separating request handling from the model itself, you can easily scale your API to handle many concurrent users.
|
||||
* **Flexibility:** You can customize how requests are processed, validated, and transformed before being sent to your model.
|
||||
* **Standard Interface:** It provides a consistent interface for interacting with your model, regardless of the underlying implementation.
|
||||
|
||||
For every model API endpoint you want to expose (e.g., `/generate`, `/generate_stream`), you'll implement an `EndpointHandler`. This class is responsible for:
|
||||
The `EndpointHandler` achieves this through several key methods:
|
||||
|
||||
* **Receiving and validating incoming requests (`get_data_from_request`):** This method ensures the request contains the necessary data (authentication and payload) and is in the correct format. It's the entry point for all requests.
|
||||
* **Defining the endpoint (`endpoint`):** This method specifies the URL endpoint on the model API server where requests will be sent (e.g., `/generate`).
|
||||
* **Specifying the payload type (`payload_cls`):** This method indicates the specific `ApiPayload` class used for this endpoint, defining the structure of the request data.
|
||||
* **Creating benchmark payloads (`make_benchmark_payload`):** This method creates payloads specifically for benchmarking the model's performance.
|
||||
* **Handling the model's response (`generate_client_response`):** This method takes the response from the model API server and transforms it into the format expected by the client making the request to your PyWorker. This allows you to customize the output as needed.
|
||||
|
||||
The `EndpointHandler` class has several abstract functions that you *must* implement to define the behavior of your specific endpoints. Here, we'll implement two common endpoints: `/generate` (for synchronous requests) and `/generate_stream` (for streaming responses):
|
||||
|
||||
```python
|
||||
|
||||
"""
|
||||
AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route.
|
||||
When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData
|
||||
work.
|
||||
When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON:
|
||||
{
|
||||
auth_data: AuthData,
|
||||
payload : InputData # defined above
|
||||
}
|
||||
"""
|
||||
from aiohttp import web
|
||||
|
||||
from lib.data_types import EndpointHandler, JsonDataException
|
||||
from lib.server import start_server
|
||||
from .data_types import InputData
|
||||
|
||||
# This class is the implementer for the '/generate' endpoint of model API
|
||||
@dataclasses.dataclass
|
||||
class GenerateHandler(EndpointHandler[InputData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
# the API endpoint
|
||||
return "/generate"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
"""this function should just return ApiPayload subclass used by this handler"""
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
"""
|
||||
defines how to convert `InputData` defined above, to
|
||||
JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but
|
||||
can be more complicated, See comfyui for an example
|
||||
"""
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
"""
|
||||
defines how to generate an InputData for benchmarking. This needs to be defined in only
|
||||
one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test()
|
||||
method on InputData. However, in some cases you might need to fine tune your InputData used for
|
||||
benchmarking to closely resemble the average request users call the endpoint with in order to get best
|
||||
autoscaling performance
|
||||
"""
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""
|
||||
defines how to convert a model API response to a response to PyWorker client
|
||||
"""
|
||||
_ = client_request
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("SUCCESS")
|
||||
data = await model_response.json()
|
||||
return web.json_response(data=data)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
```
|
||||
|
||||
We also handle `GenerateStreamHandler` for streaming responses. It is identical to `GenerateHandler`, except for
|
||||
the endpoint name and how we create a web response, as it is a streaming response:
|
||||
|
||||
```python
|
||||
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("Streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = "text/event-stream"
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
```
|
||||
|
||||
You can now instantiate a Backend and use it to handle requests.
|
||||
|
||||
```python
|
||||
from lib.backend import Backend, LogAction
|
||||
|
||||
# the url and port of model API
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
|
||||
# This is the log line that is emitted once the server has started
|
||||
MODEL_SERVER_START_LOG_MSG = "server has started"
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
|
||||
]
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
# location of model log file
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
# for some model backends that can only handle one request at a time, be sure to set this to False to
|
||||
# let PyWorker handling queueing requests.
|
||||
allow_parallel_requests=True,
|
||||
# give the backend an EndpointHandler instance that is used for benchmarking
|
||||
# number of benchmark run and number of words for a random benchmark run are given
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
# defines how to handle specific log messages. See docstring of LogAction for details
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
# this is a simple ping handler for PyWorker
|
||||
async def handle_ping(_: web.Request):
|
||||
return web.Response(body="pong")
|
||||
|
||||
# this is a handler for forwarding a health check to model API
|
||||
async def handle_healthcheck(_: web.Request):
|
||||
healthcheck_res = await backend.session.get("/healthcheck")
|
||||
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
|
||||
|
||||
routes = [
|
||||
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
web.get("/healthcheck", handle_healthcheck),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
# start server, called from start_server.sh
|
||||
start_server(backend, routes)
|
||||
```
|
||||
|
||||
### test_load.py
|
||||
|
||||
Here you can create a script that allows you test an endpoint group running instances with this PyWorker
|
||||
|
||||
```python
|
||||
from lib.test_harness import run
|
||||
from .data_types import InputData
|
||||
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
|
||||
if __name__ == "__main__":
|
||||
run(InputData.for_test(), WORKER_ENDPOINT)
|
||||
```
|
||||
|
||||
You can then run the following command from the root of this repo to load test endpoint group:
|
||||
|
||||
```sh
|
||||
# sends 1000 requests at the rate of 0.5 requests per second
|
||||
python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"
|
||||
```
|
||||
@@ -1,48 +0,0 @@
|
||||
import dataclasses
|
||||
import random
|
||||
import inspect
|
||||
from typing import Dict, Any
|
||||
|
||||
from transformers import OpenAIGPTTokenizer
|
||||
import nltk
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
# used to count to count tokens and workload for LLM
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputData(ApiPayload):
|
||||
prompt: str
|
||||
max_response_tokens: int
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "InputData":
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
return cls(prompt=prompt, max_response_tokens=300)
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def count_workload(self) -> int:
|
||||
return len(tokenizer.tokenize(self.prompt))
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
@@ -1,175 +0,0 @@
|
||||
"""
|
||||
PyWorker works as a man-in-the-middle between the client and model API. It's function is:
|
||||
1. receive request from client, update metrics such as workload of a request, number of pending requests, etc.
|
||||
2a. transform the data and forward the transformed data to model API
|
||||
2b. send updated metrics to autoscaler
|
||||
3. transform response from model API(if needed) and forward the response to client
|
||||
|
||||
PyWorker forward requests to many model API endpoint. each endpoint must have an EndpointHandler. You can also
|
||||
write function to just forward requests that don't generate anything with the model to model API without an
|
||||
EndpointHandler. This is useful for endpoints such as healthchecks. See below for example
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import dataclasses
|
||||
from typing import Dict, Any, Optional, Union, Type
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.data_types import EndpointHandler
|
||||
from lib.server import start_server
|
||||
from .data_types import InputData
|
||||
|
||||
# the url and port of model API
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
|
||||
# This is the log line that is emitted once the server has started
|
||||
MODEL_SERVER_START_LOG_MSG = "infer server has started"
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
|
||||
]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
# This class is the implementer for the '/generate' endpoint of model API
|
||||
@dataclasses.dataclass
|
||||
class GenerateHandler(EndpointHandler[InputData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
# the API endpoint
|
||||
return "/generate"
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
"""
|
||||
defines how to convert `InputData` defined above, to
|
||||
json data to be sent to the model API
|
||||
"""
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
"""
|
||||
defines how to generate an InputData for benchmarking. This needs to be defined in only
|
||||
one EndpointHandler, the one passed to the backend as the benchmark handler
|
||||
"""
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""
|
||||
defines how to convert a model API response to a response to PyWorker client
|
||||
"""
|
||||
_ = client_request
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("SUCCESS")
|
||||
data = await model_response.json()
|
||||
return web.json_response(data=data)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
# This is the same as GenerateHandler, except that it calls a streaming endpoint of the model API and streams the
|
||||
# response, which itself is streaming, back to the client.
|
||||
# it is nearly identical to handler as above, but it calls a different model API endpoint and it streams the
|
||||
# streaming response from model API to client
|
||||
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("Streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = "text/event-stream"
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
# This is the backend instance of pyworker. Only one must be made which uses EndpointHandlers to process
|
||||
# incoming requests
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
# give the backend a handler instance that is used for benchmarking
|
||||
# number of benchmark run and number of words for a random benchmark run are given
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
# defines how to handle specific log messages. See docstring of LogAction for details
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# this is a simple ping handler for pyworker
|
||||
async def handle_ping(_: web.Request):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
# this is a handler for forwarding a health check to modelAPI
|
||||
async def handle_healthcheck(_: web.Request):
|
||||
healthcheck_res = await backend.session.get("/healthcheck")
|
||||
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
web.get("/healthcheck", handle_healthcheck),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
# start the PyWorker server
|
||||
start_server(backend, routes)
|
||||
@@ -1,7 +0,0 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from .data_types import InputData
|
||||
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||
@@ -1,58 +0,0 @@
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
|
||||
class SerializableDataclass:
|
||||
def _serialize_recursive(self, obj: Any) -> Any:
|
||||
if is_dataclass(obj):
|
||||
return {
|
||||
field.name: self._serialize_recursive(getattr(obj, field.name))
|
||||
for field in fields(obj)
|
||||
}
|
||||
elif isinstance(obj, dict):
|
||||
return {key: self._serialize_recursive(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [self._serialize_recursive(item) for item in obj]
|
||||
elif isinstance(obj, set):
|
||||
return [self._serialize_recursive(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return self._serialize_recursive(self)
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionConfig(SerializableDataclass):
|
||||
"""Configuration for completion requests"""
|
||||
|
||||
model: str
|
||||
prompt: str = "Hello"
|
||||
max_tokens: int = 256
|
||||
temperature: float = 0.7
|
||||
top_k: int = 20
|
||||
top_p: float = 0.4
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionConfig(SerializableDataclass):
|
||||
"""Configuration for chat completion requests"""
|
||||
|
||||
model: str
|
||||
messages: list = field(default_factory=list)
|
||||
max_tokens: int = 2096
|
||||
temperature: float = 0.7
|
||||
top_k: int = 20
|
||||
top_p: float = 0.4
|
||||
stream: bool = False
|
||||
tools: Optional[List[Dict[str, Any]]] = field(default_factory=list)
|
||||
tool_choice: str = "auto"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.messages is None:
|
||||
self.messages = [{"role": "user", "content": "Hello"}]
|
||||
@@ -1,207 +0,0 @@
|
||||
import os, json, random
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
|
||||
from typing import Union, Type, Dict, Any, Optional
|
||||
from aiohttp import web, ClientResponse
|
||||
import nltk
|
||||
import logging
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
Generic dataclass accepts any dictionary in input.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenericData(ApiPayload, ABC):
|
||||
input: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
|
||||
return cls(input=data["input"])
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
|
||||
errors = {}
|
||||
|
||||
# Validate required parameters
|
||||
required_params = ["input"]
|
||||
for param in required_params:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
|
||||
try:
|
||||
# Create clean data dict and delegate to from_dict
|
||||
clean_data = {"input": json_msg["input"]}
|
||||
|
||||
return cls.from_dict(clean_data)
|
||||
|
||||
except (json.JSONDecodeError, JsonDataException) as e:
|
||||
errors["parameters"] = str(e)
|
||||
raise JsonDataException(errors)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def for_test(cls) -> "GenericData":
|
||||
pass
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
return self.input
|
||||
|
||||
def count_workload(self) -> int:
|
||||
return self.input.get("max_tokens", 0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenericHandler(EndpointHandler[GenericData], ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def endpoint(self) -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return os.environ.get("MODEL_HEALTH_ENDPOINT")
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[GenericData]:
|
||||
return GenericData
|
||||
|
||||
@abstractmethod
|
||||
def make_benchmark_payload(self) -> GenericData:
|
||||
pass
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
# Check if the response is actually streaming based on response headers/content-type
|
||||
is_streaming_response = (
|
||||
model_response.content_type == "text/event-stream"
|
||||
or model_response.content_type == "application/x-ndjson"
|
||||
or model_response.headers.get("Transfer-Encoding") == "chunked"
|
||||
or "stream" in model_response.content_type.lower()
|
||||
)
|
||||
|
||||
if is_streaming_response:
|
||||
log.debug("Detected streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = model_response.content_type
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
else:
|
||||
log.debug("Detected non-streaming response...")
|
||||
content = await model_response.read()
|
||||
return web.Response(
|
||||
body=content,
|
||||
status=200,
|
||||
content_type=model_response.content_type,
|
||||
)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionsData(GenericData):
|
||||
@classmethod
|
||||
def for_test(cls) -> "CompletionsData":
|
||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||
|
||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||
woodlands, shrublands, and mountainous areas.
|
||||
|
||||
Please answer the following question based on the above context."""
|
||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
test_input = {
|
||||
"model": model,
|
||||
"prompt": f"{system_prompt}\n\n{unique_question}",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
return cls(input=test_input)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionsHandler(GenericHandler):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/v1/completions"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[CompletionsData]:
|
||||
return CompletionsData
|
||||
|
||||
def make_benchmark_payload(self) -> CompletionsData:
|
||||
return CompletionsData.for_test()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionsData(GenericData):
|
||||
"""Chat completions-specific data implementation"""
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "ChatCompletionsData":
|
||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||
|
||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||
woodlands, shrublands, and mountainous areas.
|
||||
|
||||
Please answer the following question based on the above context."""
|
||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
# Chat completions use messages format instead of prompt
|
||||
test_input = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt}, # Shared prefix
|
||||
{"role": "user", "content": unique_question} # Unique per request
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
return cls(input=test_input)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionsHandler(GenericHandler):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/v1/chat/completions"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[ChatCompletionsData]:
|
||||
return ChatCompletionsData
|
||||
|
||||
def make_benchmark_payload(self) -> ChatCompletionsData:
|
||||
return ChatCompletionsData.for_test()
|
||||
@@ -1,434 +0,0 @@
|
||||
from lib.test_utils import test_args
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from lib.data_types import AuthData
|
||||
from .data_types.server import CompletionsData
|
||||
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import requests
|
||||
from dataclasses import dataclass
|
||||
from collections import Counter
|
||||
from urllib.parse import urljoin, urlparse
|
||||
import re
|
||||
|
||||
# Headless plotting
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import logging
|
||||
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
|
||||
from requests.adapters import HTTPAdapter
|
||||
|
||||
def get_incremented_path(path: str) -> str:
|
||||
base, ext = os.path.splitext(path)
|
||||
if not os.path.exists(path):
|
||||
return path
|
||||
i = 1
|
||||
while os.path.exists(f"{base}-{i}{ext}"):
|
||||
i += 1
|
||||
return f"{base}-{i}{ext}"
|
||||
|
||||
WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT)
|
||||
|
||||
@dataclass
|
||||
class ReqResult:
|
||||
worker_url: str
|
||||
route_ms: float
|
||||
worker_ms: float
|
||||
total_ms: float
|
||||
ok: bool
|
||||
error: str = ""
|
||||
status_code: int = 0
|
||||
t_start: float = 0.0
|
||||
t_end: float = 0.0
|
||||
workload: float = 0.0
|
||||
|
||||
def do_one(endpoint_name: str,
|
||||
endpoint_id: int,
|
||||
endpoint_api_key: str,
|
||||
server_url: str,
|
||||
worker_endpoint: str,
|
||||
payload,
|
||||
results_list,
|
||||
t0,
|
||||
status_samples,
|
||||
route_session,
|
||||
worker_session):
|
||||
try:
|
||||
workload = payload.count_workload()
|
||||
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
|
||||
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
|
||||
start = time.time()
|
||||
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
|
||||
t_after_route = time.time()
|
||||
if r0.status_code != 200:
|
||||
results_list.append(ReqResult(worker_url="",
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=0.0,
|
||||
total_ms=(t_after_route - start) * 1000.0,
|
||||
ok=False,
|
||||
error=f"route error {r0.reason} {r0.text}",
|
||||
status_code=r0.status_code,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_route - t0,
|
||||
workload=workload))
|
||||
return
|
||||
msg = r0.json()
|
||||
|
||||
# 1) Check if we got a worker back from route
|
||||
worker_url = msg.get("url", "")
|
||||
if not worker_url:
|
||||
status = msg.get("status", "")
|
||||
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
|
||||
if m:
|
||||
tot, loading, standby, err = map(int, m.groups())
|
||||
idle = max(tot - loading - standby - err, 0)
|
||||
status_samples.append((time.time() - t0, idle))
|
||||
|
||||
# 2) If we got a worker, send the request
|
||||
if worker_url:
|
||||
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
|
||||
t_before_worker = time.time()
|
||||
r1 = worker_session.post(
|
||||
urljoin(worker_url, worker_endpoint),
|
||||
json=req,
|
||||
verify=get_cert_file_path(),
|
||||
timeout=(4, 120),
|
||||
)
|
||||
t_after_worker = time.time()
|
||||
if r1.status_code != 200:
|
||||
results_list.append(ReqResult(worker_url=worker_url,
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||
total_ms=(t_after_worker - start) * 1000.0,
|
||||
ok=False,
|
||||
error=f"worker inference error {r1.reason} {r1.text}",
|
||||
status_code=r1.status_code,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_worker - t0,
|
||||
workload=workload))
|
||||
return
|
||||
# Success case
|
||||
results_list.append(ReqResult(worker_url=worker_url,
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||
total_ms=(t_after_worker - start) * 1000.0,
|
||||
ok=True,
|
||||
error="",
|
||||
status_code=200,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_worker - t0,
|
||||
workload=workload))
|
||||
|
||||
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
|
||||
if worker_url:
|
||||
try:
|
||||
r_status = route_session.post(
|
||||
urljoin(server_url, "/get_endpoint_workers/"),
|
||||
json={"id": endpoint_id},
|
||||
headers={"Authorization": f"Bearer {endpoint_api_key}"},
|
||||
timeout=3,
|
||||
)
|
||||
if r_status.status_code == 200:
|
||||
workers = r_status.json()
|
||||
idle = 0
|
||||
for w in workers:
|
||||
st = str(w.get("status", "")).lower()
|
||||
if (st in ("idle")):
|
||||
idle += 1
|
||||
status_samples.append((time.time() - t0, idle))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
t = time.time()
|
||||
results_list.append(ReqResult(worker_url="",
|
||||
route_ms=0.0,
|
||||
worker_ms=0.0,
|
||||
total_ms=0.0,
|
||||
ok=False,
|
||||
error=f"unknown error {e}",
|
||||
status_code=0,
|
||||
t_start=t - t0,
|
||||
t_end=t - t0,
|
||||
workload=0.0))
|
||||
|
||||
def run_load_with_metrics(num_requests: int,
|
||||
requests_per_second: float,
|
||||
endpoint_group_name: str,
|
||||
account_api_key: str,
|
||||
server_url: str,
|
||||
worker_endpoint: str,
|
||||
instance: str,
|
||||
out_path: str):
|
||||
|
||||
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
|
||||
account_api_key=account_api_key,
|
||||
instance=instance)
|
||||
if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"):
|
||||
print(f"Endpoint {endpoint_group_name} not found for API key")
|
||||
return
|
||||
endpoint_id = int(ep_info["id"])
|
||||
endpoint_api_key = ep_info["api_key"]
|
||||
|
||||
t0 = time.time()
|
||||
results = []
|
||||
status_samples = []
|
||||
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
|
||||
submit_queue_factor = 2 # cap queued tasks to reduce memory
|
||||
|
||||
# Shared HTTP sessions with connection pooling (persistent connections)
|
||||
def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session:
|
||||
sess = requests.Session()
|
||||
adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0)
|
||||
sess.mount("https://", adapter)
|
||||
sess.mount("http://", adapter)
|
||||
return sess
|
||||
|
||||
# Router: mostly single host, small connection pool is sufficient
|
||||
route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
|
||||
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
|
||||
worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
|
||||
|
||||
# Fire requests using a thread pool, scheduling at requested RPS
|
||||
inflight = set()
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
for i in range(num_requests):
|
||||
# Pace submissions to RPS
|
||||
target_time = t0 + i / max(requests_per_second, 1e-9)
|
||||
sleep_s = target_time - time.time()
|
||||
if sleep_s > 0:
|
||||
time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive
|
||||
|
||||
payload = CompletionsData.for_test()
|
||||
fut = executor.submit(
|
||||
do_one,
|
||||
endpoint_group_name,
|
||||
endpoint_id,
|
||||
endpoint_api_key,
|
||||
server_url,
|
||||
worker_endpoint,
|
||||
payload,
|
||||
results,
|
||||
t0,
|
||||
status_samples,
|
||||
route_session,
|
||||
worker_session,
|
||||
)
|
||||
inflight.add(fut)
|
||||
# Prevent unbounded queue growth
|
||||
if len(inflight) >= max_concurrency * submit_queue_factor:
|
||||
done, not_done = wait(inflight, return_when=FIRST_COMPLETED)
|
||||
inflight = not_done
|
||||
# Wait for all outstanding tasks
|
||||
if inflight:
|
||||
wait(inflight)
|
||||
# Close sessions
|
||||
try:
|
||||
route_session.close()
|
||||
finally:
|
||||
worker_session.close()
|
||||
|
||||
# Aggregate results
|
||||
oks = [r for r in results if r.ok]
|
||||
errs = [r for r in results if not r.ok]
|
||||
total_reqs = len(results)
|
||||
succ = len(oks)
|
||||
|
||||
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
|
||||
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
|
||||
route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
|
||||
|
||||
avg_total = float(np.mean(total_ms)) if succ else 0.0
|
||||
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
|
||||
avg_route = float(np.mean(route_ms)) if succ else 0.0
|
||||
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
|
||||
|
||||
# Distribution over workers (by host:port)
|
||||
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
|
||||
dist = Counter(hosts)
|
||||
|
||||
# Idle over time (mode per second)
|
||||
idle_ts, idle_vals = [], []
|
||||
if status_samples:
|
||||
buckets = {}
|
||||
for ts, idle in status_samples:
|
||||
k = int(ts)
|
||||
buckets.setdefault(k, []).append(idle)
|
||||
keys = sorted(buckets.keys())
|
||||
idle_ts = keys
|
||||
# Use the most frequent sampled value per second (mode) to keep integer counts
|
||||
idle_vals = []
|
||||
for k in keys:
|
||||
vals_k = [int(v) for v in buckets[k]]
|
||||
if vals_k:
|
||||
cnt = Counter(vals_k)
|
||||
idle_vals.append(cnt.most_common(1)[0][0])
|
||||
else:
|
||||
idle_vals.append(0)
|
||||
|
||||
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
|
||||
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
|
||||
print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
|
||||
if errs:
|
||||
print("Sample errors:")
|
||||
for e in errs[:5]:
|
||||
print(f" {e.status_code} {e.error}")
|
||||
|
||||
# Plot: 2x3 grid
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||||
fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}")
|
||||
|
||||
# Dist per worker
|
||||
ax0 = axes[0, 0]
|
||||
if dist:
|
||||
items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True)
|
||||
labels, counts = zip(*items)
|
||||
ax0.bar(range(len(labels)), counts)
|
||||
ax0.set_xticks(range(len(labels)))
|
||||
ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||
ax0.set_title("Request distribution over workers")
|
||||
ax0.set_ylabel("count")
|
||||
|
||||
# Latency histogram (total)
|
||||
ax1 = axes[0, 1]
|
||||
if succ:
|
||||
ax1.hist(total_ms, bins=30)
|
||||
ax1.set_title("Total latency (ms)")
|
||||
ax1.set_xlabel("ms")
|
||||
ax1.set_ylabel("freq")
|
||||
|
||||
# Eligible workers over time
|
||||
ax_idle = axes[0, 2]
|
||||
if idle_ts:
|
||||
ax_idle.plot(idle_ts, idle_vals, "-o", ms=3)
|
||||
ax_idle.set_title("Eligible workers over time")
|
||||
ax_idle.set_xlabel("time (s)")
|
||||
ax_idle.set_ylabel("eligible count")
|
||||
|
||||
# Throughput over time (completions/sec)
|
||||
ax_idle = axes[1, 0]
|
||||
ax_idle.clear()
|
||||
if succ:
|
||||
per_sec = {}
|
||||
for r in oks:
|
||||
s = int(r.t_end)
|
||||
per_sec[s] = per_sec.get(s, 0) + 1
|
||||
ts = sorted(per_sec.keys())
|
||||
vals = [per_sec[t] for t in ts]
|
||||
ax_idle.plot(ts, vals, "-o", ms=3)
|
||||
ax_idle.set_title("Completions per second")
|
||||
ax_idle.set_xlabel("time (s)")
|
||||
ax_idle.set_ylabel("completions / sec")
|
||||
|
||||
# Summary text
|
||||
ax3 = axes[1, 1]
|
||||
ax3.axis("off")
|
||||
text = (
|
||||
f"Total requests: {total_reqs}\n"
|
||||
f"Success: {succ} Errors: {len(errs)}\n"
|
||||
f"Avg total latency: {avg_total:.1f} ms\n"
|
||||
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
|
||||
f"Avg route latency: {avg_route:.1f} ms\n"
|
||||
f"Avg worker latency: {avg_worker:.1f} ms\n"
|
||||
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
|
||||
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
|
||||
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
|
||||
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
|
||||
)
|
||||
ax3.set_title("Summary")
|
||||
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
|
||||
|
||||
# Error count over time
|
||||
ax_errors = axes[1, 2]
|
||||
all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
|
||||
if all_end_times:
|
||||
min_second = min(all_end_times)
|
||||
max_second = max(all_end_times)
|
||||
# Count errors per second
|
||||
errors_per_second = {}
|
||||
for result in errs:
|
||||
second = int(result.t_end)
|
||||
errors_per_second[second] = errors_per_second.get(second, 0) + 1
|
||||
# Create complete timeline including zeros
|
||||
time_seconds = list(range(min_second, max_second + 1))
|
||||
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
|
||||
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
|
||||
ax_errors.set_title("Errors per second")
|
||||
ax_errors.set_xlabel("time (s)")
|
||||
ax_errors.set_ylabel("errors / sec")
|
||||
|
||||
# Ensure unique output path and create directory if needed
|
||||
final_out_path = get_incremented_path(out_path)
|
||||
out_dir = os.path.dirname(final_out_path)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
plt.savefig(final_out_path, dpi=120)
|
||||
print(f"Saved report to: {final_out_path}")
|
||||
|
||||
# Per-worker latency boxplot (top 12 by volume)
|
||||
groups = {}
|
||||
for r in oks:
|
||||
host = urlparse(r.worker_url).netloc
|
||||
groups.setdefault(host, []).append(r.total_ms)
|
||||
items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12]
|
||||
if items:
|
||||
labels, data = zip(*items)
|
||||
fig2, axb = plt.subplots(1, 1, figsize=(12, 5))
|
||||
axb.boxplot(data, showfliers=False)
|
||||
axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||
axb.set_title("Per-worker latency (ms)")
|
||||
axb.set_ylabel("ms")
|
||||
plt.tight_layout()
|
||||
extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png")
|
||||
plt.savefig(extra_out, dpi=120)
|
||||
fig2.tight_layout()
|
||||
fig2.savefig(extra_out, dpi=120)
|
||||
print(f"Saved worker latency plot to: {extra_out}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if MODEL_NAME environment variable is set
|
||||
model_name_set = os.environ.get("MODEL_NAME") is not None
|
||||
|
||||
# Add model argument - required only if MODEL_NAME is not set
|
||||
test_args.add_argument(
|
||||
"--model",
|
||||
dest="model",
|
||||
required=not model_name_set,
|
||||
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
||||
)
|
||||
|
||||
# Parse known args to get model early, before adding load args
|
||||
known_args, _ = test_args.parse_known_args()
|
||||
if hasattr(known_args, "model") and known_args.model:
|
||||
os.environ["MODEL_NAME"] = known_args.model
|
||||
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
||||
|
||||
# Load test args
|
||||
test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
|
||||
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
|
||||
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
|
||||
args = test_args.parse_args()
|
||||
|
||||
server_url = {
|
||||
"prod": "https://run.vast.ai",
|
||||
"alpha": "https://run-alpha.vast.ai",
|
||||
"candidate": "https://run-candidate.vast.ai",
|
||||
"local": "http://localhost:8080"
|
||||
}.get(args.instance, "http://localhost:8080")
|
||||
|
||||
run_load_with_metrics(
|
||||
num_requests=args.num_requests,
|
||||
requests_per_second=args.requests_per_second,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
server_url=server_url,
|
||||
worker_endpoint=WORKER_ENDPOINT,
|
||||
instance=args.instance,
|
||||
out_path=args.out_path,
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
import dataclasses
|
||||
import random
|
||||
import inspect
|
||||
from typing import Dict, Any
|
||||
|
||||
from transformers import OpenAIGPTTokenizer
|
||||
import nltk
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputParameters:
|
||||
max_new_tokens: int = 256
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputParameters":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputData(ApiPayload):
|
||||
inputs: str
|
||||
parameters: InputParameters
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "InputData":
|
||||
return cls(
|
||||
inputs=data["inputs"], parameters=InputParameters(**data["parameters"])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "InputData":
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
return cls(inputs=prompt, parameters=InputParameters())
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def count_workload(self) -> int:
|
||||
return self.parameters.max_new_tokens
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
try:
|
||||
parameters = InputParameters.from_json_msg(json_msg["parameters"])
|
||||
return cls(inputs=json_msg["inputs"], parameters=parameters)
|
||||
except JsonDataException as e:
|
||||
errors["parameters"] = e.message
|
||||
raise JsonDataException(errors)
|
||||
@@ -1,130 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Union, Type
|
||||
import dataclasses
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.data_types import EndpointHandler
|
||||
from lib.server import start_server
|
||||
from .data_types import InputData
|
||||
|
||||
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = [
|
||||
'"message":"Connected","target":"text_generation_router"',
|
||||
'"message":"Connected","target":"text_generation_router::server"',
|
||||
]
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"Error: WebserverFailed",
|
||||
"Error: DownloadError",
|
||||
"Error: ShardCannotStart",
|
||||
]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GenerateHandler(EndpointHandler[InputData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate"
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> str:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
_ = client_request
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("SUCCESS")
|
||||
data = await model_response.json()
|
||||
return web.json_response(data=data)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> str:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("Streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = "text/event-stream"
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def handle_ping(_):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server(backend, routes)
|
||||
@@ -1,7 +0,0 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from .data_types import InputData
|
||||
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||
Reference in New Issue
Block a user