@@ -0,0 +1,3 @@
|
|||||||
|
.direnv
|
||||||
|
.envrc
|
||||||
|
__pycache__
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Vast.ai
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
@@ -1 +1,14 @@
|
|||||||
# pyworker
|
# Vast PyWorker
|
||||||
|
|
||||||
|
Vast PyWorker is a Python web server designed to run alongside a LLM or image generation models running on vast,
|
||||||
|
enabling autoscaler integration.
|
||||||
|
It serves as the primary entry point for API requests, forwarding them to the model's API hosted on the
|
||||||
|
same instance. Additionally, it monitors performance metrics and estimates current workload based on factors
|
||||||
|
such as the number of tokens processed for LLMs or image resolution and steps for image generation models,
|
||||||
|
reporting these metrics to the autoscaler.
|
||||||
|
|
||||||
|
## How to Use
|
||||||
|
|
||||||
|
If you want to use autoscaler, you just need to use one of Vast's autoscaler templates. If you'd like to
|
||||||
|
implement PyWorker for a template that is not marked as autoscaler compatible on Vast, refer to
|
||||||
|
`workers/hello_world/README.md`
|
||||||
|
|||||||
+327
@@ -0,0 +1,327 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import base64
|
||||||
|
import subprocess
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||||
|
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable
|
||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
from anyio import open_file
|
||||||
|
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from Crypto.Signature import pkcs1_15
|
||||||
|
from Crypto.Hash import SHA256
|
||||||
|
from Crypto.PublicKey import RSA
|
||||||
|
|
||||||
|
from lib.metrics import Metrics
|
||||||
|
from lib.data_types import (
|
||||||
|
AuthData,
|
||||||
|
EndpointHandler,
|
||||||
|
LogAction,
|
||||||
|
ApiPayload_T,
|
||||||
|
JsonDataException,
|
||||||
|
)
|
||||||
|
|
||||||
|
MSG_HISTORY_LEN = 100
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
# defines the minimum wait time between sending updates to autoscaler
|
||||||
|
LOG_POLL_INTERVAL = 0.1
|
||||||
|
BENCHMARK_INDICATOR_FILE = ".has_benchmark"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Backend:
|
||||||
|
"""
|
||||||
|
This class is responsible for:
|
||||||
|
1. Tailing logs and updating load time metrics
|
||||||
|
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
|
||||||
|
sending the request. It also updates metrics as it makes those requests.
|
||||||
|
3. Running a benchmark from an EndpointHandler
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_server_url: str
|
||||||
|
model_log_file: str
|
||||||
|
allow_parallel_requests: bool
|
||||||
|
benchmark_handler: (
|
||||||
|
EndpointHandler # this endpoint handler will be used for benchmarking
|
||||||
|
)
|
||||||
|
log_actions: List[Tuple[LogAction, str]]
|
||||||
|
reqnum = -1
|
||||||
|
msg_history = []
|
||||||
|
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
|
||||||
|
def fetch_public_key():
|
||||||
|
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||||
|
result = subprocess.check_output(command, universal_newlines=True)
|
||||||
|
log.debug("public key:")
|
||||||
|
log.debug(result)
|
||||||
|
key = None
|
||||||
|
for _ in range(5):
|
||||||
|
try:
|
||||||
|
key = RSA.import_key(result)
|
||||||
|
break
|
||||||
|
except ValueError as e:
|
||||||
|
log.debug(f"Error downloading key: {e}")
|
||||||
|
time.sleep(15)
|
||||||
|
return key
|
||||||
|
|
||||||
|
###########
|
||||||
|
|
||||||
|
self.PUBLIC_KEY = fetch_public_key()
|
||||||
|
self.metrics = Metrics()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def session(self):
|
||||||
|
log.debug(f"starting session with {self.model_server_url}")
|
||||||
|
return ClientSession(self.model_server_url)
|
||||||
|
|
||||||
|
def create_handler(
|
||||||
|
self,
|
||||||
|
handler: EndpointHandler[ApiPayload_T],
|
||||||
|
) -> Callable[[web.Request], Awaitable[Union[web.Response, web.StreamResponse]]]:
|
||||||
|
async def handler_fn(
|
||||||
|
request: web.Request,
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
return await self.__handle_request(handler=handler, request=request)
|
||||||
|
|
||||||
|
return handler_fn
|
||||||
|
|
||||||
|
#######################################Private#######################################
|
||||||
|
async def __handle_request(
|
||||||
|
self,
|
||||||
|
handler: EndpointHandler[ApiPayload_T],
|
||||||
|
request: web.Request,
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
"""use this function to forward requests to the model endpoint"""
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
auth_data, payload = handler.get_data_from_request(data)
|
||||||
|
except JsonDataException as e:
|
||||||
|
return web.json_response(data=e.message, status=422)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||||
|
workload = payload.count_workload()
|
||||||
|
|
||||||
|
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||||
|
await request.wait_for_disconnection()
|
||||||
|
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
||||||
|
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
|
||||||
|
return web.Response(status=500)
|
||||||
|
|
||||||
|
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||||
|
log.debug(f"got request, {auth_data.reqnum}")
|
||||||
|
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
|
||||||
|
if self.allow_parallel_requests is False:
|
||||||
|
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
|
||||||
|
await self.sem.acquire()
|
||||||
|
log.debug(
|
||||||
|
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
response = await self.__call_api(handler=handler, payload=payload)
|
||||||
|
status_code = response.status
|
||||||
|
log.debug(
|
||||||
|
" ".join(
|
||||||
|
[
|
||||||
|
f"request with reqnum:{auth_data.reqnum}",
|
||||||
|
f"returned status code: {status_code},",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
res = await handler.generate_client_response(request, response)
|
||||||
|
self.metrics._request_end(
|
||||||
|
workload=workload,
|
||||||
|
req_response_time=time.time() - start_time,
|
||||||
|
reqnum=auth_data.reqnum,
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
log.debug(f"[backend] Request error: {e}")
|
||||||
|
self.metrics._request_errored(
|
||||||
|
workload=workload, reqnum=auth_data.reqnum
|
||||||
|
)
|
||||||
|
return web.Response(status=500)
|
||||||
|
finally:
|
||||||
|
self.sem.release()
|
||||||
|
|
||||||
|
###########
|
||||||
|
|
||||||
|
if self.__check_signature(auth_data) is False:
|
||||||
|
return web.Response(status=401)
|
||||||
|
|
||||||
|
try:
|
||||||
|
done, pending = await wait(
|
||||||
|
[
|
||||||
|
create_task(make_request()),
|
||||||
|
create_task(cancel_api_call_if_disconnected()),
|
||||||
|
],
|
||||||
|
return_when=FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
[task.cancel() for task in pending]
|
||||||
|
return done.pop().result()
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
|
return web.Response(status=500)
|
||||||
|
|
||||||
|
async def _start_tracking(self) -> None:
|
||||||
|
await gather(self.__read_logs(), self.metrics._send_metrics_loop())
|
||||||
|
|
||||||
|
def backend_errored(self, msg: str) -> None:
|
||||||
|
self.metrics._model_errored(msg)
|
||||||
|
|
||||||
|
async def __call_api(
|
||||||
|
self, handler: EndpointHandler[ApiPayload_T], payload: ApiPayload_T
|
||||||
|
) -> ClientResponse:
|
||||||
|
api_payload = payload.generate_payload_json()
|
||||||
|
log.debug(f"posting to endpoint: '{handler.endpoint}', payload: {api_payload}")
|
||||||
|
return await self.session.post(url=handler.endpoint, json=api_payload)
|
||||||
|
|
||||||
|
def __check_signature(self, auth_data: AuthData) -> bool:
|
||||||
|
def verify_signature(message, signature):
|
||||||
|
if self.PUBLIC_KEY is None:
|
||||||
|
log.debug(f"No Public Key!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
h = SHA256.new(message.encode())
|
||||||
|
try:
|
||||||
|
pkcs1_15.new(self.PUBLIC_KEY).verify(h, base64.b64decode(signature))
|
||||||
|
return True
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
message = {
|
||||||
|
key: value
|
||||||
|
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||||
|
if key != "signature"
|
||||||
|
}
|
||||||
|
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||||
|
log.debug(
|
||||||
|
f"reqnum failure, got {auth_data.reqnum}, current_reqnum: {self.reqnum}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
elif message in self.msg_history:
|
||||||
|
log.debug(f"message: {message} already in message history")
|
||||||
|
return False
|
||||||
|
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
|
||||||
|
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||||
|
self.msg_history.append(message)
|
||||||
|
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
log.debug(
|
||||||
|
f"signature verification failed, sig:{auth_data.signature}, message: {message}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def __read_logs(self) -> Awaitable[NoReturn]:
|
||||||
|
|
||||||
|
async def run_benchmark() -> float:
|
||||||
|
log.debug("starting benchmark")
|
||||||
|
try:
|
||||||
|
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||||
|
log.debug("already ran benchmark")
|
||||||
|
# trigger model load
|
||||||
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
|
_ = await self.__call_api(
|
||||||
|
handler=self.benchmark_handler, payload=payload
|
||||||
|
)
|
||||||
|
return float(f.readline())
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
max_throughput = 0
|
||||||
|
last_throughput = 0
|
||||||
|
sum_throughput = 0
|
||||||
|
for run in range(self.benchmark_handler.benchmark_runs + 1):
|
||||||
|
start = time.time()
|
||||||
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
|
res = await self.__call_api(
|
||||||
|
handler=self.benchmark_handler, payload=payload
|
||||||
|
)
|
||||||
|
data = await res.json()
|
||||||
|
time_elapsed = time.time() - start
|
||||||
|
# first run triggers one-time loading of the model which is very slow, so we skip counting it
|
||||||
|
if run == 0:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
workload = payload.count_workload()
|
||||||
|
last_throughput = workload / time_elapsed
|
||||||
|
sum_throughput += last_throughput
|
||||||
|
max_throughput = max(max_throughput, last_throughput)
|
||||||
|
log.debug(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
"#" * 60,
|
||||||
|
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
|
||||||
|
"",
|
||||||
|
f"response: {data}",
|
||||||
|
"#" * 60,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
||||||
|
log.debug(
|
||||||
|
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||||
|
)
|
||||||
|
# save max_throughput so we don't have to run benchmark again on restart of cold instances
|
||||||
|
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||||
|
f.write(str(max_throughput))
|
||||||
|
return max_throughput
|
||||||
|
|
||||||
|
async def handle_log_line(log_line: str) -> None:
|
||||||
|
"""
|
||||||
|
Implement this function to handle each log line for your model.
|
||||||
|
This function should mutate self.system_metrics and self.model_metrics
|
||||||
|
"""
|
||||||
|
for action, msg in self.log_actions:
|
||||||
|
match action:
|
||||||
|
case LogAction.ModelLoaded if msg in log_line:
|
||||||
|
log.debug(
|
||||||
|
f"Got log line indicating model is loaded: {log_line}"
|
||||||
|
)
|
||||||
|
# some backends need a few seconds after logging successful startup before
|
||||||
|
# they can begin accepting requests
|
||||||
|
await sleep(5)
|
||||||
|
try:
|
||||||
|
max_throughput = await run_benchmark()
|
||||||
|
self.metrics._model_loaded(
|
||||||
|
max_throughput=max_throughput,
|
||||||
|
)
|
||||||
|
except ClientConnectorError as e:
|
||||||
|
log.debug(
|
||||||
|
f"failed to connect to comfyui api during benchmark"
|
||||||
|
)
|
||||||
|
self.backend_errored(str(e))
|
||||||
|
case LogAction.ModelError if msg in log_line:
|
||||||
|
log.debug(f"Got log line indicating error: {log_line}")
|
||||||
|
self.backend_errored(msg)
|
||||||
|
break
|
||||||
|
case LogAction.Info if msg in log_line:
|
||||||
|
log.debug(f"Info from model logs: {log_line}")
|
||||||
|
|
||||||
|
async def tail_log():
|
||||||
|
log.debug(f"tailing file: {self.model_log_file}")
|
||||||
|
async with await open_file(self.model_log_file) as f:
|
||||||
|
while True:
|
||||||
|
line = await f.readline()
|
||||||
|
if line:
|
||||||
|
await handle_log_line(line.rstrip())
|
||||||
|
else:
|
||||||
|
time.sleep(LOG_POLL_INTERVAL)
|
||||||
|
|
||||||
|
###########
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if os.path.isfile(self.model_log_file) is True:
|
||||||
|
return await tail_log()
|
||||||
|
else:
|
||||||
|
await sleep(1)
|
||||||
@@ -0,0 +1,269 @@
|
|||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
|
||||||
|
from aiohttp import web, ClientResponse
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
type variable representing an incoming payload to pyworker that will used to calculate load and will then
|
||||||
|
be forwarded to the model
|
||||||
|
"""
|
||||||
|
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonDataException(Exception):
|
||||||
|
def __init__(self, json_msg: Dict[str, Any]):
|
||||||
|
self.message = json_msg
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ApiPayload(ABC):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def for_test(cls) -> "ApiPayload":
|
||||||
|
"""defines how create a payload for load testing"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_payload_json(self) -> Dict[str, Any]:
|
||||||
|
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def count_workload(self) -> float:
|
||||||
|
"""defines how to calculate workload for a payload"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ApiPayload":
|
||||||
|
"""
|
||||||
|
defines how to create an API payload from a JSON message,
|
||||||
|
it should throw an JsonDataException if there are issues with some fields
|
||||||
|
or they are missing in the format of
|
||||||
|
{
|
||||||
|
"field": "error msg"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AuthData:
|
||||||
|
"""data used to authenticate requester"""
|
||||||
|
|
||||||
|
signature: str
|
||||||
|
cost: str
|
||||||
|
endpoint: str
|
||||||
|
reqnum: int
|
||||||
|
url: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_msg(cls, json_msg: Dict[str, Any]):
|
||||||
|
errors = {}
|
||||||
|
for param in inspect.signature(cls).parameters:
|
||||||
|
if param not in json_msg:
|
||||||
|
errors[param] = "missing parameter"
|
||||||
|
if errors:
|
||||||
|
raise JsonDataException(errors)
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in json_msg.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ApiPayload_T = TypeVar("ApiPayload_T", bound=ApiPayload)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EndpointHandler(ABC, Generic[ApiPayload_T]):
|
||||||
|
"""
|
||||||
|
Each model endpoint will have a handler responsible for counting workload from the incoming ApiPayload
|
||||||
|
and converting it to json to be forwarded to model API
|
||||||
|
"""
|
||||||
|
|
||||||
|
benchmark_runs: int = 8
|
||||||
|
benchmark_words: int = 100
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
"""the endpoint on the model API"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def payload_cls(cls) -> Type[ApiPayload_T]:
|
||||||
|
"""ApiPayload class"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def make_benchmark_payload(self) -> ApiPayload_T:
|
||||||
|
"""defines how to create an ApiPayload for benchmarking."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
"""
|
||||||
|
defines how to convert a model API response to a response to PyWorker client
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_data_from_request(
|
||||||
|
cls, req_data: Dict[str, Any]
|
||||||
|
) -> Tuple[AuthData, ApiPayload_T]:
|
||||||
|
errors = {}
|
||||||
|
auth_data = payload = None
|
||||||
|
try:
|
||||||
|
if "auth_data" in req_data:
|
||||||
|
auth_data = AuthData.from_json_msg(req_data["auth_data"])
|
||||||
|
else:
|
||||||
|
errors["auth_data"] = "field missing"
|
||||||
|
except JsonDataException as e:
|
||||||
|
errors["auth_data"] = e.message
|
||||||
|
try:
|
||||||
|
if "payload" in req_data:
|
||||||
|
payload = cls.payload_cls().from_json_msg(req_data["payload"])
|
||||||
|
else:
|
||||||
|
errors["payload"] = "field missing"
|
||||||
|
except JsonDataException as e:
|
||||||
|
errors["payload"] = e.message
|
||||||
|
if errors:
|
||||||
|
raise JsonDataException(errors)
|
||||||
|
if auth_data and payload:
|
||||||
|
return (auth_data, payload)
|
||||||
|
else:
|
||||||
|
raise Exception("error deserializing request data")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SystemMetrics:
|
||||||
|
"""General system metrics"""
|
||||||
|
|
||||||
|
model_loading_start: float
|
||||||
|
model_loading_time: Union[float, None]
|
||||||
|
last_disk_usage: float
|
||||||
|
additional_disk_usage: float
|
||||||
|
model_is_loaded: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_disk_usage_GB():
|
||||||
|
return psutil.disk_usage("/").used / (2**30) # want units of GB
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty(cls):
|
||||||
|
return cls(
|
||||||
|
model_loading_start=time.time(),
|
||||||
|
model_loading_time=None,
|
||||||
|
last_disk_usage=SystemMetrics.get_disk_usage_GB(),
|
||||||
|
additional_disk_usage=0.0,
|
||||||
|
model_is_loaded=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_disk_usage(self):
|
||||||
|
disk_usage = SystemMetrics.get_disk_usage_GB()
|
||||||
|
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||||
|
self.last_disk_usage = disk_usage
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||||
|
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||||
|
# as well: they should send model_loading_time once when they are done loading
|
||||||
|
self.model_loading_time = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelMetrics:
|
||||||
|
"""Model specific metrics"""
|
||||||
|
|
||||||
|
# these are reset after being sent to autoscaler
|
||||||
|
workload_served: float
|
||||||
|
workload_received: float
|
||||||
|
workload_cancelled: float
|
||||||
|
workload_errored: float
|
||||||
|
workload_pending: float
|
||||||
|
# these are not
|
||||||
|
cur_perf: float
|
||||||
|
error_msg: Optional[str]
|
||||||
|
max_throughput: float
|
||||||
|
requests_recieved: Set[int] = field(default_factory=set)
|
||||||
|
requests_working: Set[int] = field(default_factory=set)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty(cls):
|
||||||
|
return cls(
|
||||||
|
workload_pending=0.0,
|
||||||
|
workload_served=0.0,
|
||||||
|
workload_cancelled=0.0,
|
||||||
|
workload_errored=0.0,
|
||||||
|
cur_perf=0.0,
|
||||||
|
workload_received=0.0,
|
||||||
|
error_msg=None,
|
||||||
|
max_throughput=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def workload_processing(self) -> float:
|
||||||
|
return max(self.workload_received - self.workload_cancelled, 0.0)
|
||||||
|
|
||||||
|
def set_errored(self, error_msg):
|
||||||
|
self.reset()
|
||||||
|
self.error_msg = error_msg
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.workload_served = 0
|
||||||
|
self.workload_received = 0
|
||||||
|
self.workload_cancelled = 0
|
||||||
|
self.workload_errored = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoScalaerData:
|
||||||
|
"""Data that is reported to autoscaler"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
loadtime: float
|
||||||
|
cur_load: float
|
||||||
|
error_msg: str
|
||||||
|
max_perf: float
|
||||||
|
cur_perf: float
|
||||||
|
cur_capacity: float
|
||||||
|
max_capacity: float
|
||||||
|
num_requests_working: int
|
||||||
|
num_requests_recieved: int
|
||||||
|
additional_disk_usage: float
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
class LogAction(Enum):
|
||||||
|
"""
|
||||||
|
These actions tell the backend what a log value means, for example:
|
||||||
|
actions [
|
||||||
|
# this marks the model server as loaded
|
||||||
|
(LogAction.ModelLoaded, "Starting server"),
|
||||||
|
# these mark the model server as errored
|
||||||
|
(LogAction.ModelError, "Exception loading model"),
|
||||||
|
(LogAction.ModelError, "Server failed to bind to port"),
|
||||||
|
# this tells the backend to print any logs containing the string into its own logs
|
||||||
|
# which are visible in the vast console instance logs
|
||||||
|
(LogAction.Info, "Starting model download"),
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
ModelLoaded = 1
|
||||||
|
ModelError = 2
|
||||||
|
Info = 3
|
||||||
+153
@@ -0,0 +1,153 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from asyncio import sleep
|
||||||
|
from dataclasses import dataclass, asdict, field
|
||||||
|
from functools import cache
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
|
||||||
|
from typing import Awaitable, NoReturn, List
|
||||||
|
|
||||||
|
METRICS_UPDATE_INTERVAL = 1
|
||||||
|
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_url() -> str:
|
||||||
|
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||||
|
worker_port = os.environ[f"VAST_TCP_PORT_{os.environ['WORKER_PORT']}"]
|
||||||
|
public_ip = os.environ["PUBLIC_IPADDR"]
|
||||||
|
return f"http{'s' if use_ssl else ''}://{public_ip}:{worker_port}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Metrics:
|
||||||
|
last_metric_update: float = 0.0
|
||||||
|
update_pending: bool = False
|
||||||
|
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
|
||||||
|
report_addr: List[str] = field(
|
||||||
|
default_factory=lambda: os.environ["REPORT_ADDR"].split(",")
|
||||||
|
)
|
||||||
|
url: str = field(default_factory=get_url)
|
||||||
|
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
|
||||||
|
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
|
||||||
|
|
||||||
|
def _request_start(self, workload: float, reqnum: int) -> None:
|
||||||
|
"""
|
||||||
|
this function is called prior to forwarding a request to a model API.
|
||||||
|
"""
|
||||||
|
log.debug("request start")
|
||||||
|
self.model_metrics.workload_pending += workload
|
||||||
|
self.model_metrics.workload_received += workload
|
||||||
|
self.model_metrics.requests_recieved.add(reqnum)
|
||||||
|
self.model_metrics.requests_working.add(reqnum)
|
||||||
|
|
||||||
|
def _request_end(
|
||||||
|
self, workload: float, req_response_time: float, reqnum: int
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
this function is called after a response from model API is received.
|
||||||
|
"""
|
||||||
|
self.model_metrics.workload_served += workload
|
||||||
|
self.model_metrics.workload_pending -= workload
|
||||||
|
self.model_metrics.requests_working.discard(reqnum)
|
||||||
|
self.model_metrics.cur_perf = workload / req_response_time
|
||||||
|
self.update_pending = True
|
||||||
|
|
||||||
|
def _request_errored(self, workload: float, reqnum: int) -> None:
|
||||||
|
"""
|
||||||
|
this function is called if model API returns an error
|
||||||
|
"""
|
||||||
|
self.model_metrics.workload_pending -= workload
|
||||||
|
self.model_metrics.workload_errored += workload
|
||||||
|
self.model_metrics.requests_working.discard(reqnum)
|
||||||
|
|
||||||
|
def _request_canceled(self, workload: float, reqnum: int) -> None:
|
||||||
|
"""
|
||||||
|
this function is called if client drops connection before model API has responded
|
||||||
|
"""
|
||||||
|
self.model_metrics.workload_pending -= workload
|
||||||
|
self.model_metrics.workload_cancelled += workload
|
||||||
|
self.model_metrics.requests_working.discard(reqnum)
|
||||||
|
|
||||||
|
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||||
|
while True:
|
||||||
|
await sleep(METRICS_UPDATE_INTERVAL)
|
||||||
|
elapsed = time.time() - self.last_metric_update
|
||||||
|
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
||||||
|
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
||||||
|
self.__send_metrics_and_reset(elapsed)
|
||||||
|
elif self.update_pending or elapsed > 10:
|
||||||
|
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
||||||
|
self.__send_metrics_and_reset(elapsed)
|
||||||
|
|
||||||
|
def _model_loaded(self, max_throughput: float) -> None:
|
||||||
|
self.system_metrics.model_loading_time = (
|
||||||
|
time.time() - self.system_metrics.model_loading_start
|
||||||
|
)
|
||||||
|
self.system_metrics.model_is_loaded = True
|
||||||
|
self.model_metrics.max_throughput = max_throughput
|
||||||
|
|
||||||
|
def _model_errored(self, error_msg: str) -> None:
|
||||||
|
self.model_metrics.set_errored(error_msg)
|
||||||
|
self.system_metrics.model_is_loaded = True
|
||||||
|
|
||||||
|
#######################################Private#######################################
|
||||||
|
|
||||||
|
def __send_metrics_and_reset(self, elapsed):
|
||||||
|
|
||||||
|
def compute_autoscaler_data() -> AutoScalaerData:
|
||||||
|
return AutoScalaerData(
|
||||||
|
id=self.id,
|
||||||
|
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
||||||
|
cur_load=(self.model_metrics.workload_processing / elapsed),
|
||||||
|
max_perf=self.model_metrics.max_throughput,
|
||||||
|
cur_perf=self.model_metrics.cur_perf,
|
||||||
|
error_msg=self.model_metrics.error_msg or "",
|
||||||
|
num_requests_working=len(self.model_metrics.requests_working),
|
||||||
|
num_requests_recieved=len(self.model_metrics.requests_recieved),
|
||||||
|
additional_disk_usage=self.system_metrics.additional_disk_usage,
|
||||||
|
cur_capacity=0,
|
||||||
|
max_capacity=0,
|
||||||
|
url=self.url,
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_data(report_addr: str) -> None:
|
||||||
|
data = compute_autoscaler_data()
|
||||||
|
full_path = urljoin(report_addr, "/worker_status/")
|
||||||
|
log.debug(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
"#" * 60,
|
||||||
|
f"sending data to autoscaler",
|
||||||
|
f"{json.dumps((asdict(data)), indent=2)}",
|
||||||
|
"#" * 60,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for attempt in range(1, 4):
|
||||||
|
try:
|
||||||
|
requests.post(full_path, json=asdict(data), timeout=1)
|
||||||
|
break
|
||||||
|
except requests.Timeout:
|
||||||
|
log.debug(f"autoscaler status update timed out")
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"autoscaler status update failed with error: {e}")
|
||||||
|
time.sleep(2)
|
||||||
|
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
|
||||||
|
|
||||||
|
###########
|
||||||
|
|
||||||
|
self.system_metrics.update_disk_usage()
|
||||||
|
|
||||||
|
for report_addr in self.report_addr:
|
||||||
|
send_data(report_addr)
|
||||||
|
self.update_pending = False
|
||||||
|
self.model_metrics.reset()
|
||||||
|
self.system_metrics.reset()
|
||||||
|
self.last_metric_update = time.time()
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
import ssl
|
||||||
|
from asyncio import run, gather
|
||||||
|
|
||||||
|
|
||||||
|
from lib.backend import Backend
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||||
|
log.debug("getting certificate...")
|
||||||
|
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||||
|
if use_ssl is True:
|
||||||
|
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
|
ssl_context.load_cert_chain(
|
||||||
|
certfile="/etc/instance.crt",
|
||||||
|
keyfile="/etc/instance.key",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ssl_context = None
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
log.debug("starting server...")
|
||||||
|
app = web.Application()
|
||||||
|
app.add_routes(routes)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(
|
||||||
|
runner,
|
||||||
|
ssl_context=ssl_context,
|
||||||
|
port=int(os.environ["WORKER_PORT"]),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
await gather(site.start(), backend._start_tracking())
|
||||||
|
|
||||||
|
run(main())
|
||||||
@@ -0,0 +1,267 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
from typing import Callable, List, Dict, Tuple, Dict, Any, Type
|
||||||
|
from time import sleep
|
||||||
|
import threading
|
||||||
|
from enum import Enum
|
||||||
|
from collections import Counter
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from lib.data_types import AuthData, ApiPayload
|
||||||
|
|
||||||
|
|
||||||
|
class ClientStatus(Enum):
|
||||||
|
FetchEndpoint = 1
|
||||||
|
Generating = 2
|
||||||
|
Done = 3
|
||||||
|
Error = 4
|
||||||
|
|
||||||
|
|
||||||
|
total_success = 0
|
||||||
|
last_res = []
|
||||||
|
stop_event = threading.Event()
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
test_args = argparse.ArgumentParser(description="Test inference endpoint")
|
||||||
|
test_args.add_argument(
|
||||||
|
"-k", dest="api_key", type=str, required=True, help="Your vast account API key"
|
||||||
|
)
|
||||||
|
test_args.add_argument(
|
||||||
|
"-e",
|
||||||
|
dest="endpoint_group_name",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Endpoint group name",
|
||||||
|
)
|
||||||
|
test_args.add_argument(
|
||||||
|
"-l",
|
||||||
|
dest="server_url",
|
||||||
|
action="store_const",
|
||||||
|
const="http://localhost:8081",
|
||||||
|
default="https://run.vast.ai",
|
||||||
|
help="Call local autoscaler instead of prod, for dev use only",
|
||||||
|
)
|
||||||
|
|
||||||
|
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
|
||||||
|
|
||||||
|
|
||||||
|
def print_truncate_res(res: str):
|
||||||
|
if len(res) > 150:
|
||||||
|
print(f"{res[:50]}....{res[-100:]}")
|
||||||
|
else:
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClientState:
|
||||||
|
endpoint_group_name: str
|
||||||
|
api_key: str
|
||||||
|
server_url: str
|
||||||
|
worker_endpoint: str
|
||||||
|
payload: ApiPayload
|
||||||
|
url: str = ""
|
||||||
|
status: ClientStatus = ClientStatus.FetchEndpoint
|
||||||
|
as_error: List[str] = field(default_factory=list)
|
||||||
|
infer_error: List[str] = field(default_factory=list)
|
||||||
|
conn_errors: Counter = field(default_factory=Counter)
|
||||||
|
|
||||||
|
def make_call(self):
|
||||||
|
self.status = ClientStatus.FetchEndpoint
|
||||||
|
route_payload = {
|
||||||
|
"endpoint": self.endpoint_group_name,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"cost": self.payload.count_workload(),
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
urljoin(self.server_url, "/route/"),
|
||||||
|
json=route_payload,
|
||||||
|
timeout=4,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
self.as_error.append(
|
||||||
|
f"code: {response.status_code}, body: {response.text}",
|
||||||
|
)
|
||||||
|
self.status = ClientStatus.Error
|
||||||
|
return
|
||||||
|
message = response.json()
|
||||||
|
worker_address = message["url"]
|
||||||
|
req_data = dict(
|
||||||
|
payload=asdict(self.payload),
|
||||||
|
auth_data=asdict(AuthData.from_json_msg(message)),
|
||||||
|
)
|
||||||
|
self.url = worker_address
|
||||||
|
url = urljoin(worker_address, self.worker_endpoint)
|
||||||
|
self.status = ClientStatus.Generating
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=req_data,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
self.infer_error.append(
|
||||||
|
f"code: {response.status_code}, body: {response.text}, url: {url}",
|
||||||
|
)
|
||||||
|
self.status = ClientStatus.Error
|
||||||
|
return
|
||||||
|
res = str(response.json())
|
||||||
|
global total_success
|
||||||
|
global last_res
|
||||||
|
total_success += 1
|
||||||
|
last_res.append(res)
|
||||||
|
self.status = ClientStatus.Done
|
||||||
|
|
||||||
|
def simulate_user(self) -> None:
|
||||||
|
try:
|
||||||
|
self.make_call()
|
||||||
|
except Exception as e:
|
||||||
|
self.status = ClientStatus.Error
|
||||||
|
_ = e
|
||||||
|
self.conn_errors[self.url] += 1
|
||||||
|
|
||||||
|
|
||||||
|
def print_state(clients: List[ClientState], num_clients: int) -> None:
|
||||||
|
print("starting up...")
|
||||||
|
sleep(2)
|
||||||
|
center_size = 14
|
||||||
|
global start_time
|
||||||
|
while len(clients) < num_clients or (
|
||||||
|
any(
|
||||||
|
map(
|
||||||
|
lambda client: client.status
|
||||||
|
in [ClientStatus.FetchEndpoint, ClientStatus.Generating],
|
||||||
|
clients,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
sleep(0.5)
|
||||||
|
os.system("clear")
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
[member.name.center(center_size) for member in ClientStatus]
|
||||||
|
+ [
|
||||||
|
item.center(center_size)
|
||||||
|
for item in [
|
||||||
|
"urls",
|
||||||
|
"as_error",
|
||||||
|
"infer_error",
|
||||||
|
"conn_error",
|
||||||
|
"total_success",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
unique_urls = len(set([c.url for c in clients if c.url != ""]))
|
||||||
|
as_errors = sum(
|
||||||
|
map(
|
||||||
|
lambda client: len(client.as_error),
|
||||||
|
[client for client in clients],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
infer_errors = sum(
|
||||||
|
map(
|
||||||
|
lambda client: len(client.infer_error),
|
||||||
|
[client for client in clients],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conn_errors = sum([client.conn_errors for client in clients], start=Counter())
|
||||||
|
conn_errors_str = ",".join(map(str, conn_errors.values())) or "0"
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
map(
|
||||||
|
lambda item: str(item).center(center_size),
|
||||||
|
[
|
||||||
|
len(list(filter(lambda x: x.status == member, clients)))
|
||||||
|
for member in ClientStatus
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
unique_urls,
|
||||||
|
as_errors,
|
||||||
|
infer_errors,
|
||||||
|
conn_errors_str,
|
||||||
|
f"{total_success}({((total_success/elapsed) * 60):.2f}/minute)",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if conn_errors:
|
||||||
|
print("conn_errors:")
|
||||||
|
for url, count in conn_errors.items():
|
||||||
|
print(url.ljust(28), ": ", str(count))
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(f"\n elapsed: {int(elapsed // 60)}:{int(elapsed % 60)}")
|
||||||
|
if last_res:
|
||||||
|
for i, res in enumerate(last_res[-10:]):
|
||||||
|
print_truncate_res(f"res #{1+i+max(len(last_res )-10,0)}: {res}")
|
||||||
|
if stop_event.is_set():
|
||||||
|
print("\n### waiting for existing connections to close ###")
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(
|
||||||
|
num_requests: int,
|
||||||
|
requests_per_second: int,
|
||||||
|
endpoint_group_name: str,
|
||||||
|
api_key: str,
|
||||||
|
server_url: str,
|
||||||
|
worker_endpoint: str,
|
||||||
|
payload_cls: Type[ApiPayload],
|
||||||
|
):
|
||||||
|
threads = []
|
||||||
|
|
||||||
|
clients = []
|
||||||
|
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
|
||||||
|
print_thread.daemon = True # makes threads get killed on program exit
|
||||||
|
print_thread.start()
|
||||||
|
try:
|
||||||
|
for _ in range(num_requests):
|
||||||
|
client = ClientState(
|
||||||
|
endpoint_group_name=endpoint_group_name,
|
||||||
|
api_key=api_key,
|
||||||
|
server_url=server_url,
|
||||||
|
worker_endpoint=worker_endpoint,
|
||||||
|
payload=payload_cls.for_test(),
|
||||||
|
)
|
||||||
|
clients.append(client)
|
||||||
|
thread = threading.Thread(target=client.simulate_user, args=())
|
||||||
|
threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
sleep(1 / requests_per_second)
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
print("done spawning workers")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_cmd(
|
||||||
|
payload_cls: Type[ApiPayload], endpoint: str, arg_parser: argparse.ArgumentParser
|
||||||
|
):
|
||||||
|
arg_parser.add_argument(
|
||||||
|
"-n",
|
||||||
|
dest="num_requests",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="total number of requests",
|
||||||
|
)
|
||||||
|
arg_parser.add_argument(
|
||||||
|
"-rps",
|
||||||
|
dest="requests_per_second",
|
||||||
|
type=float,
|
||||||
|
required=True,
|
||||||
|
help="requests per second",
|
||||||
|
)
|
||||||
|
args = arg_parser.parse_args()
|
||||||
|
if hasattr(args, "comfy_model"):
|
||||||
|
os.environ["COMFY_MODEL"] = args.comfy_model
|
||||||
|
run_test(
|
||||||
|
num_requests=args.num_requests,
|
||||||
|
requests_per_second=args.requests_per_second,
|
||||||
|
api_key=args.api_key,
|
||||||
|
server_url=args.server_url,
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
|
worker_endpoint=endpoint,
|
||||||
|
payload_cls=payload_cls,
|
||||||
|
)
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
aiofiles==24.1.0
|
||||||
|
aiohappyeyeballs==2.3.4
|
||||||
|
aiohttp==3.10.0
|
||||||
|
aiojobs==1.2.1
|
||||||
|
aiosignal==1.3.1
|
||||||
|
anyio==4.4.0
|
||||||
|
attrs==23.2.0
|
||||||
|
blinker==1.8.2
|
||||||
|
certifi==2024.7.4
|
||||||
|
charset-normalizer==3.3.2
|
||||||
|
click==8.1.7
|
||||||
|
cx_Freeze==7.1.1
|
||||||
|
filelock==3.15.4
|
||||||
|
Flask==3.0.3
|
||||||
|
frozenlist==1.4.1
|
||||||
|
fsspec==2024.6.1
|
||||||
|
gitignore_parser==0.1.11
|
||||||
|
gunicorn==22.0.0
|
||||||
|
hf_transfer==0.1.8
|
||||||
|
huggingface-hub==0.24.2
|
||||||
|
idna==3.7
|
||||||
|
itsdangerous==2.2.0
|
||||||
|
Jinja2==3.1.4
|
||||||
|
joblib==1.4.2
|
||||||
|
MarkupSafe==2.1.5
|
||||||
|
multidict==6.0.5
|
||||||
|
nltk==3.8.1
|
||||||
|
Nuitka==2.3.11
|
||||||
|
numpy==2.0.0
|
||||||
|
ordered-set==4.1.0
|
||||||
|
packaging==24.1
|
||||||
|
patchelf==0.17.2.1
|
||||||
|
psutil==6.0.0
|
||||||
|
pycryptodome==3.20.0
|
||||||
|
PyYAML==6.0.1
|
||||||
|
regex==2024.5.15
|
||||||
|
requests==2.32.3
|
||||||
|
safetensors==0.4.3
|
||||||
|
setuptools==70.3.0
|
||||||
|
sniffio==1.3.1
|
||||||
|
tiktoken==0.7.0
|
||||||
|
token-count==0.2.1
|
||||||
|
tokenizers==0.19.1
|
||||||
|
tqdm==4.66.4
|
||||||
|
transformers==4.43.2
|
||||||
|
typing_extensions==4.12.2
|
||||||
|
urllib3==2.2.2
|
||||||
|
Werkzeug==3.0.3
|
||||||
|
wheel==0.43.0
|
||||||
|
yarl==1.9.4
|
||||||
|
zstandard==0.22.0
|
||||||
Executable
+118
@@ -0,0 +1,118 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e -o pipefail
|
||||||
|
|
||||||
|
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
|
||||||
|
|
||||||
|
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
|
||||||
|
ENV_PATH="$WORKSPACE_DIR/worker-env"
|
||||||
|
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
||||||
|
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
||||||
|
|
||||||
|
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
|
||||||
|
USE_SSL="${USE_SSL:-true}"
|
||||||
|
WORKER_PORT="${WORKER_PORT:-3000}"
|
||||||
|
|
||||||
|
mkdir -p "$WORKSPACE_DIR"
|
||||||
|
cd "$WORKSPACE_DIR"
|
||||||
|
|
||||||
|
# make all output go to $DEBUG_LOG and stdout without having to add `... | tee -a $DEBUG_LOG` to every command
|
||||||
|
exec &> >(tee -a "$DEBUG_LOG")
|
||||||
|
|
||||||
|
function echo_var(){
|
||||||
|
echo "$1: ${!1}"
|
||||||
|
}
|
||||||
|
|
||||||
|
[ -z "$BACKEND" ] && echo "BACKEND must be set!" && exit 1
|
||||||
|
[ -z "$MODEL_LOG" ] && echo "MODEL_LOG must be set!" && exit 1
|
||||||
|
[ -z "$HF_TOKEN" ] && echo "HF_TOKEN must be set!" && exit 1
|
||||||
|
[ "$BACKEND" = "sd3" ] && [ -z "$COMFY_MODEL" ] && echo "For sd3 backends, COMFY_MODEL must be set!" && exit 1
|
||||||
|
|
||||||
|
|
||||||
|
echo "start_server.sh"
|
||||||
|
date
|
||||||
|
|
||||||
|
echo_var BACKEND
|
||||||
|
echo_var REPORT_ADDR
|
||||||
|
echo_var WORKER_PORT
|
||||||
|
echo_var WORKSPACE_DIR
|
||||||
|
echo_var SERVER_DIR
|
||||||
|
echo_var ENV_PATH
|
||||||
|
echo_var DEBUG_LOG
|
||||||
|
echo_var PYWORKER_LOG
|
||||||
|
echo_var MODEL_LOG
|
||||||
|
|
||||||
|
env | grep _ >> /etc/environment;
|
||||||
|
|
||||||
|
|
||||||
|
if [ ! -d "$ENV_PATH" ]
|
||||||
|
then
|
||||||
|
apt install -y python3.10-venv
|
||||||
|
echo "setting up venv"
|
||||||
|
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
|
||||||
|
|
||||||
|
python3 -m venv "$WORKSPACE_DIR/worker-env"
|
||||||
|
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
||||||
|
|
||||||
|
pip install -r vast-pyworker/requirements.txt
|
||||||
|
|
||||||
|
touch ~/.no_auto_tmux
|
||||||
|
else
|
||||||
|
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
||||||
|
echo "environment activated"
|
||||||
|
echo "venv: $VIRTUAL_ENV"
|
||||||
|
fi
|
||||||
|
|
||||||
|
[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
|
||||||
|
|
||||||
|
if [ "$USE_SSL" = true ]; then
|
||||||
|
|
||||||
|
cat << EOF > /etc/openssl-san.cnf
|
||||||
|
[req]
|
||||||
|
default_bits = 2048
|
||||||
|
distinguished_name = req_distinguished_name
|
||||||
|
req_extensions = v3_req
|
||||||
|
|
||||||
|
[req_distinguished_name]
|
||||||
|
countryName = US
|
||||||
|
stateOrProvinceName = CA
|
||||||
|
organizationName = Vast.ai Inc.
|
||||||
|
commonName = vast.ai
|
||||||
|
|
||||||
|
[v3_req]
|
||||||
|
basicConstraints = CA:FALSE
|
||||||
|
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
|
||||||
|
subjectAltName = @alt_names
|
||||||
|
|
||||||
|
[alt_names]
|
||||||
|
IP.1 = 0.0.0.0
|
||||||
|
EOF
|
||||||
|
|
||||||
|
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
|
||||||
|
-nodes \
|
||||||
|
-sha256 \
|
||||||
|
-keyout /etc/instance.key \
|
||||||
|
-out /etc/instance.csr \
|
||||||
|
-config /etc/openssl-san.cnf
|
||||||
|
|
||||||
|
curl --header 'Content-Type: application/octet-stream' \
|
||||||
|
--data-binary @//etc/instance.csr \
|
||||||
|
-X \
|
||||||
|
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
export REPORT_ADDR WORKER_PORT USE_SSL
|
||||||
|
|
||||||
|
cd "$SERVER_DIR"
|
||||||
|
|
||||||
|
echo "launching PyWorker server"
|
||||||
|
|
||||||
|
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
||||||
|
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||||
|
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
||||||
|
|
||||||
|
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
||||||
|
echo "launching PyWorker server done"
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
This is the base PyWorker for comfyui. It can be used to create PyWorker that use various models and
|
||||||
|
workflows. It provides two endpoints:
|
||||||
|
|
||||||
|
1. `/prompt`: Uses the default comfy workflow defined under `misc/default_workflows`
|
||||||
|
2. `/custom_workflow`: Allows the client to send their own comfy workflow with each API request.
|
||||||
|
|
||||||
|
To use the comfyui PyWorker, `$COMFY_MODEL` env variable must be set in the template. Current options are
|
||||||
|
`sd3` and `flux`. Each have example clients.
|
||||||
|
|
||||||
|
To add new models, a JSON with name `$COMFY_MODEL.json` must be created under `misc/default_workflows`
|
||||||
|
|
||||||
|
NOTE: default workflows follow this format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"handler": "RawWorkflow",
|
||||||
|
"aws_access_key_id": "your-s3-access-key",
|
||||||
|
"aws_secret_access_key": "your-s3-secret-access-key",
|
||||||
|
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
|
||||||
|
"aws_bucket_name": "your-bucket",
|
||||||
|
"webhook_url": "your-webhook-url",
|
||||||
|
"webhook_extra_params": {},
|
||||||
|
"workflow_json": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can ignore all of these fields except for `workflow_json`.
|
||||||
|
|
||||||
|
Fields written as "{{FOO}}" will be replaced using data from a user request. For example, SD3's workflow has the
|
||||||
|
following nodes:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"width": "{{WIDTH}}",
|
||||||
|
"height": "{{HEIGHT}}",
|
||||||
|
"batch_size": 1
|
||||||
|
},
|
||||||
|
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "{{PROMPT}}",
|
||||||
|
"clip": ["11", 0]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
...
|
||||||
|
"17": {
|
||||||
|
"inputs": {
|
||||||
|
"scheduler": "simple",
|
||||||
|
"steps": "{{STEPS}}",
|
||||||
|
"denoise": 1,
|
||||||
|
"model": ["12", 0]
|
||||||
|
},
|
||||||
|
"class_type": "BasicScheduler",
|
||||||
|
"_meta": {
|
||||||
|
"title": "BasicScheduler"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
...
|
||||||
|
"25": {
|
||||||
|
"inputs": {
|
||||||
|
"noise_seed": "{{SEED}}"
|
||||||
|
},
|
||||||
|
"class_type": "RandomNoise",
|
||||||
|
"_meta": {
|
||||||
|
"title": "RandomNoise"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Incoming requests have the following JSON format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
prompt: str
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
steps: int
|
||||||
|
seed: int
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Each value in those fields with replace the placeholder of the same name in the default workflow.
|
||||||
|
|
||||||
|
See Vast's serverless documentation for more details on how to use comfyui with autoscaler
|
||||||
@@ -0,0 +1,150 @@
|
|||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from lib.test_utils import print_truncate_res
|
||||||
|
|
||||||
|
"""
|
||||||
|
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def call_default_workflow(
|
||||||
|
endpoint_group_name: str, api_key: str, server_url: str
|
||||||
|
) -> None:
|
||||||
|
WORKER_ENDPOINT = "/prompt"
|
||||||
|
COST = 100
|
||||||
|
route_payload = {
|
||||||
|
"endpoint": endpoint_group_name,
|
||||||
|
"api_key": api_key,
|
||||||
|
"cost": COST,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
urljoin(server_url, "/route/"),
|
||||||
|
json=route_payload,
|
||||||
|
timeout=4,
|
||||||
|
)
|
||||||
|
message = response.json()
|
||||||
|
url = message["url"]
|
||||||
|
auth_data = dict(
|
||||||
|
signature=message["signature"],
|
||||||
|
cost=message["cost"],
|
||||||
|
endpoint=message["endpoint"],
|
||||||
|
reqnum=message["reqnum"],
|
||||||
|
url=message["url"],
|
||||||
|
)
|
||||||
|
payload = dict(
|
||||||
|
prompt="a fat fluffy cat", width=1024, height=1024, steps=20, seed=123456789
|
||||||
|
)
|
||||||
|
req_data = dict(payload=payload, auth_data=auth_data)
|
||||||
|
url = urljoin(url, WORKER_ENDPOINT)
|
||||||
|
print(f"url: {url}")
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=req_data,
|
||||||
|
)
|
||||||
|
print_truncate_res(str(response.json()))
|
||||||
|
|
||||||
|
|
||||||
|
def call_custom_workflow_for_sd3(
|
||||||
|
endpoint_group_name: str, api_key: str, server_url: str
|
||||||
|
) -> None:
|
||||||
|
WORKER_ENDPOINT = "/custom-workflow"
|
||||||
|
COST = 100
|
||||||
|
route_payload = {
|
||||||
|
"endpoint": endpoint_group_name,
|
||||||
|
"api_key": api_key,
|
||||||
|
"cost": COST,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
urljoin(server_url, "/route/"),
|
||||||
|
json=route_payload,
|
||||||
|
timeout=4,
|
||||||
|
)
|
||||||
|
message = response.json()
|
||||||
|
url = message["url"]
|
||||||
|
auth_data = dict(
|
||||||
|
signature=message["signature"],
|
||||||
|
cost=message["cost"],
|
||||||
|
endpoint=message["endpoint"],
|
||||||
|
reqnum=message["reqnum"],
|
||||||
|
url=message["url"],
|
||||||
|
)
|
||||||
|
workflow = {
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"seed": 156680208700286,
|
||||||
|
"steps": 20,
|
||||||
|
"cfg": 8,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"denoise": 1,
|
||||||
|
"model": ["4", 0],
|
||||||
|
"positive": ["6", 0],
|
||||||
|
"negative": ["7", 0],
|
||||||
|
"latent_image": ["5", 0],
|
||||||
|
},
|
||||||
|
"class_type": "KSampler",
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"inputs": {"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"},
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"inputs": {"width": 512, "height": 512, "batch_size": 1},
|
||||||
|
"class_type": "EmptyLatentImage",
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "beautiful scenery nature glass bottle landscape, purple galaxy bottle",
|
||||||
|
"clip": ["4", 1],
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"inputs": {"text": "text, watermark", "clip": ["4", 1]},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"inputs": {"filename_prefix": "ComfyUI", "images": ["8", 0]},
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
# these values should match the values in the custom workflow above,
|
||||||
|
# they are used to calculate workload
|
||||||
|
custom_fields = dict(
|
||||||
|
steps=20,
|
||||||
|
width=512,
|
||||||
|
height=512,
|
||||||
|
)
|
||||||
|
req_data = dict(
|
||||||
|
payload=dict(custom_fields=custom_fields, workflow=workflow),
|
||||||
|
auth_data=auth_data,
|
||||||
|
)
|
||||||
|
url = urljoin(url, WORKER_ENDPOINT)
|
||||||
|
print(f"url: {url}")
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=req_data,
|
||||||
|
)
|
||||||
|
print_truncate_res(str(response.json()))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from lib.test_utils import test_args
|
||||||
|
|
||||||
|
args = test_args.parse_args()
|
||||||
|
call_default_workflow(
|
||||||
|
api_key=args.api_key,
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
|
server_url=args.server_url,
|
||||||
|
)
|
||||||
|
call_custom_workflow_for_sd3(
|
||||||
|
api_key=args.api_key,
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
|
server_url=args.server_url,
|
||||||
|
)
|
||||||
@@ -0,0 +1,205 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import dataclasses
|
||||||
|
import inspect
|
||||||
|
from typing import Dict, Any
|
||||||
|
from functools import cache
|
||||||
|
from math import ceil
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from lib.data_types import ApiPayload, JsonDataException
|
||||||
|
|
||||||
|
|
||||||
|
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
|
||||||
|
test_prompts = f.readlines()
|
||||||
|
|
||||||
|
|
||||||
|
class Model(Enum):
|
||||||
|
Flux = "flux"
|
||||||
|
Sd3 = "sd3"
|
||||||
|
|
||||||
|
def get_request_time(self) -> int:
|
||||||
|
match self:
|
||||||
|
case Model.Flux:
|
||||||
|
return 23
|
||||||
|
case Model.Sd3:
|
||||||
|
return 6
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_model() -> Model:
|
||||||
|
match os.environ.get("COMFY_MODEL"):
|
||||||
|
case "flux":
|
||||||
|
return Model.Flux
|
||||||
|
case "sd3":
|
||||||
|
return Model.Sd3
|
||||||
|
case None:
|
||||||
|
raise Exception(
|
||||||
|
"For comfyui pyworker, $COMFY_MODEL must be set in the vast template"
|
||||||
|
)
|
||||||
|
case model:
|
||||||
|
raise Exception(f"Unsupported comfyui model: {model}")
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_request_template() -> str:
|
||||||
|
with open(f"workers/comfyui/misc/default_workflows/{get_model().value}.json") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def count_workload(width: int, height: int, steps: int) -> float:
|
||||||
|
"""
|
||||||
|
we want to normalize the workload is a number such that cur_perf(tokens/second) for 1024x1024 image with
|
||||||
|
28 steps is 200 tokens on a 4090.
|
||||||
|
|
||||||
|
in order get that we calculate the
|
||||||
|
|
||||||
|
A = ( absolute workload based on given data )
|
||||||
|
B = ( absolute workload for a 1024x1024 image with 28 steps )
|
||||||
|
|
||||||
|
and adjust the workload to 200 tokens by A/B.
|
||||||
|
|
||||||
|
we then adjust for difference between Flux and SD3 by multiplying this value by expected request time for a
|
||||||
|
standard image(23s for Flux, 6s for SD3).
|
||||||
|
On a 4090, this would give us a workload that would give a cur_perf(workload / request_time) of around 200
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _calculate_absolute_tokens(width_: int, height_: int, steps_: int) -> float:
|
||||||
|
"""
|
||||||
|
This is based on how openai counts image generation tokens, see: https://openai.com/api/pricing/
|
||||||
|
|
||||||
|
we count how many 512x512 grids are needed to cover the image.
|
||||||
|
each tile is then counted as 175 tokens.
|
||||||
|
each image generation also has constant of 85 base tokens.
|
||||||
|
|
||||||
|
we then adjust the count based on the number of steps. The baseline number of steps is assumed to be 28.
|
||||||
|
Some testing with flux gave me this data:
|
||||||
|
|
||||||
|
steps(X) | request time(Y)
|
||||||
|
__________|_________________
|
||||||
|
07(0.25x) | 11s (0.47x)
|
||||||
|
14(0.50x) | 15s (0.65x)
|
||||||
|
21(0.75x) | 20s (0.86x)
|
||||||
|
28(1.00x) | 23s (1.00x)
|
||||||
|
35(1.25x) | 28s (1.21x)
|
||||||
|
42(1.50x) | 32s (1.39x)
|
||||||
|
49(1.75x) | 37s (1.60x)
|
||||||
|
|
||||||
|
this gives a linear regression of Y = 0.61*X + 6.57
|
||||||
|
|
||||||
|
we can use this as an adjustment_factor for token count
|
||||||
|
|
||||||
|
adjustment_factor = (0.61 * steps + 6.57)
|
||||||
|
"""
|
||||||
|
|
||||||
|
width_grids = ceil(width_ / 512)
|
||||||
|
height_grids = ceil(height_ / 512)
|
||||||
|
tokens = 85 + width_grids * height_grids * 175
|
||||||
|
adjustment_factor = 0.61 * steps_ + 6.57
|
||||||
|
return tokens * adjustment_factor
|
||||||
|
|
||||||
|
REQUEST_TIME_FOR_STANDARD_IMAGE = get_model().get_request_time()
|
||||||
|
|
||||||
|
absolute_tokens = _calculate_absolute_tokens(
|
||||||
|
width_=width, height_=height, steps_=steps
|
||||||
|
)
|
||||||
|
absolute_tokens_standard_image = _calculate_absolute_tokens(
|
||||||
|
width_=1024, height_=1024, steps_=28
|
||||||
|
)
|
||||||
|
return REQUEST_TIME_FOR_STANDARD_IMAGE * (
|
||||||
|
(absolute_tokens / absolute_tokens_standard_image) * 200
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class DefaultComfyWorkflowData(ApiPayload):
|
||||||
|
prompt: str
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
steps: int
|
||||||
|
seed: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_test(cls):
|
||||||
|
|
||||||
|
test_prompt = random.choice(test_prompts).rstrip()
|
||||||
|
return cls(
|
||||||
|
prompt=test_prompt,
|
||||||
|
width=1024,
|
||||||
|
height=1024,
|
||||||
|
steps=28,
|
||||||
|
seed=random.randint(0, sys.maxsize),
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_payload_json(
|
||||||
|
self,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return json.loads(
|
||||||
|
get_request_template()
|
||||||
|
.replace("{{PROMPT}}", self.prompt)
|
||||||
|
# these values should be of int type. Since "{{VAR}}" is wrapped with " in the template
|
||||||
|
# to make the JSON valid, we must replace the double quotes. i.e. "{{WIDTH}}" -> 1024 and not "1024"
|
||||||
|
.replace('"{{WIDTH}}"', str(self.width))
|
||||||
|
.replace('"{{HEIGHT}}"', str(self.height))
|
||||||
|
.replace('"{{STEPS}}"', str(self.steps))
|
||||||
|
.replace('"{{SEED}}"', str(self.seed))
|
||||||
|
)
|
||||||
|
|
||||||
|
def count_workload(self) -> float:
|
||||||
|
return count_workload(width=self.width, height=self.height, steps=self.steps)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "DefaultComfyWorkflowData":
|
||||||
|
errors = {}
|
||||||
|
for param in inspect.signature(cls).parameters:
|
||||||
|
if param not in json_msg:
|
||||||
|
errors[param] = "missing parameter"
|
||||||
|
if errors:
|
||||||
|
raise JsonDataException(errors)
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in json_msg.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CustomComfyWorkflowData(ApiPayload):
|
||||||
|
custom_fields: Dict[str, int]
|
||||||
|
workflow: Dict[str, Any]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_test(cls):
|
||||||
|
raise NotImplemented("Custom comfy workflow is not used for testing")
|
||||||
|
|
||||||
|
def count_workload(self) -> float:
|
||||||
|
return count_workload(
|
||||||
|
width=int(self.custom_fields.get("width", 1024)),
|
||||||
|
height=int(self.custom_fields.get("height", 1024)),
|
||||||
|
steps=int(self.custom_fields.get("steps", 28)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_payload_json(self) -> Dict[str, Any]:
|
||||||
|
template_json = json.loads(get_request_template())
|
||||||
|
template_json["input"]["workflow_json"] = self.workflow
|
||||||
|
return template_json
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "CustomComfyWorkflowData":
|
||||||
|
errors = {}
|
||||||
|
for param in inspect.signature(cls).parameters:
|
||||||
|
if param not in json_msg:
|
||||||
|
errors[param] = "missing parameter"
|
||||||
|
if errors:
|
||||||
|
raise JsonDataException(errors)
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in json_msg.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"handler": "RawWorkflow",
|
||||||
|
"aws_access_key_id": "your-s3-access-key",
|
||||||
|
"aws_secret_access_key": "your-s3-secret-access-key",
|
||||||
|
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
|
||||||
|
"aws_bucket_name": "your-bucket",
|
||||||
|
"webhook_url": "your-webhook-url",
|
||||||
|
"webhook_extra_params": {},
|
||||||
|
"workflow_json": {
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"width": "{{WIDTH}}",
|
||||||
|
"height": "{{HEIGHT}}",
|
||||||
|
"batch_size": 1
|
||||||
|
},
|
||||||
|
"class_type": "EmptyLatentImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Empty Latent Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "{{PROMPT}}",
|
||||||
|
"clip": ["11", 0]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": ["13", 0],
|
||||||
|
"vae": ["10", 0]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE Decode"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "ComfyUI",
|
||||||
|
"images": ["8", 0]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Save Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"10": {
|
||||||
|
"inputs": {
|
||||||
|
"vae_name": "ae.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "VAELoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load VAE"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"11": {
|
||||||
|
"inputs": {
|
||||||
|
"clip_name1": "t5xxl_fp16.safetensors",
|
||||||
|
"clip_name2": "clip_l.safetensors",
|
||||||
|
"type": "flux"
|
||||||
|
},
|
||||||
|
"class_type": "DualCLIPLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "DualCLIPLoader"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"12": {
|
||||||
|
"inputs": {
|
||||||
|
"unet_name": "flux1-dev.safetensors",
|
||||||
|
"weight_dtype": "default"
|
||||||
|
},
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Diffusion Model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"13": {
|
||||||
|
"inputs": {
|
||||||
|
"noise": ["25", 0],
|
||||||
|
"guider": ["22", 0],
|
||||||
|
"sampler": ["16", 0],
|
||||||
|
"sigmas": ["17", 0],
|
||||||
|
"latent_image": ["5", 0]
|
||||||
|
},
|
||||||
|
"class_type": "SamplerCustomAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "SamplerCustomAdvanced"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"inputs": {
|
||||||
|
"sampler_name": "euler"
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerSelect",
|
||||||
|
"_meta": {
|
||||||
|
"title": "KSamplerSelect"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"17": {
|
||||||
|
"inputs": {
|
||||||
|
"scheduler": "simple",
|
||||||
|
"steps": "{{STEPS}}",
|
||||||
|
"denoise": 1,
|
||||||
|
"model": ["12", 0]
|
||||||
|
},
|
||||||
|
"class_type": "BasicScheduler",
|
||||||
|
"_meta": {
|
||||||
|
"title": "BasicScheduler"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"22": {
|
||||||
|
"inputs": {
|
||||||
|
"model": ["12", 0],
|
||||||
|
"conditioning": ["6", 0]
|
||||||
|
},
|
||||||
|
"class_type": "BasicGuider",
|
||||||
|
"_meta": {
|
||||||
|
"title": "BasicGuider"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"25": {
|
||||||
|
"inputs": {
|
||||||
|
"noise_seed": "{{SEED}}"
|
||||||
|
},
|
||||||
|
"class_type": "RandomNoise",
|
||||||
|
"_meta": {
|
||||||
|
"title": "RandomNoise"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"handler": "RawWorkflow",
|
||||||
|
"aws_access_key_id": "your-s3-access-key",
|
||||||
|
"aws_secret_access_key": "your-s3-secret-access-key",
|
||||||
|
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
|
||||||
|
"aws_bucket_name": "your-bucket",
|
||||||
|
"webhook_url": "your-webhook-url",
|
||||||
|
"webhook_extra_params": {},
|
||||||
|
"workflow_json": {
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "{{PROMPT}}",
|
||||||
|
"clip": ["252", 1]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"13": {
|
||||||
|
"inputs": {
|
||||||
|
"shift": 3,
|
||||||
|
"model": ["252", 0]
|
||||||
|
},
|
||||||
|
"class_type": "ModelSamplingSD3",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ModelSamplingSD3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"67": {
|
||||||
|
"inputs": {
|
||||||
|
"conditioning": ["71", 0]
|
||||||
|
},
|
||||||
|
"class_type": "ConditioningZeroOut",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ConditioningZeroOut"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"68": {
|
||||||
|
"inputs": {
|
||||||
|
"start": 0.1,
|
||||||
|
"end": 1,
|
||||||
|
"conditioning": ["67", 0]
|
||||||
|
},
|
||||||
|
"class_type": "ConditioningSetTimestepRange",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ConditioningSetTimestepRange"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"69": {
|
||||||
|
"inputs": {
|
||||||
|
"conditioning_1": ["68", 0],
|
||||||
|
"conditioning_2": ["70", 0]
|
||||||
|
},
|
||||||
|
"class_type": "ConditioningCombine",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Conditioning (Combine)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"70": {
|
||||||
|
"inputs": {
|
||||||
|
"start": 0,
|
||||||
|
"end": 0.1,
|
||||||
|
"conditioning": ["71", 0]
|
||||||
|
},
|
||||||
|
"class_type": "ConditioningSetTimestepRange",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ConditioningSetTimestepRange"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"71": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "bad quality, poor quality, doll, disfigured, jpg, toy, bad anatomy, missing limbs, missing fingers, 3d, cgi",
|
||||||
|
"clip": ["252", 1]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Negative Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"135": {
|
||||||
|
"inputs": {
|
||||||
|
"width": "{{WIDTH}}",
|
||||||
|
"height": "{{HEIGHT}}",
|
||||||
|
"batch_size": 1
|
||||||
|
},
|
||||||
|
"class_type": "EmptySD3LatentImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "EmptySD3LatentImage"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"231": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": ["271", 0],
|
||||||
|
"vae": ["252", 2]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE Decode"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"233": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "ComfyUI",
|
||||||
|
"images": ["231", 0]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Save Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"252": {
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Checkpoint"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"271": {
|
||||||
|
"inputs": {
|
||||||
|
"seed": "{{SEED}}",
|
||||||
|
"steps": "{{STEPS}}",
|
||||||
|
"cfg": 4.5,
|
||||||
|
"sampler_name": "dpmpp_2m",
|
||||||
|
"scheduler": "sgm_uniform",
|
||||||
|
"denoise": 1,
|
||||||
|
"model": ["13", 0],
|
||||||
|
"positive": ["6", 0],
|
||||||
|
"negative": ["69", 0],
|
||||||
|
"latent_image": ["135", 0]
|
||||||
|
},
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"_meta": {
|
||||||
|
"title": "KSampler"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
|
||||||
|
stardew valley, fine details
|
||||||
|
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
|
||||||
|
realistic futuristic city-downtown with short buildings, sunset
|
||||||
|
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
|
||||||
|
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
|
||||||
|
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
|
||||||
|
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
|
||||||
|
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
|
||||||
|
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
|
||||||
|
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
|
||||||
|
Pope Francis wearing biker (leather jacket), a masterpiece
|
||||||
|
Luke Skywalker ordering a burger and fries from the Death Star canteen.
|
||||||
|
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
|
||||||
|
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
|
||||||
|
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
|
||||||
|
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||||
|
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||||
|
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
|
||||||
|
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
|
||||||
|
london luxurious interior living-room, light walls
|
||||||
|
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
|
||||||
|
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
|
||||||
|
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
|
||||||
|
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
|
||||||
|
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
|
||||||
|
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
|
||||||
|
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
|
||||||
|
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
|
||||||
|
the street of amedieval fantasy town, at dawn, dark, highly detailed
|
||||||
|
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
|
||||||
|
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
|
||||||
|
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||||
|
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
|
||||||
@@ -0,0 +1,135 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import dataclasses
|
||||||
|
import base64
|
||||||
|
from typing import Union, Type
|
||||||
|
|
||||||
|
from aiohttp import web, ClientResponse
|
||||||
|
from anyio import open_file
|
||||||
|
|
||||||
|
from lib.backend import Backend, LogAction
|
||||||
|
from lib.data_types import EndpointHandler
|
||||||
|
from lib.server import start_server
|
||||||
|
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_SERVER_URL = "http://0.0.0.0:38188"
|
||||||
|
|
||||||
|
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||||
|
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
|
||||||
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
|
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||||
|
"Value not in list: unet_name", # This error is emitted when the model file is not there at all
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
request: web.Request, response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
_ = request
|
||||||
|
match response.status:
|
||||||
|
case 200:
|
||||||
|
log.debug("SUCCESS")
|
||||||
|
res = await response.json()
|
||||||
|
if "output" not in res:
|
||||||
|
return web.json_response(
|
||||||
|
data=dict(error="there was an error in the workflow"),
|
||||||
|
status=422,
|
||||||
|
)
|
||||||
|
image_paths = [path["local_path"] for path in res["output"]["images"]]
|
||||||
|
if not image_paths:
|
||||||
|
return web.json_response(
|
||||||
|
data=dict(error="workflow did not produce any images"),
|
||||||
|
status=422,
|
||||||
|
)
|
||||||
|
images = []
|
||||||
|
for image_path in image_paths:
|
||||||
|
async with await open_file(image_path, mode="rb") as f:
|
||||||
|
contents = await f.read()
|
||||||
|
images.append(
|
||||||
|
f"data:image/png;base64,{base64.b64encode(contents).decode('utf-8')}"
|
||||||
|
)
|
||||||
|
return web.json_response(data=dict(images=images))
|
||||||
|
case code:
|
||||||
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||||
|
return web.Response(status=code)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
return "/runsync"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
|
||||||
|
return DefaultComfyWorkflowData
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> DefaultComfyWorkflowData:
|
||||||
|
return DefaultComfyWorkflowData.for_test()
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
return await generate_client_response(client_request, model_response)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
return "/runsync"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
|
||||||
|
return CustomComfyWorkflowData
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> CustomComfyWorkflowData:
|
||||||
|
return CustomComfyWorkflowData.for_test()
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
return await generate_client_response(client_request, model_response)
|
||||||
|
|
||||||
|
|
||||||
|
backend = Backend(
|
||||||
|
model_server_url=MODEL_SERVER_URL,
|
||||||
|
model_log_file=os.environ["MODEL_LOG"],
|
||||||
|
allow_parallel_requests=False,
|
||||||
|
benchmark_handler=DefaultComfyWorkflowHandler(
|
||||||
|
benchmark_runs=3, benchmark_words=100
|
||||||
|
),
|
||||||
|
log_actions=[
|
||||||
|
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||||
|
(LogAction.Info, "Downloading:"),
|
||||||
|
*[
|
||||||
|
(LogAction.ModelError, error_msg)
|
||||||
|
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_ping(_):
|
||||||
|
return web.Response(body="pong")
|
||||||
|
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
web.post("/prompt", backend.create_handler(DefaultComfyWorkflowHandler())),
|
||||||
|
web.post("/custom-workflow", backend.create_handler(CustomComfyWorkflowHandler())),
|
||||||
|
web.get("/ping", handle_ping),
|
||||||
|
]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start_server(backend, routes)
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
from lib.test_utils import test_load_cmd, test_args
|
||||||
|
from .data_types import DefaultComfyWorkflowData, Model
|
||||||
|
|
||||||
|
WORKER_ENDPOINT = "/prompt"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_args.add_argument(
|
||||||
|
"-m",
|
||||||
|
dest="comfy_model",
|
||||||
|
choices=list(map(lambda x: x.value, Model)),
|
||||||
|
required=True,
|
||||||
|
help="Image generation model name",
|
||||||
|
)
|
||||||
|
test_load_cmd(DefaultComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||||
@@ -0,0 +1,300 @@
|
|||||||
|
# Vast PyWorker
|
||||||
|
|
||||||
|
## Hello_world example
|
||||||
|
|
||||||
|
There is a hello_world PyWorker implantation under `workers/hello_world`. This PyWorker is
|
||||||
|
created for an LLM model server that runs on port 5001 has two API endpoints:
|
||||||
|
|
||||||
|
1. `/generate`: generates an full response to the prompt and sends a JSON response
|
||||||
|
2. `/generate_stream`: streams a response one token at a time
|
||||||
|
|
||||||
|
Both of these endpoints take the same API JSON payload:
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"prompt": String,
|
||||||
|
"max_response_tokens": Number | null
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
We want the PyWorker to also expose two endpoints, for each of the above endpoints.
|
||||||
|
|
||||||
|
### Structure
|
||||||
|
|
||||||
|
All PyWorkers have four files:
|
||||||
|
|
||||||
|
```
|
||||||
|
.
|
||||||
|
└── workers
|
||||||
|
└── hello_world
|
||||||
|
├── __init__.py
|
||||||
|
├── data_types.py # contains data types representing model API endpoints
|
||||||
|
├── server.py # contains endpoint handlers
|
||||||
|
├── client.py # a script to call an endpoint through the autoscaler
|
||||||
|
└── test_load.py # script for load testing
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
All of the classes follow strict type hinting. It is recommended that you type hint all of your function.
|
||||||
|
This will allow your IDE or VSCode with `pyright` plugin to find any type errors in your implementation.
|
||||||
|
You can also install `pyright` with `sudo npm install -g pyright` and run `pyright` in the root of the project to find
|
||||||
|
any type errors.
|
||||||
|
|
||||||
|
#### data_Types.py
|
||||||
|
|
||||||
|
data classes representing the model API are defined here. They must inherit from
|
||||||
|
`lib.data_types.ApiPayload`. `ApiPayload` is an abstract class and you need to define several functions for it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import dataclasses
|
||||||
|
import random
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer # used to count tokens in a prompt
|
||||||
|
import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model
|
||||||
|
|
||||||
|
from lib.data_types import ApiPayload
|
||||||
|
|
||||||
|
nltk.download("words")
|
||||||
|
WORD_LIST = nltk.corpus.words.words()
|
||||||
|
|
||||||
|
# you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class InputData(ApiPayload):
|
||||||
|
prompt: str
|
||||||
|
max_response_tokens: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_test(cls) -> "ApiPayload":
|
||||||
|
"""defines how create a payload for load testing"""
|
||||||
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||||
|
return cls(prompt=prompt, max_response_tokens=300)
|
||||||
|
|
||||||
|
def generate_payload_json(self) -> Dict[str, Any]:
|
||||||
|
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
|
||||||
|
return dataclasses.asdict(self)
|
||||||
|
|
||||||
|
def count_workload(self) -> float:
|
||||||
|
"""defines how to calculate workload for a payload"""
|
||||||
|
return len(tokenizer.tokenize(self.prompt))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||||
|
"""
|
||||||
|
defines how to transform JSON data to AuthData and payload type,
|
||||||
|
in this case `InputData` defined above represents the data sent to the model API.
|
||||||
|
AuthData is data generated by autoscaler in order to authenticate payloads.
|
||||||
|
In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker
|
||||||
|
for more complicated examples
|
||||||
|
"""
|
||||||
|
errors = {}
|
||||||
|
for param in inspect.signature(cls).parameters:
|
||||||
|
if param not in json_msg:
|
||||||
|
errors[param] = "missing parameter"
|
||||||
|
if errors:
|
||||||
|
raise JsonDataException(errors)
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in json_msg.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
#### server.py
|
||||||
|
|
||||||
|
For every model API endpoint you want to use, you must implement an `EndpointHandler`. This class handles incoming
|
||||||
|
requests, processes them, sends them to the model API server, and finally returns an HTTP response.
|
||||||
|
`EndpointHandler` has several abstract functions that must be implemented. Here, we implement two, one
|
||||||
|
for `/generate`, and one for `/generate_stream`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
"""
|
||||||
|
AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route.
|
||||||
|
When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData
|
||||||
|
work.
|
||||||
|
When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON:
|
||||||
|
{
|
||||||
|
auth_data: AuthData,
|
||||||
|
payload : InputData # defined above
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from lib.data_types import EndpointHandler, JsonDataException
|
||||||
|
from lib.server import start_server
|
||||||
|
from .data_types import InputData
|
||||||
|
|
||||||
|
# This class is the implementer for the '/generate' endpoint of model API
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GenerateHandler(EndpointHandler[InputData]):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
# the API endpoint
|
||||||
|
return "/generate"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[InputData]:
|
||||||
|
"""this function should just return ApiPayload subclass used by this handler"""
|
||||||
|
return InputData
|
||||||
|
|
||||||
|
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
defines how to convert `InputData` defined above, to
|
||||||
|
JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but
|
||||||
|
can be more complicated, See comfyui for an example
|
||||||
|
"""
|
||||||
|
return dataclasses.asdict(payload)
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> InputData:
|
||||||
|
"""
|
||||||
|
defines how to generate an InputData for benchmarking. This needs to be defined in only
|
||||||
|
one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test()
|
||||||
|
method on InputData. However, in some cases you might need to fine tune your InputData used for
|
||||||
|
benchmarking to closely resemble the average request users call the endpoint with in order to get best
|
||||||
|
autoscaling performance
|
||||||
|
"""
|
||||||
|
return InputData.for_test()
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
"""
|
||||||
|
defines how to convert a model API response to a response to PyWorker client
|
||||||
|
"""
|
||||||
|
_ = client_request
|
||||||
|
match model_response.status:
|
||||||
|
case 200:
|
||||||
|
log.debug("SUCCESS")
|
||||||
|
data = await model_response.json()
|
||||||
|
return web.json_response(data=data)
|
||||||
|
case code:
|
||||||
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||||
|
return web.Response(status=code)
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
We also handle `GenerateStreamHandler` for streaming responses. It is identical to `GenerateHandler`, except for
|
||||||
|
the endpoint name and how we create a web response, as it is a streaming response:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
return "/generate_stream"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[InputData]:
|
||||||
|
return InputData
|
||||||
|
|
||||||
|
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||||
|
return dataclasses.asdict(payload)
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> InputData:
|
||||||
|
return InputData.for_test()
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
match model_response.status:
|
||||||
|
case 200:
|
||||||
|
log.debug("Streaming response...")
|
||||||
|
res = web.StreamResponse()
|
||||||
|
res.content_type = "text/event-stream"
|
||||||
|
await res.prepare(client_request)
|
||||||
|
async for chunk in model_response.content:
|
||||||
|
await res.write(chunk)
|
||||||
|
await res.write_eof()
|
||||||
|
log.debug("Done streaming response")
|
||||||
|
return res
|
||||||
|
case code:
|
||||||
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||||
|
return web.Response(status=code)
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
You can now instantiate a Backend and use it to handle requests.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lib.backend import Backend, LogAction
|
||||||
|
|
||||||
|
# the url and port of model API
|
||||||
|
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||||
|
|
||||||
|
|
||||||
|
# This is the log line that is emitted once the server has started
|
||||||
|
MODEL_SERVER_START_LOG_MSG = "server has started"
|
||||||
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
|
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
|
||||||
|
]
|
||||||
|
|
||||||
|
backend = Backend(
|
||||||
|
model_server_url=MODEL_SERVER_URL,
|
||||||
|
# location of model log file
|
||||||
|
model_log_file=os.environ["MODEL_LOG"],
|
||||||
|
# for some model backends that can only handle one request at a time, be sure to set this to False to
|
||||||
|
# let PyWorker handling queueing requests.
|
||||||
|
allow_parallel_requests=True,
|
||||||
|
# give the backend an EndpointHandler instance that is used for benchmarking
|
||||||
|
# number of benchmark run and number of words for a random benchmark run are given
|
||||||
|
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||||
|
# defines how to handle specific log messages. See docstring of LogAction for details
|
||||||
|
log_actions=[
|
||||||
|
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||||
|
(LogAction.Info, '"message":"Download'),
|
||||||
|
*[
|
||||||
|
(LogAction.ModelError, error_msg)
|
||||||
|
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# this is a simple ping handler for PyWorker
|
||||||
|
async def handle_ping(_: web.Request):
|
||||||
|
return web.Response(body="pong")
|
||||||
|
|
||||||
|
# this is a handler for forwarding a health check to model API
|
||||||
|
async def handle_healthcheck(_: web.Request):
|
||||||
|
healthcheck_res = await backend.session.get("/healthcheck")
|
||||||
|
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||||
|
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||||
|
web.get("/ping", handle_ping),
|
||||||
|
web.get("/healthcheck", handle_healthcheck),
|
||||||
|
]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# start server, called from start_server.sh
|
||||||
|
start_server(backend, routes)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### test_load.py
|
||||||
|
|
||||||
|
Here you can create a script that allows you test an endpoint group running instances with this PyWorker
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lib.test_harness import run
|
||||||
|
from .data_types import InputData
|
||||||
|
|
||||||
|
WORKER_ENDPOINT = "/generate"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run(InputData.for_test(), WORKER_ENDPOINT)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can then run the following command from the root of this repo to load test endpoint group:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# sends 1000 requests at the rate of 0.5 requests per second
|
||||||
|
python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"
|
||||||
|
```
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
import dataclasses
|
||||||
|
import random
|
||||||
|
import inspect
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
from lib.data_types import ApiPayload, JsonDataException
|
||||||
|
|
||||||
|
nltk.download("words")
|
||||||
|
WORD_LIST = nltk.corpus.words.words()
|
||||||
|
|
||||||
|
# used to count to count tokens and workload for LLM
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class InputData(ApiPayload):
|
||||||
|
prompt: str
|
||||||
|
max_response_tokens: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_test(cls) -> "InputData":
|
||||||
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||||
|
return cls(prompt=prompt, max_response_tokens=300)
|
||||||
|
|
||||||
|
def generate_payload_json(self) -> Dict[str, Any]:
|
||||||
|
return dataclasses.asdict(self)
|
||||||
|
|
||||||
|
def count_workload(self) -> int:
|
||||||
|
return len(tokenizer.tokenize(self.prompt))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||||
|
errors = {}
|
||||||
|
for param in inspect.signature(cls).parameters:
|
||||||
|
if param not in json_msg:
|
||||||
|
errors[param] = "missing parameter"
|
||||||
|
if errors:
|
||||||
|
raise JsonDataException(errors)
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in json_msg.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
@@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
PyWorker works as a man-in-the-middle between the client and model API. It's function is:
|
||||||
|
1. receive request from client, update metrics such as workload of a request, number of pending requests, etc.
|
||||||
|
2a. transform the data and forward the transformed data to model API
|
||||||
|
2b. send updated metrics to autoscaler
|
||||||
|
3. transform response from model API(if needed) and forward the response to client
|
||||||
|
|
||||||
|
PyWorker forward requests to many model API endpoint. each endpoint must have an EndpointHandler. You can also
|
||||||
|
write function to just forward requests that don't generate anything with the model to model API without an
|
||||||
|
EndpointHandler. This is useful for endpoints such as healthchecks. See below for example
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import dataclasses
|
||||||
|
from typing import Dict, Any, Union, Type
|
||||||
|
|
||||||
|
from aiohttp import web, ClientResponse
|
||||||
|
|
||||||
|
from lib.backend import Backend, LogAction
|
||||||
|
from lib.data_types import EndpointHandler
|
||||||
|
from lib.server import start_server
|
||||||
|
from .data_types import InputData
|
||||||
|
|
||||||
|
# the url and port of model API
|
||||||
|
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||||
|
|
||||||
|
|
||||||
|
# This is the log line that is emitted once the server has started
|
||||||
|
MODEL_SERVER_START_LOG_MSG = "infer server has started"
|
||||||
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
|
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
# This class is the implementer for the '/generate' endpoint of model API
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GenerateHandler(EndpointHandler[InputData]):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
# the API endpoint
|
||||||
|
return "/generate"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[InputData]:
|
||||||
|
return InputData
|
||||||
|
|
||||||
|
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
defines how to convert `InputData` defined above, to
|
||||||
|
json data to be sent to the model API
|
||||||
|
"""
|
||||||
|
return dataclasses.asdict(payload)
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> InputData:
|
||||||
|
"""
|
||||||
|
defines how to generate an InputData for benchmarking. This needs to be defined in only
|
||||||
|
one EndpointHandler, the one passed to the backend as the benchmark handler
|
||||||
|
"""
|
||||||
|
return InputData.for_test()
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
"""
|
||||||
|
defines how to convert a model API response to a response to PyWorker client
|
||||||
|
"""
|
||||||
|
_ = client_request
|
||||||
|
match model_response.status:
|
||||||
|
case 200:
|
||||||
|
log.debug("SUCCESS")
|
||||||
|
data = await model_response.json()
|
||||||
|
return web.json_response(data=data)
|
||||||
|
case code:
|
||||||
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||||
|
return web.Response(status=code)
|
||||||
|
|
||||||
|
|
||||||
|
# This is the same as GenerateHandler, except that it calls a streaming endpoint of the model API and streams the
|
||||||
|
# response, which itself is streaming, back to the client.
|
||||||
|
# it is nearly identical to handler as above, but it calls a different model API endpoint and it streams the
|
||||||
|
# streaming response from model API to client
|
||||||
|
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
return "/generate_stream"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[InputData]:
|
||||||
|
return InputData
|
||||||
|
|
||||||
|
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||||
|
return dataclasses.asdict(payload)
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> InputData:
|
||||||
|
return InputData.for_test()
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
match model_response.status:
|
||||||
|
case 200:
|
||||||
|
log.debug("Streaming response...")
|
||||||
|
res = web.StreamResponse()
|
||||||
|
res.content_type = "text/event-stream"
|
||||||
|
await res.prepare(client_request)
|
||||||
|
async for chunk in model_response.content:
|
||||||
|
await res.write(chunk)
|
||||||
|
await res.write_eof()
|
||||||
|
log.debug("Done streaming response")
|
||||||
|
return res
|
||||||
|
case code:
|
||||||
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||||
|
return web.Response(status=code)
|
||||||
|
|
||||||
|
|
||||||
|
# This is the backend instance of pyworker. Only one must be made which uses EndpointHandlers to process
|
||||||
|
# incoming requests
|
||||||
|
backend = Backend(
|
||||||
|
model_server_url=MODEL_SERVER_URL,
|
||||||
|
model_log_file=os.environ["MODEL_LOG"],
|
||||||
|
allow_parallel_requests=True,
|
||||||
|
# give the backend a handler instance that is used for benchmarking
|
||||||
|
# number of benchmark run and number of words for a random benchmark run are given
|
||||||
|
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||||
|
# defines how to handle specific log messages. See docstring of LogAction for details
|
||||||
|
log_actions=[
|
||||||
|
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||||
|
(LogAction.Info, '"message":"Download'),
|
||||||
|
*[
|
||||||
|
(LogAction.ModelError, error_msg)
|
||||||
|
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# this is a simple ping handler for pyworker
|
||||||
|
async def handle_ping(_: web.Request):
|
||||||
|
return web.Response(body="pong")
|
||||||
|
|
||||||
|
|
||||||
|
# this is a handler for forwarding a health check to modelAPI
|
||||||
|
async def handle_healthcheck(_: web.Request):
|
||||||
|
healthcheck_res = await backend.session.get("/healthcheck")
|
||||||
|
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
|
||||||
|
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||||
|
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||||
|
web.get("/ping", handle_ping),
|
||||||
|
web.get("/healthcheck", handle_healthcheck),
|
||||||
|
]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# start the PyWorker server
|
||||||
|
start_server(backend, routes)
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from lib.test_utils import test_load_cmd, test_args
|
||||||
|
from .data_types import InputData
|
||||||
|
|
||||||
|
WORKER_ENDPOINT = "/generate"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
|
||||||
|
|
||||||
|
1. `generate`: Generates the LLM's response to a given prompt in a single request.
|
||||||
|
2. `generate_stream`: Streams the LLM's response token by token.
|
||||||
|
|
||||||
|
Both endpoints use the following API payload format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"inputs": "PROMPT",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 250
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
|
||||||
|
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
|
||||||
|
approximately 2 seconds to complete.
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
import sys
|
||||||
|
import json
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
|
||||||
|
WORKER_ENDPOINT = "/generate"
|
||||||
|
COST = 100
|
||||||
|
route_payload = {
|
||||||
|
"endpoint": endpoint_group_name,
|
||||||
|
"api_key": api_key,
|
||||||
|
"cost": COST,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
urljoin(server_url, "/route/"),
|
||||||
|
json=route_payload,
|
||||||
|
timeout=4,
|
||||||
|
)
|
||||||
|
message = response.json()
|
||||||
|
url = message["url"]
|
||||||
|
auth_data = dict(
|
||||||
|
signature=message["signature"],
|
||||||
|
cost=message["cost"],
|
||||||
|
endpoint=message["endpoint"],
|
||||||
|
reqnum=message["reqnum"],
|
||||||
|
url=message["url"],
|
||||||
|
)
|
||||||
|
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
|
||||||
|
req_data = dict(payload=payload, auth_data=auth_data)
|
||||||
|
url = urljoin(url, WORKER_ENDPOINT)
|
||||||
|
print(f"url: {url}")
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=req_data,
|
||||||
|
)
|
||||||
|
res = response.json()
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
|
def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str):
|
||||||
|
WORKER_ENDPOINT = "/generate_stream"
|
||||||
|
COST = 100
|
||||||
|
route_payload = {
|
||||||
|
"endpoint": endpoint_group_name,
|
||||||
|
"api_key": api_key,
|
||||||
|
"cost": COST,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
urljoin(server_url, "/route/"),
|
||||||
|
json=route_payload,
|
||||||
|
timeout=4,
|
||||||
|
)
|
||||||
|
message = response.json()
|
||||||
|
url = message["url"]
|
||||||
|
print(f"url: {url}")
|
||||||
|
auth_data = dict(
|
||||||
|
signature=message["signature"],
|
||||||
|
cost=message["cost"],
|
||||||
|
endpoint=message["endpoint"],
|
||||||
|
reqnum=message["reqnum"],
|
||||||
|
url=message["url"],
|
||||||
|
)
|
||||||
|
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
|
||||||
|
req_data = dict(payload=payload, auth_data=auth_data)
|
||||||
|
url = urljoin(url, WORKER_ENDPOINT)
|
||||||
|
response = requests.post(url, json=req_data, stream=True)
|
||||||
|
for line in response.iter_lines():
|
||||||
|
payload = line.decode().lstrip("data:").rstrip()
|
||||||
|
if payload:
|
||||||
|
data = json.loads(payload)
|
||||||
|
print(data["token"]["text"], end="")
|
||||||
|
sys.stdout.flush()
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from lib.test_utils import test_args
|
||||||
|
|
||||||
|
args = test_args.parse_args()
|
||||||
|
call_generate(
|
||||||
|
api_key=args.api_key,
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
|
server_url=args.server_url,
|
||||||
|
)
|
||||||
|
call_generate_stream(
|
||||||
|
api_key=args.api_key,
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
|
server_url=args.server_url,
|
||||||
|
)
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
import dataclasses
|
||||||
|
import random
|
||||||
|
import inspect
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
from lib.data_types import ApiPayload, JsonDataException
|
||||||
|
|
||||||
|
nltk.download("words")
|
||||||
|
WORD_LIST = nltk.corpus.words.words()
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class InputParameters:
|
||||||
|
max_new_tokens: int = 256
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputParameters":
|
||||||
|
errors = {}
|
||||||
|
for param in inspect.signature(cls).parameters:
|
||||||
|
if param not in json_msg:
|
||||||
|
errors[param] = "missing parameter"
|
||||||
|
if errors:
|
||||||
|
raise JsonDataException(errors)
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in json_msg.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class InputData(ApiPayload):
|
||||||
|
inputs: str
|
||||||
|
parameters: InputParameters
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "InputData":
|
||||||
|
return cls(
|
||||||
|
inputs=data["inputs"], parameters=InputParameters(**data["parameters"])
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_test(cls) -> "InputData":
|
||||||
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||||
|
return cls(inputs=prompt, parameters=InputParameters())
|
||||||
|
|
||||||
|
def generate_payload_json(self) -> Dict[str, Any]:
|
||||||
|
return dataclasses.asdict(self)
|
||||||
|
|
||||||
|
def count_workload(self) -> int:
|
||||||
|
return self.parameters.max_new_tokens
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||||
|
errors = {}
|
||||||
|
for param in inspect.signature(cls).parameters:
|
||||||
|
if param not in json_msg:
|
||||||
|
errors[param] = "missing parameter"
|
||||||
|
if errors:
|
||||||
|
raise JsonDataException(errors)
|
||||||
|
try:
|
||||||
|
parameters = InputParameters.from_json_msg(json_msg["parameters"])
|
||||||
|
return cls(inputs=json_msg["inputs"], parameters=parameters)
|
||||||
|
except JsonDataException as e:
|
||||||
|
errors["parameters"] = e.message
|
||||||
|
raise JsonDataException(errors)
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Union, Type
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
from aiohttp import web, ClientResponse
|
||||||
|
|
||||||
|
from lib.backend import Backend, LogAction
|
||||||
|
from lib.data_types import EndpointHandler
|
||||||
|
from lib.server import start_server
|
||||||
|
from .data_types import InputData
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||||
|
|
||||||
|
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||||
|
MODEL_SERVER_START_LOG_MSG = '"message":"Connected","target":"text_generation_router"'
|
||||||
|
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError"]
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GenerateHandler(EndpointHandler[InputData]):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
return "/generate"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[InputData]:
|
||||||
|
return InputData
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> InputData:
|
||||||
|
return InputData.for_test()
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
_ = client_request
|
||||||
|
match model_response.status:
|
||||||
|
case 200:
|
||||||
|
log.debug("SUCCESS")
|
||||||
|
data = await model_response.json()
|
||||||
|
return web.json_response(data=data)
|
||||||
|
case code:
|
||||||
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||||
|
return web.Response(status=code)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
return "/generate_stream"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[InputData]:
|
||||||
|
return InputData
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> InputData:
|
||||||
|
return InputData.for_test()
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
match model_response.status:
|
||||||
|
case 200:
|
||||||
|
log.debug("Streaming response...")
|
||||||
|
res = web.StreamResponse()
|
||||||
|
res.content_type = "text/event-stream"
|
||||||
|
await res.prepare(client_request)
|
||||||
|
async for chunk in model_response.content:
|
||||||
|
await res.write(chunk)
|
||||||
|
await res.write_eof()
|
||||||
|
log.debug("Done streaming response")
|
||||||
|
return res
|
||||||
|
case code:
|
||||||
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||||
|
return web.Response(status=code)
|
||||||
|
|
||||||
|
|
||||||
|
backend = Backend(
|
||||||
|
model_server_url=MODEL_SERVER_URL,
|
||||||
|
model_log_file=os.environ["MODEL_LOG"],
|
||||||
|
allow_parallel_requests=True,
|
||||||
|
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||||
|
log_actions=[
|
||||||
|
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||||
|
(LogAction.Info, '"message":"Download'),
|
||||||
|
*[
|
||||||
|
(LogAction.ModelError, error_msg)
|
||||||
|
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_ping(_):
|
||||||
|
return web.Response(body="pong")
|
||||||
|
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||||
|
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||||
|
web.get("/ping", handle_ping),
|
||||||
|
]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start_server(backend, routes)
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from lib.test_utils import test_load_cmd, test_args
|
||||||
|
from .data_types import InputData
|
||||||
|
|
||||||
|
WORKER_ENDPOINT = "/generate"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||||
Reference in New Issue
Block a user