Merge pull request #1 from Nader-gator/main

add pyworker v2
This commit is contained in:
Nader Arbabian
2024-09-04 11:53:45 -07:00
committed by Nader Arbabian
31 changed files with 3000 additions and 1 deletions
+3
View File
@@ -0,0 +1,3 @@
.direnv
.envrc
__pycache__
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Vast.ai
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+14 -1
View File
@@ -1 +1,14 @@
# pyworker # Vast PyWorker
Vast PyWorker is a Python web server designed to run alongside a LLM or image generation models running on vast,
enabling autoscaler integration.
It serves as the primary entry point for API requests, forwarding them to the model's API hosted on the
same instance. Additionally, it monitors performance metrics and estimates current workload based on factors
such as the number of tokens processed for LLMs or image resolution and steps for image generation models,
reporting these metrics to the autoscaler.
## How to Use
If you want to use autoscaler, you just need to use one of Vast's autoscaler templates. If you'd like to
implement PyWorker for a template that is not marked as autoscaler compatible on Vast, refer to
`workers/hello_world/README.md`
View File
+327
View File
@@ -0,0 +1,327 @@
import os
import json
import time
import base64
import subprocess
import dataclasses
import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable
from functools import cached_property
from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
import requests
from Crypto.Signature import pkcs1_15
from Crypto.Hash import SHA256
from Crypto.PublicKey import RSA
from lib.metrics import Metrics
from lib.data_types import (
AuthData,
EndpointHandler,
LogAction,
ApiPayload_T,
JsonDataException,
)
MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
# defines the minimum wait time between sending updates to autoscaler
LOG_POLL_INTERVAL = 0.1
BENCHMARK_INDICATOR_FILE = ".has_benchmark"
@dataclasses.dataclass
class Backend:
"""
This class is responsible for:
1. Tailing logs and updating load time metrics
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
sending the request. It also updates metrics as it makes those requests.
3. Running a benchmark from an EndpointHandler
"""
model_server_url: str
model_log_file: str
allow_parallel_requests: bool
benchmark_handler: (
EndpointHandler # this endpoint handler will be used for benchmarking
)
log_actions: List[Tuple[LogAction, str]]
reqnum = -1
msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
def __post_init__(self):
def fetch_public_key():
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = None
for _ in range(5):
try:
key = RSA.import_key(result)
break
except ValueError as e:
log.debug(f"Error downloading key: {e}")
time.sleep(15)
return key
###########
self.PUBLIC_KEY = fetch_public_key()
self.metrics = Metrics()
@cached_property
def session(self):
log.debug(f"starting session with {self.model_server_url}")
return ClientSession(self.model_server_url)
def create_handler(
self,
handler: EndpointHandler[ApiPayload_T],
) -> Callable[[web.Request], Awaitable[Union[web.Response, web.StreamResponse]]]:
async def handler_fn(
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
return await self.__handle_request(handler=handler, request=request)
return handler_fn
#######################################Private#######################################
async def __handle_request(
self,
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""use this function to forward requests to the model endpoint"""
try:
data = await request.json()
auth_data, payload = handler.get_data_from_request(data)
except JsonDataException as e:
return web.json_response(data=e.message, status=422)
except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload()
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
await self.sem.acquire()
log.debug(
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
try:
start_time = time.time()
response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status
log.debug(
" ".join(
[
f"request with reqnum:{auth_data.reqnum}",
f"returned status code: {status_code},",
]
)
)
res = await handler.generate_client_response(request, response)
self.metrics._request_end(
workload=workload,
req_response_time=time.time() - start_time,
reqnum=auth_data.reqnum,
)
return res
except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(
workload=workload, reqnum=auth_data.reqnum
)
return web.Response(status=500)
finally:
self.sem.release()
###########
if self.__check_signature(auth_data) is False:
return web.Response(status=401)
try:
done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
[task.cancel() for task in pending]
return done.pop().result()
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
async def _start_tracking(self) -> None:
await gather(self.__read_logs(), self.metrics._send_metrics_loop())
def backend_errored(self, msg: str) -> None:
self.metrics._model_errored(msg)
async def __call_api(
self, handler: EndpointHandler[ApiPayload_T], payload: ApiPayload_T
) -> ClientResponse:
api_payload = payload.generate_payload_json()
log.debug(f"posting to endpoint: '{handler.endpoint}', payload: {api_payload}")
return await self.session.post(url=handler.endpoint, json=api_payload)
def __check_signature(self, auth_data: AuthData) -> bool:
def verify_signature(message, signature):
if self.PUBLIC_KEY is None:
log.debug(f"No Public Key!")
return False
h = SHA256.new(message.encode())
try:
pkcs1_15.new(self.PUBLIC_KEY).verify(h, base64.b64decode(signature))
return True
except (ValueError, TypeError):
return False
message = {
key: value
for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature"
}
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug(
f"reqnum failure, got {auth_data.reqnum}, current_reqnum: {self.reqnum}"
)
return False
elif message in self.msg_history:
log.debug(f"message: {message} already in message history")
return False
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
return True
else:
log.debug(
f"signature verification failed, sig:{auth_data.signature}, message: {message}"
)
return False
async def __read_logs(self) -> Awaitable[NoReturn]:
async def run_benchmark() -> float:
log.debug("starting benchmark")
try:
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark")
# trigger model load
payload = self.benchmark_handler.make_benchmark_payload()
_ = await self.__call_api(
handler=self.benchmark_handler, payload=payload
)
return float(f.readline())
except FileNotFoundError:
pass
max_throughput = 0
last_throughput = 0
sum_throughput = 0
for run in range(self.benchmark_handler.benchmark_runs + 1):
start = time.time()
payload = self.benchmark_handler.make_benchmark_payload()
res = await self.__call_api(
handler=self.benchmark_handler, payload=payload
)
data = await res.json()
time_elapsed = time.time() - start
# first run triggers one-time loading of the model which is very slow, so we skip counting it
if run == 0:
continue
else:
workload = payload.count_workload()
last_throughput = workload / time_elapsed
sum_throughput += last_throughput
max_throughput = max(max_throughput, last_throughput)
log.debug(
"\n".join(
[
"#" * 60,
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
"",
f"response: {data}",
"#" * 60,
]
)
)
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
log.debug(
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
)
# save max_throughput so we don't have to run benchmark again on restart of cold instances
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
f.write(str(max_throughput))
return max_throughput
async def handle_log_line(log_line: str) -> None:
"""
Implement this function to handle each log line for your model.
This function should mutate self.system_metrics and self.model_metrics
"""
for action, msg in self.log_actions:
match action:
case LogAction.ModelLoaded if msg in log_line:
log.debug(
f"Got log line indicating model is loaded: {log_line}"
)
# some backends need a few seconds after logging successful startup before
# they can begin accepting requests
await sleep(5)
try:
max_throughput = await run_benchmark()
self.metrics._model_loaded(
max_throughput=max_throughput,
)
except ClientConnectorError as e:
log.debug(
f"failed to connect to comfyui api during benchmark"
)
self.backend_errored(str(e))
case LogAction.ModelError if msg in log_line:
log.debug(f"Got log line indicating error: {log_line}")
self.backend_errored(msg)
break
case LogAction.Info if msg in log_line:
log.debug(f"Info from model logs: {log_line}")
async def tail_log():
log.debug(f"tailing file: {self.model_log_file}")
async with await open_file(self.model_log_file) as f:
while True:
line = await f.readline()
if line:
await handle_log_line(line.rstrip())
else:
time.sleep(LOG_POLL_INTERVAL)
###########
while True:
if os.path.isfile(self.model_log_file) is True:
return await tail_log()
else:
await sleep(1)
+269
View File
@@ -0,0 +1,269 @@
import time
import logging
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
from aiohttp import web, ClientResponse
import inspect
import psutil
"""
type variable representing an incoming payload to pyworker that will used to calculate load and will then
be forwarded to the model
"""
log = logging.getLogger(__file__)
class JsonDataException(Exception):
def __init__(self, json_msg: Dict[str, Any]):
self.message = json_msg
@dataclass
class ApiPayload(ABC):
@classmethod
@abstractmethod
def for_test(cls) -> "ApiPayload":
"""defines how create a payload for load testing"""
pass
@abstractmethod
def generate_payload_json(self) -> Dict[str, Any]:
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
pass
@abstractmethod
def count_workload(self) -> float:
"""defines how to calculate workload for a payload"""
pass
@classmethod
@abstractmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ApiPayload":
"""
defines how to create an API payload from a JSON message,
it should throw an JsonDataException if there are issues with some fields
or they are missing in the format of
{
"field": "error msg"
}
"""
pass
@dataclass
class AuthData:
"""data used to authenticate requester"""
signature: str
cost: str
endpoint: str
reqnum: int
url: str
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]):
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
ApiPayload_T = TypeVar("ApiPayload_T", bound=ApiPayload)
@dataclass
class EndpointHandler(ABC, Generic[ApiPayload_T]):
"""
Each model endpoint will have a handler responsible for counting workload from the incoming ApiPayload
and converting it to json to be forwarded to model API
"""
benchmark_runs: int = 8
benchmark_words: int = 100
@property
@abstractmethod
def endpoint(self) -> str:
"""the endpoint on the model API"""
pass
@classmethod
@abstractmethod
def payload_cls(cls) -> Type[ApiPayload_T]:
"""ApiPayload class"""
pass
@abstractmethod
def make_benchmark_payload(self) -> ApiPayload_T:
"""defines how to create an ApiPayload for benchmarking."""
pass
@abstractmethod
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
"""
defines how to convert a model API response to a response to PyWorker client
"""
pass
@classmethod
def get_data_from_request(
cls, req_data: Dict[str, Any]
) -> Tuple[AuthData, ApiPayload_T]:
errors = {}
auth_data = payload = None
try:
if "auth_data" in req_data:
auth_data = AuthData.from_json_msg(req_data["auth_data"])
else:
errors["auth_data"] = "field missing"
except JsonDataException as e:
errors["auth_data"] = e.message
try:
if "payload" in req_data:
payload = cls.payload_cls().from_json_msg(req_data["payload"])
else:
errors["payload"] = "field missing"
except JsonDataException as e:
errors["payload"] = e.message
if errors:
raise JsonDataException(errors)
if auth_data and payload:
return (auth_data, payload)
else:
raise Exception("error deserializing request data")
@dataclass
class SystemMetrics:
"""General system metrics"""
model_loading_start: float
model_loading_time: Union[float, None]
last_disk_usage: float
additional_disk_usage: float
model_is_loaded: bool
@staticmethod
def get_disk_usage_GB():
return psutil.disk_usage("/").used / (2**30) # want units of GB
@classmethod
def empty(cls):
return cls(
model_loading_start=time.time(),
model_loading_time=None,
last_disk_usage=SystemMetrics.get_disk_usage_GB(),
additional_disk_usage=0.0,
model_is_loaded=False,
)
def update_disk_usage(self):
disk_usage = SystemMetrics.get_disk_usage_GB()
self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage
def reset(self):
# autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading
self.model_loading_time = None
@dataclass
class ModelMetrics:
"""Model specific metrics"""
# these are reset after being sent to autoscaler
workload_served: float
workload_received: float
workload_cancelled: float
workload_errored: float
workload_pending: float
# these are not
cur_perf: float
error_msg: Optional[str]
max_throughput: float
requests_recieved: Set[int] = field(default_factory=set)
requests_working: Set[int] = field(default_factory=set)
@classmethod
def empty(cls):
return cls(
workload_pending=0.0,
workload_served=0.0,
workload_cancelled=0.0,
workload_errored=0.0,
cur_perf=0.0,
workload_received=0.0,
error_msg=None,
max_throughput=0.0,
)
@property
def workload_processing(self) -> float:
return max(self.workload_received - self.workload_cancelled, 0.0)
def set_errored(self, error_msg):
self.reset()
self.error_msg = error_msg
def reset(self):
self.workload_served = 0
self.workload_received = 0
self.workload_cancelled = 0
self.workload_errored = 0
@dataclass
class AutoScalaerData:
"""Data that is reported to autoscaler"""
id: int
loadtime: float
cur_load: float
error_msg: str
max_perf: float
cur_perf: float
cur_capacity: float
max_capacity: float
num_requests_working: int
num_requests_recieved: int
additional_disk_usage: float
url: str
class LogAction(Enum):
"""
These actions tell the backend what a log value means, for example:
actions [
# this marks the model server as loaded
(LogAction.ModelLoaded, "Starting server"),
# these mark the model server as errored
(LogAction.ModelError, "Exception loading model"),
(LogAction.ModelError, "Server failed to bind to port"),
# this tells the backend to print any logs containing the string into its own logs
# which are visible in the vast console instance logs
(LogAction.Info, "Starting model download"),
]
"""
ModelLoaded = 1
ModelError = 2
Info = 3
+153
View File
@@ -0,0 +1,153 @@
import os
import time
import logging
import json
from asyncio import sleep
from dataclasses import dataclass, asdict, field
from functools import cache
from urllib.parse import urljoin
import requests
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
from typing import Awaitable, NoReturn, List
METRICS_UPDATE_INTERVAL = 1
log = logging.getLogger(__file__)
@cache
def get_url() -> str:
use_ssl = os.environ.get("USE_SSL", "false") == "true"
worker_port = os.environ[f"VAST_TCP_PORT_{os.environ['WORKER_PORT']}"]
public_ip = os.environ["PUBLIC_IPADDR"]
return f"http{'s' if use_ssl else ''}://{public_ip}:{worker_port}"
@dataclass
class Metrics:
last_metric_update: float = 0.0
update_pending: bool = False
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
report_addr: List[str] = field(
default_factory=lambda: os.environ["REPORT_ADDR"].split(",")
)
url: str = field(default_factory=get_url)
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
def _request_start(self, workload: float, reqnum: int) -> None:
"""
this function is called prior to forwarding a request to a model API.
"""
log.debug("request start")
self.model_metrics.workload_pending += workload
self.model_metrics.workload_received += workload
self.model_metrics.requests_recieved.add(reqnum)
self.model_metrics.requests_working.add(reqnum)
def _request_end(
self, workload: float, req_response_time: float, reqnum: int
) -> None:
"""
this function is called after a response from model API is received.
"""
self.model_metrics.workload_served += workload
self.model_metrics.workload_pending -= workload
self.model_metrics.requests_working.discard(reqnum)
self.model_metrics.cur_perf = workload / req_response_time
self.update_pending = True
def _request_errored(self, workload: float, reqnum: int) -> None:
"""
this function is called if model API returns an error
"""
self.model_metrics.workload_pending -= workload
self.model_metrics.workload_errored += workload
self.model_metrics.requests_working.discard(reqnum)
def _request_canceled(self, workload: float, reqnum: int) -> None:
"""
this function is called if client drops connection before model API has responded
"""
self.model_metrics.workload_pending -= workload
self.model_metrics.workload_cancelled += workload
self.model_metrics.requests_working.discard(reqnum)
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
while True:
await sleep(METRICS_UPDATE_INTERVAL)
elapsed = time.time() - self.last_metric_update
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed)
elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed)
def _model_loaded(self, max_throughput: float) -> None:
self.system_metrics.model_loading_time = (
time.time() - self.system_metrics.model_loading_start
)
self.system_metrics.model_is_loaded = True
self.model_metrics.max_throughput = max_throughput
def _model_errored(self, error_msg: str) -> None:
self.model_metrics.set_errored(error_msg)
self.system_metrics.model_is_loaded = True
#######################################Private#######################################
def __send_metrics_and_reset(self, elapsed):
def compute_autoscaler_data() -> AutoScalaerData:
return AutoScalaerData(
id=self.id,
loadtime=(self.system_metrics.model_loading_time or 0.0),
cur_load=(self.model_metrics.workload_processing / elapsed),
max_perf=self.model_metrics.max_throughput,
cur_perf=self.model_metrics.cur_perf,
error_msg=self.model_metrics.error_msg or "",
num_requests_working=len(self.model_metrics.requests_working),
num_requests_recieved=len(self.model_metrics.requests_recieved),
additional_disk_usage=self.system_metrics.additional_disk_usage,
cur_capacity=0,
max_capacity=0,
url=self.url,
)
def send_data(report_addr: str) -> None:
data = compute_autoscaler_data()
full_path = urljoin(report_addr, "/worker_status/")
log.debug(
"\n".join(
[
"#" * 60,
f"sending data to autoscaler",
f"{json.dumps((asdict(data)), indent=2)}",
"#" * 60,
]
)
)
for attempt in range(1, 4):
try:
requests.post(full_path, json=asdict(data), timeout=1)
break
except requests.Timeout:
log.debug(f"autoscaler status update timed out")
except Exception as e:
log.debug(f"autoscaler status update failed with error: {e}")
time.sleep(2)
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
###########
self.system_metrics.update_disk_usage()
for report_addr in self.report_addr:
send_data(report_addr)
self.update_pending = False
self.model_metrics.reset()
self.system_metrics.reset()
self.last_metric_update = time.time()
+40
View File
@@ -0,0 +1,40 @@
import os
import logging
from typing import List
import ssl
from asyncio import run, gather
from lib.backend import Backend
from aiohttp import web
log = logging.getLogger(__file__)
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("getting certificate...")
use_ssl = os.environ.get("USE_SSL", "false") == "true"
if use_ssl is True:
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(
certfile="/etc/instance.crt",
keyfile="/etc/instance.key",
)
else:
ssl_context = None
async def main():
log.debug("starting server...")
app = web.Application()
app.add_routes(routes)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(
runner,
ssl_context=ssl_context,
port=int(os.environ["WORKER_PORT"]),
**kwargs
)
await gather(site.start(), backend._start_tracking())
run(main())
+267
View File
@@ -0,0 +1,267 @@
import os
import time
import argparse
from typing import Callable, List, Dict, Tuple, Dict, Any, Type
from time import sleep
import threading
from enum import Enum
from collections import Counter
from dataclasses import dataclass, field, asdict
from urllib.parse import urljoin
import requests
from lib.data_types import AuthData, ApiPayload
class ClientStatus(Enum):
FetchEndpoint = 1
Generating = 2
Done = 3
Error = 4
total_success = 0
last_res = []
stop_event = threading.Event()
start_time = time.time()
test_args = argparse.ArgumentParser(description="Test inference endpoint")
test_args.add_argument(
"-k", dest="api_key", type=str, required=True, help="Your vast account API key"
)
test_args.add_argument(
"-e",
dest="endpoint_group_name",
type=str,
required=True,
help="Endpoint group name",
)
test_args.add_argument(
"-l",
dest="server_url",
action="store_const",
const="http://localhost:8081",
default="https://run.vast.ai",
help="Call local autoscaler instead of prod, for dev use only",
)
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
def print_truncate_res(res: str):
if len(res) > 150:
print(f"{res[:50]}....{res[-100:]}")
else:
print(res)
@dataclass
class ClientState:
endpoint_group_name: str
api_key: str
server_url: str
worker_endpoint: str
payload: ApiPayload
url: str = ""
status: ClientStatus = ClientStatus.FetchEndpoint
as_error: List[str] = field(default_factory=list)
infer_error: List[str] = field(default_factory=list)
conn_errors: Counter = field(default_factory=Counter)
def make_call(self):
self.status = ClientStatus.FetchEndpoint
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.api_key,
"cost": self.payload.count_workload(),
}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
timeout=4,
)
if response.status_code != 200:
self.as_error.append(
f"code: {response.status_code}, body: {response.text}",
)
self.status = ClientStatus.Error
return
message = response.json()
worker_address = message["url"]
req_data = dict(
payload=asdict(self.payload),
auth_data=asdict(AuthData.from_json_msg(message)),
)
self.url = worker_address
url = urljoin(worker_address, self.worker_endpoint)
self.status = ClientStatus.Generating
response = requests.post(
url,
json=req_data,
)
if response.status_code != 200:
self.infer_error.append(
f"code: {response.status_code}, body: {response.text}, url: {url}",
)
self.status = ClientStatus.Error
return
res = str(response.json())
global total_success
global last_res
total_success += 1
last_res.append(res)
self.status = ClientStatus.Done
def simulate_user(self) -> None:
try:
self.make_call()
except Exception as e:
self.status = ClientStatus.Error
_ = e
self.conn_errors[self.url] += 1
def print_state(clients: List[ClientState], num_clients: int) -> None:
print("starting up...")
sleep(2)
center_size = 14
global start_time
while len(clients) < num_clients or (
any(
map(
lambda client: client.status
in [ClientStatus.FetchEndpoint, ClientStatus.Generating],
clients,
)
)
):
sleep(0.5)
os.system("clear")
print(
" | ".join(
[member.name.center(center_size) for member in ClientStatus]
+ [
item.center(center_size)
for item in [
"urls",
"as_error",
"infer_error",
"conn_error",
"total_success",
]
]
)
)
unique_urls = len(set([c.url for c in clients if c.url != ""]))
as_errors = sum(
map(
lambda client: len(client.as_error),
[client for client in clients],
)
)
infer_errors = sum(
map(
lambda client: len(client.infer_error),
[client for client in clients],
)
)
conn_errors = sum([client.conn_errors for client in clients], start=Counter())
conn_errors_str = ",".join(map(str, conn_errors.values())) or "0"
elapsed = time.time() - start_time
print(
" | ".join(
map(
lambda item: str(item).center(center_size),
[
len(list(filter(lambda x: x.status == member, clients)))
for member in ClientStatus
]
+ [
unique_urls,
as_errors,
infer_errors,
conn_errors_str,
f"{total_success}({((total_success/elapsed) * 60):.2f}/minute)",
],
)
)
)
if conn_errors:
print("conn_errors:")
for url, count in conn_errors.items():
print(url.ljust(28), ": ", str(count))
elapsed = time.time() - start_time
print(f"\n elapsed: {int(elapsed // 60)}:{int(elapsed % 60)}")
if last_res:
for i, res in enumerate(last_res[-10:]):
print_truncate_res(f"res #{1+i+max(len(last_res )-10,0)}: {res}")
if stop_event.is_set():
print("\n### waiting for existing connections to close ###")
def run_test(
num_requests: int,
requests_per_second: int,
endpoint_group_name: str,
api_key: str,
server_url: str,
worker_endpoint: str,
payload_cls: Type[ApiPayload],
):
threads = []
clients = []
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
print_thread.daemon = True # makes threads get killed on program exit
print_thread.start()
try:
for _ in range(num_requests):
client = ClientState(
endpoint_group_name=endpoint_group_name,
api_key=api_key,
server_url=server_url,
worker_endpoint=worker_endpoint,
payload=payload_cls.for_test(),
)
clients.append(client)
thread = threading.Thread(target=client.simulate_user, args=())
threads.append(thread)
thread.start()
sleep(1 / requests_per_second)
for thread in threads:
thread.join()
print("done spawning workers")
except KeyboardInterrupt:
stop_event.set()
def test_load_cmd(
payload_cls: Type[ApiPayload], endpoint: str, arg_parser: argparse.ArgumentParser
):
arg_parser.add_argument(
"-n",
dest="num_requests",
type=int,
required=True,
help="total number of requests",
)
arg_parser.add_argument(
"-rps",
dest="requests_per_second",
type=float,
required=True,
help="requests per second",
)
args = arg_parser.parse_args()
if hasattr(args, "comfy_model"):
os.environ["COMFY_MODEL"] = args.comfy_model
run_test(
num_requests=args.num_requests,
requests_per_second=args.requests_per_second,
api_key=args.api_key,
server_url=args.server_url,
endpoint_group_name=args.endpoint_group_name,
worker_endpoint=endpoint,
payload_cls=payload_cls,
)
+51
View File
@@ -0,0 +1,51 @@
aiofiles==24.1.0
aiohappyeyeballs==2.3.4
aiohttp==3.10.0
aiojobs==1.2.1
aiosignal==1.3.1
anyio==4.4.0
attrs==23.2.0
blinker==1.8.2
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
cx_Freeze==7.1.1
filelock==3.15.4
Flask==3.0.3
frozenlist==1.4.1
fsspec==2024.6.1
gitignore_parser==0.1.11
gunicorn==22.0.0
hf_transfer==0.1.8
huggingface-hub==0.24.2
idna==3.7
itsdangerous==2.2.0
Jinja2==3.1.4
joblib==1.4.2
MarkupSafe==2.1.5
multidict==6.0.5
nltk==3.8.1
Nuitka==2.3.11
numpy==2.0.0
ordered-set==4.1.0
packaging==24.1
patchelf==0.17.2.1
psutil==6.0.0
pycryptodome==3.20.0
PyYAML==6.0.1
regex==2024.5.15
requests==2.32.3
safetensors==0.4.3
setuptools==70.3.0
sniffio==1.3.1
tiktoken==0.7.0
token-count==0.2.1
tokenizers==0.19.1
tqdm==4.66.4
transformers==4.43.2
typing_extensions==4.12.2
urllib3==2.2.2
Werkzeug==3.0.3
wheel==0.43.0
yarl==1.9.4
zstandard==0.22.0
+118
View File
@@ -0,0 +1,118 @@
#!/bin/bash
set -e -o pipefail
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR"
cd "$WORKSPACE_DIR"
# make all output go to $DEBUG_LOG and stdout without having to add `... | tee -a $DEBUG_LOG` to every command
exec &> >(tee -a "$DEBUG_LOG")
function echo_var(){
echo "$1: ${!1}"
}
[ -z "$BACKEND" ] && echo "BACKEND must be set!" && exit 1
[ -z "$MODEL_LOG" ] && echo "MODEL_LOG must be set!" && exit 1
[ -z "$HF_TOKEN" ] && echo "HF_TOKEN must be set!" && exit 1
[ "$BACKEND" = "sd3" ] && [ -z "$COMFY_MODEL" ] && echo "For sd3 backends, COMFY_MODEL must be set!" && exit 1
echo "start_server.sh"
date
echo_var BACKEND
echo_var REPORT_ADDR
echo_var WORKER_PORT
echo_var WORKSPACE_DIR
echo_var SERVER_DIR
echo_var ENV_PATH
echo_var DEBUG_LOG
echo_var PYWORKER_LOG
echo_var MODEL_LOG
env | grep _ >> /etc/environment;
if [ ! -d "$ENV_PATH" ]
then
apt install -y python3.10-venv
echo "setting up venv"
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
python3 -m venv "$WORKSPACE_DIR/worker-env"
source "$WORKSPACE_DIR/worker-env/bin/activate"
pip install -r vast-pyworker/requirements.txt
touch ~/.no_auto_tmux
else
source "$WORKSPACE_DIR/worker-env/bin/activate"
echo "environment activated"
echo "venv: $VIRTUAL_ENV"
fi
[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
if [ "$USE_SSL" = true ]; then
cat << EOF > /etc/openssl-san.cnf
[req]
default_bits = 2048
distinguished_name = req_distinguished_name
req_extensions = v3_req
[req_distinguished_name]
countryName = US
stateOrProvinceName = CA
organizationName = Vast.ai Inc.
commonName = vast.ai
[v3_req]
basicConstraints = CA:FALSE
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
subjectAltName = @alt_names
[alt_names]
IP.1 = 0.0.0.0
EOF
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
-nodes \
-sha256 \
-keyout /etc/instance.key \
-out /etc/instance.csr \
-config /etc/openssl-san.cnf
curl --header 'Content-Type: application/octet-stream' \
--data-binary @//etc/instance.csr \
-X \
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
fi
export REPORT_ADDR WORKER_PORT USE_SSL
cd "$SERVER_DIR"
echo "launching PyWorker server"
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
echo "launching PyWorker server done"
+92
View File
@@ -0,0 +1,92 @@
This is the base PyWorker for comfyui. It can be used to create PyWorker that use various models and
workflows. It provides two endpoints:
1. `/prompt`: Uses the default comfy workflow defined under `misc/default_workflows`
2. `/custom_workflow`: Allows the client to send their own comfy workflow with each API request.
To use the comfyui PyWorker, `$COMFY_MODEL` env variable must be set in the template. Current options are
`sd3` and `flux`. Each have example clients.
To add new models, a JSON with name `$COMFY_MODEL.json` must be created under `misc/default_workflows`
NOTE: default workflows follow this format:
```json
{
"input": {
"handler": "RawWorkflow",
"aws_access_key_id": "your-s3-access-key",
"aws_secret_access_key": "your-s3-secret-access-key",
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
"aws_bucket_name": "your-bucket",
"webhook_url": "your-webhook-url",
"webhook_extra_params": {},
"workflow_json": {}
}
}
```
You can ignore all of these fields except for `workflow_json`.
Fields written as "{{FOO}}" will be replaced using data from a user request. For example, SD3's workflow has the
following nodes:
```json
"5": {
"inputs": {
"width": "{{WIDTH}}",
"height": "{{HEIGHT}}",
"batch_size": 1
},
"6": {
"inputs": {
"text": "{{PROMPT}}",
"clip": ["11", 0]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
...
"17": {
"inputs": {
"scheduler": "simple",
"steps": "{{STEPS}}",
"denoise": 1,
"model": ["12", 0]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
...
"25": {
"inputs": {
"noise_seed": "{{SEED}}"
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
}
```
Incoming requests have the following JSON format:
```json
{
prompt: str
width: int
height: int
steps: int
seed: int
}
```
Each value in those fields with replace the placeholder of the same name in the default workflow.
See Vast's serverless documentation for more details on how to use comfyui with autoscaler
View File
+150
View File
@@ -0,0 +1,150 @@
from urllib.parse import urljoin
import requests
from lib.test_utils import print_truncate_res
"""
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
"""
def call_default_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/prompt"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(
prompt="a fat fluffy cat", width=1024, height=1024, steps=20, seed=123456789
)
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
)
print_truncate_res(str(response.json()))
def call_custom_workflow_for_sd3(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/custom-workflow"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
workflow = {
"3": {
"inputs": {
"seed": 156680208700286,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": ["4", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["5", 0],
},
"class_type": "KSampler",
},
"4": {
"inputs": {"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"},
"class_type": "CheckpointLoaderSimple",
},
"5": {
"inputs": {"width": 512, "height": 512, "batch_size": 1},
"class_type": "EmptyLatentImage",
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, purple galaxy bottle",
"clip": ["4", 1],
},
"class_type": "CLIPTextEncode",
},
"7": {
"inputs": {"text": "text, watermark", "clip": ["4", 1]},
"class_type": "CLIPTextEncode",
},
"8": {
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
"class_type": "VAEDecode",
},
"9": {
"inputs": {"filename_prefix": "ComfyUI", "images": ["8", 0]},
"class_type": "SaveImage",
},
}
# these values should match the values in the custom workflow above,
# they are used to calculate workload
custom_fields = dict(
steps=20,
width=512,
height=512,
)
req_data = dict(
payload=dict(custom_fields=custom_fields, workflow=workflow),
auth_data=auth_data,
)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
)
print_truncate_res(str(response.json()))
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
call_default_workflow(
api_key=args.api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_custom_workflow_for_sd3(
api_key=args.api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
+205
View File
@@ -0,0 +1,205 @@
import sys
import os
import json
import random
import dataclasses
import inspect
from typing import Dict, Any
from functools import cache
from math import ceil
from enum import Enum
from lib.data_types import ApiPayload, JsonDataException
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
class Model(Enum):
Flux = "flux"
Sd3 = "sd3"
def get_request_time(self) -> int:
match self:
case Model.Flux:
return 23
case Model.Sd3:
return 6
@cache
def get_model() -> Model:
match os.environ.get("COMFY_MODEL"):
case "flux":
return Model.Flux
case "sd3":
return Model.Sd3
case None:
raise Exception(
"For comfyui pyworker, $COMFY_MODEL must be set in the vast template"
)
case model:
raise Exception(f"Unsupported comfyui model: {model}")
@cache
def get_request_template() -> str:
with open(f"workers/comfyui/misc/default_workflows/{get_model().value}.json") as f:
return f.read()
def count_workload(width: int, height: int, steps: int) -> float:
"""
we want to normalize the workload is a number such that cur_perf(tokens/second) for 1024x1024 image with
28 steps is 200 tokens on a 4090.
in order get that we calculate the
A = ( absolute workload based on given data )
B = ( absolute workload for a 1024x1024 image with 28 steps )
and adjust the workload to 200 tokens by A/B.
we then adjust for difference between Flux and SD3 by multiplying this value by expected request time for a
standard image(23s for Flux, 6s for SD3).
On a 4090, this would give us a workload that would give a cur_perf(workload / request_time) of around 200
"""
def _calculate_absolute_tokens(width_: int, height_: int, steps_: int) -> float:
"""
This is based on how openai counts image generation tokens, see: https://openai.com/api/pricing/
we count how many 512x512 grids are needed to cover the image.
each tile is then counted as 175 tokens.
each image generation also has constant of 85 base tokens.
we then adjust the count based on the number of steps. The baseline number of steps is assumed to be 28.
Some testing with flux gave me this data:
steps(X) | request time(Y)
__________|_________________
07(0.25x) | 11s (0.47x)
14(0.50x) | 15s (0.65x)
21(0.75x) | 20s (0.86x)
28(1.00x) | 23s (1.00x)
35(1.25x) | 28s (1.21x)
42(1.50x) | 32s (1.39x)
49(1.75x) | 37s (1.60x)
this gives a linear regression of Y = 0.61*X + 6.57
we can use this as an adjustment_factor for token count
adjustment_factor = (0.61 * steps + 6.57)
"""
width_grids = ceil(width_ / 512)
height_grids = ceil(height_ / 512)
tokens = 85 + width_grids * height_grids * 175
adjustment_factor = 0.61 * steps_ + 6.57
return tokens * adjustment_factor
REQUEST_TIME_FOR_STANDARD_IMAGE = get_model().get_request_time()
absolute_tokens = _calculate_absolute_tokens(
width_=width, height_=height, steps_=steps
)
absolute_tokens_standard_image = _calculate_absolute_tokens(
width_=1024, height_=1024, steps_=28
)
return REQUEST_TIME_FOR_STANDARD_IMAGE * (
(absolute_tokens / absolute_tokens_standard_image) * 200
)
@dataclasses.dataclass
class DefaultComfyWorkflowData(ApiPayload):
prompt: str
width: int
height: int
steps: int
seed: int
@classmethod
def for_test(cls):
test_prompt = random.choice(test_prompts).rstrip()
return cls(
prompt=test_prompt,
width=1024,
height=1024,
steps=28,
seed=random.randint(0, sys.maxsize),
)
def generate_payload_json(
self,
) -> Dict[str, Any]:
return json.loads(
get_request_template()
.replace("{{PROMPT}}", self.prompt)
# these values should be of int type. Since "{{VAR}}" is wrapped with " in the template
# to make the JSON valid, we must replace the double quotes. i.e. "{{WIDTH}}" -> 1024 and not "1024"
.replace('"{{WIDTH}}"', str(self.width))
.replace('"{{HEIGHT}}"', str(self.height))
.replace('"{{STEPS}}"', str(self.steps))
.replace('"{{SEED}}"', str(self.seed))
)
def count_workload(self) -> float:
return count_workload(width=self.width, height=self.height, steps=self.steps)
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "DefaultComfyWorkflowData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
@dataclasses.dataclass
class CustomComfyWorkflowData(ApiPayload):
custom_fields: Dict[str, int]
workflow: Dict[str, Any]
@classmethod
def for_test(cls):
raise NotImplemented("Custom comfy workflow is not used for testing")
def count_workload(self) -> float:
return count_workload(
width=int(self.custom_fields.get("width", 1024)),
height=int(self.custom_fields.get("height", 1024)),
steps=int(self.custom_fields.get("steps", 28)),
)
def generate_payload_json(self) -> Dict[str, Any]:
template_json = json.loads(get_request_template())
template_json["input"]["workflow_json"] = self.workflow
return template_json
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "CustomComfyWorkflowData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
@@ -0,0 +1,137 @@
{
"input": {
"handler": "RawWorkflow",
"aws_access_key_id": "your-s3-access-key",
"aws_secret_access_key": "your-s3-secret-access-key",
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
"aws_bucket_name": "your-bucket",
"webhook_url": "your-webhook-url",
"webhook_extra_params": {},
"workflow_json": {
"5": {
"inputs": {
"width": "{{WIDTH}}",
"height": "{{HEIGHT}}",
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "{{PROMPT}}",
"clip": ["11", 0]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": ["13", 0],
"vae": ["10", 0]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": ["8", 0]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"10": {
"inputs": {
"vae_name": "ae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"11": {
"inputs": {
"clip_name1": "t5xxl_fp16.safetensors",
"clip_name2": "clip_l.safetensors",
"type": "flux"
},
"class_type": "DualCLIPLoader",
"_meta": {
"title": "DualCLIPLoader"
}
},
"12": {
"inputs": {
"unet_name": "flux1-dev.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"13": {
"inputs": {
"noise": ["25", 0],
"guider": ["22", 0],
"sampler": ["16", 0],
"sigmas": ["17", 0],
"latent_image": ["5", 0]
},
"class_type": "SamplerCustomAdvanced",
"_meta": {
"title": "SamplerCustomAdvanced"
}
},
"16": {
"inputs": {
"sampler_name": "euler"
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect"
}
},
"17": {
"inputs": {
"scheduler": "simple",
"steps": "{{STEPS}}",
"denoise": 1,
"model": ["12", 0]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
"22": {
"inputs": {
"model": ["12", 0],
"conditioning": ["6", 0]
},
"class_type": "BasicGuider",
"_meta": {
"title": "BasicGuider"
}
},
"25": {
"inputs": {
"noise_seed": "{{SEED}}"
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
}
}
}
}
@@ -0,0 +1,142 @@
{
"input": {
"handler": "RawWorkflow",
"aws_access_key_id": "your-s3-access-key",
"aws_secret_access_key": "your-s3-secret-access-key",
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
"aws_bucket_name": "your-bucket",
"webhook_url": "your-webhook-url",
"webhook_extra_params": {},
"workflow_json": {
"6": {
"inputs": {
"text": "{{PROMPT}}",
"clip": ["252", 1]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"13": {
"inputs": {
"shift": 3,
"model": ["252", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"67": {
"inputs": {
"conditioning": ["71", 0]
},
"class_type": "ConditioningZeroOut",
"_meta": {
"title": "ConditioningZeroOut"
}
},
"68": {
"inputs": {
"start": 0.1,
"end": 1,
"conditioning": ["67", 0]
},
"class_type": "ConditioningSetTimestepRange",
"_meta": {
"title": "ConditioningSetTimestepRange"
}
},
"69": {
"inputs": {
"conditioning_1": ["68", 0],
"conditioning_2": ["70", 0]
},
"class_type": "ConditioningCombine",
"_meta": {
"title": "Conditioning (Combine)"
}
},
"70": {
"inputs": {
"start": 0,
"end": 0.1,
"conditioning": ["71", 0]
},
"class_type": "ConditioningSetTimestepRange",
"_meta": {
"title": "ConditioningSetTimestepRange"
}
},
"71": {
"inputs": {
"text": "bad quality, poor quality, doll, disfigured, jpg, toy, bad anatomy, missing limbs, missing fingers, 3d, cgi",
"clip": ["252", 1]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"135": {
"inputs": {
"width": "{{WIDTH}}",
"height": "{{HEIGHT}}",
"batch_size": 1
},
"class_type": "EmptySD3LatentImage",
"_meta": {
"title": "EmptySD3LatentImage"
}
},
"231": {
"inputs": {
"samples": ["271", 0],
"vae": ["252", 2]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"233": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": ["231", 0]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"252": {
"inputs": {
"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"271": {
"inputs": {
"seed": "{{SEED}}",
"steps": "{{STEPS}}",
"cfg": 4.5,
"sampler_name": "dpmpp_2m",
"scheduler": "sgm_uniform",
"denoise": 1,
"model": ["13", 0],
"positive": ["6", 0],
"negative": ["69", 0],
"latent_image": ["135", 0]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
}
}
}
}
+34
View File
@@ -0,0 +1,34 @@
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
stardew valley, fine details
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
realistic futuristic city-downtown with short buildings, sunset
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
Pope Francis wearing biker (leather jacket), a masterpiece
Luke Skywalker ordering a burger and fries from the Death Star canteen.
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
london luxurious interior living-room, light walls
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
the street of amedieval fantasy town, at dawn, dark, highly detailed
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
+135
View File
@@ -0,0 +1,135 @@
import os
import logging
import dataclasses
import base64
from typing import Union, Type
from aiohttp import web, ClientResponse
from anyio import open_file
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
MODEL_SERVER_URL = "http://0.0.0.0:38188"
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: unet_name", # This error is emitted when the model file is not there at all
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def generate_client_response(
request: web.Request, response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
_ = request
match response.status:
case 200:
log.debug("SUCCESS")
res = await response.json()
if "output" not in res:
return web.json_response(
data=dict(error="there was an error in the workflow"),
status=422,
)
image_paths = [path["local_path"] for path in res["output"]["images"]]
if not image_paths:
return web.json_response(
data=dict(error="workflow did not produce any images"),
status=422,
)
images = []
for image_path in image_paths:
async with await open_file(image_path, mode="rb") as f:
contents = await f.read()
images.append(
f"data:image/png;base64,{base64.b64encode(contents).decode('utf-8')}"
)
return web.json_response(data=dict(images=images))
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
@dataclasses.dataclass
class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/runsync"
@classmethod
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
return DefaultComfyWorkflowData
def make_benchmark_payload(self) -> DefaultComfyWorkflowData:
return DefaultComfyWorkflowData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
return await generate_client_response(client_request, model_response)
@dataclasses.dataclass
class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/runsync"
@classmethod
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
return CustomComfyWorkflowData
def make_benchmark_payload(self) -> CustomComfyWorkflowData:
return CustomComfyWorkflowData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
return await generate_client_response(client_request, model_response)
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=False,
benchmark_handler=DefaultComfyWorkflowHandler(
benchmark_runs=3, benchmark_words=100
),
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, "Downloading:"),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/prompt", backend.create_handler(DefaultComfyWorkflowHandler())),
web.post("/custom-workflow", backend.create_handler(CustomComfyWorkflowHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
+15
View File
@@ -0,0 +1,15 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import DefaultComfyWorkflowData, Model
WORKER_ENDPOINT = "/prompt"
if __name__ == "__main__":
test_args.add_argument(
"-m",
dest="comfy_model",
choices=list(map(lambda x: x.value, Model)),
required=True,
help="Image generation model name",
)
test_load_cmd(DefaultComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
+300
View File
@@ -0,0 +1,300 @@
# Vast PyWorker
## Hello_world example
There is a hello_world PyWorker implantation under `workers/hello_world`. This PyWorker is
created for an LLM model server that runs on port 5001 has two API endpoints:
1. `/generate`: generates an full response to the prompt and sends a JSON response
2. `/generate_stream`: streams a response one token at a time
Both of these endpoints take the same API JSON payload:
```
{
"prompt": String,
"max_response_tokens": Number | null
}
```
We want the PyWorker to also expose two endpoints, for each of the above endpoints.
### Structure
All PyWorkers have four files:
```
.
└── workers
└── hello_world
├── __init__.py
├── data_types.py # contains data types representing model API endpoints
├── server.py # contains endpoint handlers
├── client.py # a script to call an endpoint through the autoscaler
└── test_load.py # script for load testing
```
All of the classes follow strict type hinting. It is recommended that you type hint all of your function.
This will allow your IDE or VSCode with `pyright` plugin to find any type errors in your implementation.
You can also install `pyright` with `sudo npm install -g pyright` and run `pyright` in the root of the project to find
any type errors.
#### data_Types.py
data classes representing the model API are defined here. They must inherit from
`lib.data_types.ApiPayload`. `ApiPayload` is an abstract class and you need to define several functions for it:
```python
import dataclasses
import random
from typing import Dict, Any
from transformers import AutoTokenizer # used to count tokens in a prompt
import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model
from lib.data_types import ApiPayload
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
# you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
@dataclasses.dataclass
class InputData(ApiPayload):
prompt: str
max_response_tokens: int
@classmethod
def for_test(cls) -> "ApiPayload":
"""defines how create a payload for load testing"""
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
return cls(prompt=prompt, max_response_tokens=300)
def generate_payload_json(self) -> Dict[str, Any]:
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
return dataclasses.asdict(self)
def count_workload(self) -> float:
"""defines how to calculate workload for a payload"""
return len(tokenizer.tokenize(self.prompt))
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
"""
defines how to transform JSON data to AuthData and payload type,
in this case `InputData` defined above represents the data sent to the model API.
AuthData is data generated by autoscaler in order to authenticate payloads.
In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker
for more complicated examples
"""
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
```
#### server.py
For every model API endpoint you want to use, you must implement an `EndpointHandler`. This class handles incoming
requests, processes them, sends them to the model API server, and finally returns an HTTP response.
`EndpointHandler` has several abstract functions that must be implemented. Here, we implement two, one
for `/generate`, and one for `/generate_stream`:
```python
"""
AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route.
When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData
work.
When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON:
{
auth_data: AuthData,
payload : InputData # defined above
}
"""
from aiohttp import web
from lib.data_types import EndpointHandler, JsonDataException
from lib.server import start_server
from .data_types import InputData
# This class is the implementer for the '/generate' endpoint of model API
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
# the API endpoint
return "/generate"
@classmethod
def payload_cls(cls) -> Type[InputData]:
"""this function should just return ApiPayload subclass used by this handler"""
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
"""
defines how to convert `InputData` defined above, to
JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but
can be more complicated, See comfyui for an example
"""
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
"""
defines how to generate an InputData for benchmarking. This needs to be defined in only
one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test()
method on InputData. However, in some cases you might need to fine tune your InputData used for
benchmarking to closely resemble the average request users call the endpoint with in order to get best
autoscaling performance
"""
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
"""
defines how to convert a model API response to a response to PyWorker client
"""
_ = client_request
match model_response.status:
case 200:
log.debug("SUCCESS")
data = await model_response.json()
return web.json_response(data=data)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
```
We also handle `GenerateStreamHandler` for streaming responses. It is identical to `GenerateHandler`, except for
the endpoint name and how we create a web response, as it is a streaming response:
```python
class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("Streaming response...")
res = web.StreamResponse()
res.content_type = "text/event-stream"
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
```
You can now instantiate a Backend and use it to handle requests.
```python
from lib.backend import Backend, LogAction
# the url and port of model API
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the log line that is emitted once the server has started
MODEL_SERVER_START_LOG_MSG = "server has started"
MODEL_SERVER_ERROR_LOG_MSGS = [
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
]
backend = Backend(
model_server_url=MODEL_SERVER_URL,
# location of model log file
model_log_file=os.environ["MODEL_LOG"],
# for some model backends that can only handle one request at a time, be sure to set this to False to
# let PyWorker handling queueing requests.
allow_parallel_requests=True,
# give the backend an EndpointHandler instance that is used for benchmarking
# number of benchmark run and number of words for a random benchmark run are given
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
# defines how to handle specific log messages. See docstring of LogAction for details
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
# this is a simple ping handler for PyWorker
async def handle_ping(_: web.Request):
return web.Response(body="pong")
# this is a handler for forwarding a health check to model API
async def handle_healthcheck(_: web.Request):
healthcheck_res = await backend.session.get("/healthcheck")
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
routes = [
web.post("/generate", backend.create_handler(GenerateHandler())),
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
web.get("/ping", handle_ping),
web.get("/healthcheck", handle_healthcheck),
]
if __name__ == "__main__":
# start server, called from start_server.sh
start_server(backend, routes)
```
#### test_load.py
Here you can create a script that allows you test an endpoint group running instances with this PyWorker
```python
from lib.test_harness import run
from .data_types import InputData
WORKER_ENDPOINT = "/generate"
if __name__ == "__main__":
run(InputData.for_test(), WORKER_ENDPOINT)
```
You can then run the following command from the root of this repo to load test endpoint group:
```sh
# sends 1000 requests at the rate of 0.5 requests per second
python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"
```
View File
+48
View File
@@ -0,0 +1,48 @@
import dataclasses
import random
import inspect
from typing import Dict, Any
from transformers import AutoTokenizer
import nltk
from lib.data_types import ApiPayload, JsonDataException
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
# used to count to count tokens and workload for LLM
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
@dataclasses.dataclass
class InputData(ApiPayload):
prompt: str
max_response_tokens: int
@classmethod
def for_test(cls) -> "InputData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
return cls(prompt=prompt, max_response_tokens=300)
def generate_payload_json(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
def count_workload(self) -> int:
return len(tokenizer.tokenize(self.prompt))
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
+167
View File
@@ -0,0 +1,167 @@
"""
PyWorker works as a man-in-the-middle between the client and model API. It's function is:
1. receive request from client, update metrics such as workload of a request, number of pending requests, etc.
2a. transform the data and forward the transformed data to model API
2b. send updated metrics to autoscaler
3. transform response from model API(if needed) and forward the response to client
PyWorker forward requests to many model API endpoint. each endpoint must have an EndpointHandler. You can also
write function to just forward requests that don't generate anything with the model to model API without an
EndpointHandler. This is useful for endpoints such as healthchecks. See below for example
"""
import os
import logging
import dataclasses
from typing import Dict, Any, Union, Type
from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import InputData
# the url and port of model API
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the log line that is emitted once the server has started
MODEL_SERVER_START_LOG_MSG = "infer server has started"
MODEL_SERVER_ERROR_LOG_MSGS = [
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
# This class is the implementer for the '/generate' endpoint of model API
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
# the API endpoint
return "/generate"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
"""
defines how to convert `InputData` defined above, to
json data to be sent to the model API
"""
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
"""
defines how to generate an InputData for benchmarking. This needs to be defined in only
one EndpointHandler, the one passed to the backend as the benchmark handler
"""
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
"""
defines how to convert a model API response to a response to PyWorker client
"""
_ = client_request
match model_response.status:
case 200:
log.debug("SUCCESS")
data = await model_response.json()
return web.json_response(data=data)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
# This is the same as GenerateHandler, except that it calls a streaming endpoint of the model API and streams the
# response, which itself is streaming, back to the client.
# it is nearly identical to handler as above, but it calls a different model API endpoint and it streams the
# streaming response from model API to client
class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("Streaming response...")
res = web.StreamResponse()
res.content_type = "text/event-stream"
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
# This is the backend instance of pyworker. Only one must be made which uses EndpointHandlers to process
# incoming requests
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
# give the backend a handler instance that is used for benchmarking
# number of benchmark run and number of words for a random benchmark run are given
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
# defines how to handle specific log messages. See docstring of LogAction for details
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
# this is a simple ping handler for pyworker
async def handle_ping(_: web.Request):
return web.Response(body="pong")
# this is a handler for forwarding a health check to modelAPI
async def handle_healthcheck(_: web.Request):
healthcheck_res = await backend.session.get("/healthcheck")
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
routes = [
web.post("/generate", backend.create_handler(GenerateHandler())),
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
web.get("/ping", handle_ping),
web.get("/healthcheck", handle_healthcheck),
]
if __name__ == "__main__":
# start the PyWorker server
start_server(backend, routes)
+7
View File
@@ -0,0 +1,7 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import InputData
WORKER_ENDPOINT = "/generate"
if __name__ == "__main__":
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
+19
View File
@@ -0,0 +1,19 @@
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
1. `generate`: Generates the LLM's response to a given prompt in a single request.
2. `generate_stream`: Streams the LLM's response token by token.
Both endpoints use the following API payload format:
```json
{
"inputs": "PROMPT",
"parameters": {
"max_new_tokens": 250
}
}
```
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
approximately 2 seconds to complete.
View File
+91
View File
@@ -0,0 +1,91 @@
import sys
import json
from urllib.parse import urljoin
import requests
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
WORKER_ENDPOINT = "/generate"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
)
res = response.json()
print(res)
def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str):
WORKER_ENDPOINT = "/generate_stream"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
message = response.json()
url = message["url"]
print(f"url: {url}")
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
response = requests.post(url, json=req_data, stream=True)
for line in response.iter_lines():
payload = line.decode().lstrip("data:").rstrip()
if payload:
data = json.loads(payload)
print(data["token"]["text"], end="")
sys.stdout.flush()
print()
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
call_generate(
api_key=args.api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_generate_stream(
api_key=args.api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
+73
View File
@@ -0,0 +1,73 @@
import dataclasses
import random
import inspect
from typing import Dict, Any
from transformers import AutoTokenizer
import nltk
from lib.data_types import ApiPayload, JsonDataException
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
@dataclasses.dataclass
class InputParameters:
max_new_tokens: int = 256
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputParameters":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
@dataclasses.dataclass
class InputData(ApiPayload):
inputs: str
parameters: InputParameters
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "InputData":
return cls(
inputs=data["inputs"], parameters=InputParameters(**data["parameters"])
)
@classmethod
def for_test(cls) -> "InputData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
return cls(inputs=prompt, parameters=InputParameters())
def generate_payload_json(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
def count_workload(self) -> int:
return self.parameters.max_new_tokens
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
try:
parameters = InputParameters.from_json_msg(json_msg["parameters"])
return cls(inputs=json_msg["inputs"], parameters=parameters)
except JsonDataException as e:
errors["parameters"] = e.message
raise JsonDataException(errors)
+115
View File
@@ -0,0 +1,115 @@
import os
import logging
from typing import Union, Type
import dataclasses
from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import InputData
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = '"message":"Connected","target":"text_generation_router"'
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError"]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
_ = client_request
match model_response.status:
case 200:
log.debug("SUCCESS")
data = await model_response.json()
return web.json_response(data=data)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("Streaming response...")
res = web.StreamResponse()
res.content_type = "text/event-stream"
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/generate", backend.create_handler(GenerateHandler())),
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
+7
View File
@@ -0,0 +1,7 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import InputData
WORKER_ENDPOINT = "/generate"
if __name__ == "__main__":
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)