Compare commits

...

43 Commits

Author SHA1 Message Date
Lucas Armand 7d3be849d9 Handle errors from model for comfyui-json 2025-10-08 12:00:45 -07:00
Colter-Downing 0397af719d Merge pull request #37 from robballantyne/bugfix/healthcheck-endpoint
Fix healthcheck endpoint URL

Tested and merged by Colter
2025-10-06 15:11:27 -07:00
Rob Ballantyne 4fdc314fd9 Fix healthcheck endpoint URL 2025-10-06 22:16:09 +01:00
Colter-Downing 639d82f5b4 Merge pull request #35 from vast-ai/AUTO-664--Healthcheck-error
Fix healthcheck with separate session
2025-10-02 12:51:19 -07:00
Colter Downing 25db78e39d Fix healthcheck with separate session 2025-10-01 18:04:31 -07:00
Scott-Laytart 4e2f2311d0 Merge pull request #33 from vast-ai/comfy-blind-fix-override
undo the fix for comfy yesterday.
2025-09-03 11:50:07 -07:00
abiola-vastai 38782d89bc undo the fix for comfy yesterday. 2025-09-03 17:12:35 +00:00
Scott-Laytart 0185216ccb Merge pull request #32 from vast-ai/blindhotfix_comfy_ui_default_port
Blind hotfix to see if comfy UI default is needed. if it does work we…
2025-09-02 18:26:25 -07:00
abiola-vastai b20d9e714c Blind hotfix to see if comfy UI default is needed. if it does work we would revert back. 2025-09-03 01:20:09 +00:00
Rob Ballantyne b1eb65d75d Merge pull request #31 from vast-ai/bugfix/startup-script-20250901
Update uv venv creation command
2025-09-01 18:19:17 +01:00
Rob Ballantyne 1d09d7fe96 Update uv venv creation command 2025-09-01 16:55:20 +01:00
Colter-Downing 1b37054dec Merge pull request #28 from vast-ai/bugfix/backend-timeout-infinite
Bugfix/backend timeout infinite
2025-08-28 11:22:33 -07:00
Colter-Downing 1a1e4174b8 Merge pull request #29 from vast-ai/bugfix/comfyui-json-cost-fix
Set cost to 100
2025-08-28 11:22:21 -07:00
Rob Ballantyne b8377c4081 Set cost to 100 2025-08-28 16:13:17 +01:00
Rob Ballantyne 1e4fa87437 Prevent timeout and allow long running connections 2025-08-28 15:48:57 +01:00
Rob Ballantyne 4c5fa03c7b adds import for ClientTimeout 2025-08-27 20:54:27 +01:00
Rob Ballantyne a8fe74f771 Remove default 300s timeout 2025-08-27 18:34:45 +01:00
Rob Ballantyne b482de8394 Merge pull request #27 from vast-ai/feat/comfyui-api-s3-webhook
Adds new ComfyUI worker

Upload assets to s3 compatible storage via intermediate API wrapper
2025-08-26 14:22:05 +01:00
Rob Ballantyne 703435d10e Improve MODEL_SERVER_START_* messages 2025-08-26 12:42:04 +01:00
Rob Ballantyne 947fc5eea4 Improve benchmarking explanation 2025-08-26 12:41:30 +01:00
Rob Ballantyne 7c1a544b19 Improve error reporting when no ready workers 2025-08-26 12:41:05 +01:00
Rob Ballantyne 16b414676e Use count_workload() function for cost 2025-08-25 18:31:10 +01:00
Rob Ballantyne ba74ac8136 Use cost value 1 for all jobs 2025-08-25 17:58:22 +01:00
Rob Ballantyne 92ff412679 Use MODEL_SERVER_URL environment variable 2025-08-25 17:57:32 +01:00
Rob Ballantyne fc75a64684 Use MODEL_SERVER_URL environment variable 2025-08-25 17:56:27 +01:00
Rob Ballantyne b00bef547c Ensure uv env script is present before sourcing 2025-08-22 17:08:42 +01:00
Rob Ballantyne 3f4acb29fa Improved client exception handling 2025-08-22 15:20:15 +01:00
Rob Ballantyne 58b078f908 Fix modifier class 2025-08-20 18:06:02 +01:00
Rob Ballantyne f9fdf04884 Fix signature 2025-08-20 13:27:29 +01:00
Rob Ballantyne 636f17d27f Fix workflow modifier class 2025-08-20 09:57:07 +01:00
Rob Ballantyne 08c88f7527 Improve testability 2025-08-20 09:34:09 +01:00
Rob Ballantyne 8797b504af Initial ComfyUI implementation with updated wrapper 2025-08-19 17:59:20 +01:00
Nader Arbabian cd946b0a9f update report_addr to use new webserver endpoint with AS fallback 2025-08-12 13:31:19 -07:00
Nader Arbabian c595b42410 for benchmarking, use concurrent requests (#26) 2025-08-11 12:39:28 -07:00
Nader Arbabian 0bf3247a34 fix completions and interactive client 2025-08-11 12:37:53 -07:00
Nader Arbabian 52ac4c0c1a fix endpoint_util not using the correct instance's endpoint 2025-08-11 12:05:58 -07:00
Nader Arbabian 8804e17201 download vast.ai's root certificate in order to make pyworker requests (#25) 2025-08-08 17:04:16 -07:00
Nader Arbabian 4016cf9a53 redo metrics tracking for requests, fixes bug wherere some requests were marked as pending, even though they had finished (#24) 2025-08-08 17:01:21 -07:00
Rob Ballantyne e0be45f39a Addresses breaking change in core pyworker (#22)
* Addresses breaking change in test_utils.py

Endpoint.get_endpoint_api_key() now requires instance

Moves the call to this function out of the APIClient and into main

* Ensure make_benchmark_payload has a value to calculate the workload

---------

Co-authored-by: Nader Arbabian <nader@vast.ai>
2025-07-18 16:11:10 -07:00
Nader Arbabian be2aafdb1f fix pyright errors + revert to old way of handling cancelled api requests (#23) 2025-07-17 16:59:06 -07:00
Rob Ballantyne 9e369c55a5 Ensure venv creation where python is unavailable (#21) 2025-07-17 09:59:35 -07:00
Rob Ballantyne 69d9b7455f OpenAI compatible worker (#19)
Adds initial support for OpenAI compatible inference servers

Available endpoints:

- `/v1/completions`
- `/v1/chat/completions`
2025-07-16 09:46:26 +01:00
Nader Arbabian 6fb610cb5b fix pyworker miscounting active connections (#20)
* fix pyworker miscounting active connections

* clean up some issues

* add option to skip auth
2025-07-15 15:33:27 -07:00
27 changed files with 1888 additions and 101 deletions
+67 -39
View File
@@ -8,9 +8,10 @@ import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property
from distutils.util import strtobool
from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
import requests
from Crypto.Signature import pkcs1_15
@@ -55,6 +56,9 @@ class Backend:
reqnum = -1
msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
def __post_init__(self):
self.metrics = Metrics()
@@ -71,7 +75,13 @@ class Backend:
@cached_property
def session(self):
log.debug(f"starting session with {self.model_server_url}")
return ClientSession(self.model_server_url)
connector = TCPConnector(
force_close=True, # Required for long running jobs
enable_cleanup_closed=True,
)
timeout = ClientTimeout(total=None)
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
def create_handler(
self,
@@ -119,14 +129,10 @@ class Backend:
return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload()
async def wait_for_disconnection() -> None:
while request.transport and not request.transport.is_closing():
await sleep(0.5)
async def cancel_api_call_if_disconnected() -> web.Response:
await wait_for_disconnection()
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
self.metrics._request_canceled(workload=workload)
return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]:
@@ -141,7 +147,6 @@ class Backend:
else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
try:
start_time = time.time()
response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status
log.debug(
@@ -153,19 +158,17 @@ class Backend:
)
)
res = await handler.generate_client_response(request, response)
self.metrics._request_end(
workload=workload,
req_response_time=time.time() - start_time,
reqnum=auth_data.reqnum,
)
self.metrics._request_success(workload=workload)
return res
except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(
workload=workload, reqnum=auth_data.reqnum
)
self.metrics._request_errored(workload=workload)
return web.Response(status=500)
finally:
self.metrics._request_end(
workload=workload,
reqnum=auth_data.reqnum,
)
self.sem.release()
###########
@@ -187,18 +190,30 @@ class Backend:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
@cached_property
def healthcheck_session(self):
"""Dedicated session for healthchecks to avoid conflicts with API session"""
log.debug("creating dedicated healthcheck session")
connector = TCPConnector(
force_close=True, # Keep this for isolation
enable_cleanup_closed=True,
)
timeout = ClientTimeout(total=10) # Reasonable timeout for healthchecks
return ClientSession(timeout=timeout, connector=connector)
async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint
if health_check_url is None:
log.debug("No healthcheck endpoint defined, skipping healthcheck")
return
while True:
await sleep(10)
if self.__start_healthcheck is False:
continue
try:
log.debug(f"Performing healthcheck on {health_check_url}")
async with self.session.get(health_check_url) as response:
async with self.healthcheck_session.get(health_check_url) as response:
if response.status == 200:
log.debug("Healthcheck successful")
elif response.status == 503:
@@ -207,7 +222,6 @@ class Backend:
f"Healthcheck failed with status: {response.status}"
)
else:
# endpoint not ready yet so bail
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}")
@@ -229,6 +243,9 @@ class Backend:
return await self.session.post(url=handler.endpoint, json=api_payload)
def __check_signature(self, auth_data: AuthData) -> bool:
if self.unsecured is True:
return True
def verify_signature(message, signature):
if self.pubkey is None:
log.debug(f"No Public Key!")
@@ -280,41 +297,52 @@ class Backend:
return float(f.readline())
except FileNotFoundError:
pass
max_throughput = 0
last_throughput = 0
sum_throughput = 0
for run in range(self.benchmark_handler.benchmark_runs + 1):
start = time.time()
log.debug("Initial run to trigger model loading...")
payload = self.benchmark_handler.make_benchmark_payload()
res = await self.__call_api(
handler=self.benchmark_handler, payload=payload
await self.__call_api(handler=self.benchmark_handler, payload=payload)
max_throughput = 0
sum_throughput = 0
concurrent_requests = 10 if self.allow_parallel_requests else 1
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
start = time.time()
tasks = []
total_workload = 0
for _ in range(concurrent_requests):
payload = self.benchmark_handler.make_benchmark_payload()
total_workload += payload.count_workload()
tasks.append(
self.__call_api(handler=self.benchmark_handler, payload=payload)
)
data = await res.json()
responses = await gather(*tasks)
time_elapsed = time.time() - start
# first run triggers one-time loading of the model which is very slow, so we skip counting it
if run == 0:
continue
else:
workload = payload.count_workload()
last_throughput = workload / time_elapsed
sum_throughput += last_throughput
max_throughput = max(max_throughput, last_throughput)
throughput = total_workload / time_elapsed
sum_throughput += throughput
max_throughput = max(max_throughput, throughput)
# Log results for debugging
log.debug(
"\n".join(
[
"#" * 60,
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
"",
f"response: {data}",
f"Run: {run}, concurrent_requests: {concurrent_requests}",
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
f"Throughput: {throughput} workload/s",
f"Successful responses: {len([r for r in responses if r.status == 200])}",
"#" * 60,
]
)
)
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
log.debug(
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
)
# save max_throughput so we don't have to run benchmark again on restart of cold instances
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
f.write(str(max_throughput))
return max_throughput
+7 -4
View File
@@ -8,7 +8,6 @@ from aiohttp import web, ClientResponse
import inspect
import psutil
import requests
"""
@@ -206,13 +205,13 @@ class ModelMetrics:
workload_received: float
workload_cancelled: float
workload_errored: float
workload_pending: float
# these are not
cur_perf: float
workload_pending: float
error_msg: Optional[str]
max_throughput: float
requests_recieved: Set[int] = field(default_factory=set)
requests_working: Set[int] = field(default_factory=set)
last_update: float = field(default_factory=time.time)
@classmethod
def empty(cls):
@@ -221,12 +220,15 @@ class ModelMetrics:
workload_served=0.0,
workload_cancelled=0.0,
workload_errored=0.0,
cur_perf=0.0,
workload_received=0.0,
error_msg=None,
max_throughput=0.0,
)
@property
def cur_perf(self) -> float:
return max(self.workload_served / (time.time() - self.last_update), 0.0)
@property
def workload_processing(self) -> float:
return max(self.workload_received - self.workload_cancelled, 0.0)
@@ -240,6 +242,7 @@ class ModelMetrics:
self.workload_received = 0
self.workload_cancelled = 0
self.workload_errored = 0
self.last_update = time.time()
@dataclass
+19 -16
View File
@@ -46,33 +46,31 @@ class Metrics:
self.model_metrics.requests_recieved.add(reqnum)
self.model_metrics.requests_working.add(reqnum)
def _request_end(
self, workload: float, req_response_time: float, reqnum: int
) -> None:
def _request_end(self, workload: float, reqnum: int) -> None:
"""
this function is called after a response from model API is received.
this function is called after handling of a request ends, regardless of the outcome
"""
self.model_metrics.workload_served += workload
self.model_metrics.workload_pending -= workload
self.model_metrics.requests_working.discard(reqnum)
self.model_metrics.cur_perf = workload / req_response_time
def _request_success(self, workload: float) -> None:
"""
this function is called after a response from model API is received and forwarded.
"""
self.model_metrics.workload_served += workload
self.update_pending = True
def _request_errored(self, workload: float, reqnum: int) -> None:
def _request_errored(self, workload: float) -> None:
"""
this function is called if model API returns an error
"""
self.model_metrics.workload_pending -= workload
self.model_metrics.workload_errored += workload
self.model_metrics.requests_working.discard(reqnum)
def _request_canceled(self, workload: float, reqnum: int) -> None:
def _request_canceled(self, workload: float) -> None:
"""
this function is called if client drops connection before model API has responded
"""
self.model_metrics.workload_pending -= workload
self.model_metrics.workload_cancelled += workload
self.model_metrics.requests_working.discard(reqnum)
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
while True:
@@ -116,7 +114,7 @@ class Metrics:
url=self.url,
)
def send_data(report_addr: str) -> None:
def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data()
full_path = report_addr.rstrip("/") + "/worker_status/"
log.debug(
@@ -131,21 +129,26 @@ class Metrics:
)
for attempt in range(1, 4):
try:
requests.post(full_path, json=asdict(data), timeout=1)
break
res = requests.post(full_path, json=asdict(data), timeout=1)
res.raise_for_status()
return True
except requests.Timeout:
log.debug(f"autoscaler status update timed out")
except Exception as e:
log.debug(f"autoscaler status update failed with error: {e}")
time.sleep(2)
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
log.debug(f"failed to send update through {report_addr}")
return False
###########
self.system_metrics.update_disk_usage()
for report_addr in self.report_addr:
send_data(report_addr)
success = send_data(report_addr)
if success is True:
break
self.update_pending = False
self.model_metrics.reset()
self.system_metrics.reset()
+27 -9
View File
@@ -10,6 +10,7 @@ from collections import Counter
from dataclasses import dataclass, field, asdict
from urllib.parse import urljoin
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
import requests
from lib.data_types import AuthData, ApiPayload
@@ -53,6 +54,13 @@ test_args.add_argument(
default="https://run.vast.ai",
help="Call local autoscaler instead of prod, for dev use only",
)
test_args.add_argument(
"-i",
dest="instance",
type=str,
default="prod",
help="Autoscaler shard to run the command against, default: prod",
)
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
@@ -70,6 +78,7 @@ class ClientState:
api_key: str
server_url: str
worker_endpoint: str
instance: str
payload: ApiPayload
url: str = ""
status: ClientStatus = ClientStatus.FetchEndpoint
@@ -79,11 +88,7 @@ class ClientState:
def make_call(self):
self.status = ClientStatus.FetchEndpoint
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
)
if not endpoint_api_key:
if not self.api_key:
self.as_error.append(
f"Endpoint {self.endpoint_group_name} not found for API key",
)
@@ -91,12 +96,14 @@ class ClientState:
return
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": endpoint_api_key,
"api_key": self.api_key,
"cost": self.payload.count_workload(),
}
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
headers=headers,
timeout=4,
)
if response.status_code != 200:
@@ -114,9 +121,11 @@ class ClientState:
self.url = worker_address
url = urljoin(worker_address, self.worker_endpoint)
self.status = ClientStatus.Generating
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
if response.status_code != 200:
self.infer_error.append(
@@ -135,6 +144,7 @@ class ClientState:
try:
self.make_call()
except Exception as e:
print(e)
self.status = ClientStatus.Error
_ = e
self.conn_errors[self.url] += 1
@@ -226,6 +236,7 @@ def run_test(
server_url: str,
worker_endpoint: str,
payload_cls: Type[ApiPayload],
instance: str,
):
threads = []
@@ -234,8 +245,7 @@ def run_test(
print_thread.daemon = True # makes threads get killed on program exit
print_thread.start()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=endpoint_group_name,
account_api_key=api_key,
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
)
if not endpoint_api_key:
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
@@ -248,6 +258,7 @@ def run_test(
server_url=server_url,
worker_endpoint=worker_endpoint,
payload=payload_cls.for_test(),
instance=instance,
)
clients.append(client)
thread = threading.Thread(target=client.simulate_user, args=())
@@ -281,12 +292,19 @@ def test_load_cmd(
args = arg_parser.parse_args()
if hasattr(args, "comfy_model"):
os.environ["COMFY_MODEL"] = args.comfy_model
server_url = dict(
prod="https://run.vast.ai",
alpha="https://run-alpha.vast.ai",
candidate="https://run-candidate.vast.ai",
local="http://localhost:8080",
)[args.instance]
run_test(
num_requests=args.num_requests,
requests_per_second=args.requests_per_second,
api_key=args.api_key,
server_url=args.server_url,
server_url=server_url,
endpoint_group_name=args.endpoint_group_name,
worker_endpoint=endpoint,
payload_cls=payload_cls,
instance=args.instance,
)
+2 -2
View File
@@ -1,4 +1,4 @@
aiohttp~=3.11
aiohttp[speedups]==3.10.1
anyio~=4.4
lib~=4.0
nltk~=3.9
@@ -6,5 +6,5 @@ psutil~=6.0
pycryptodome~=3.20
Requests~=2.32
transformers~=4.52
utils~=1.0
utils==1.0.*
hf_transfer>=0.1.9
+24 -9
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR"
@@ -41,22 +41,37 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG
echo_var MODEL_LOG
env | grep _ >> /etc/environment;
# Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
name=${line%%=*}
value=${line#*=}
printf '%s="%s"\n' "$name" "$value"
done > /etc/environment
fi
if [ ! -d "$ENV_PATH" ]
then
apt install -y python3.10-venv
echo "setting up venv"
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
if ! which uv; then
curl -LsSf https://astral.sh/uv/install.sh | sh
source ~/.local/bin/env
fi
python3 -m venv "$WORKSPACE_DIR/worker-env"
source "$WORKSPACE_DIR/worker-env/bin/activate"
# Fork testing
[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
if [[ -n ${PYWORKER_REF:-} ]]; then
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
fi
pip install -r vast-pyworker/requirements.txt
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
source "$ENV_PATH/bin/activate"
uv pip install -r "${SERVER_DIR}/requirements.txt"
touch ~/.no_auto_tmux
else
[[ -f ~/.local/bin/env ]] && source ~/.local/bin/env
source "$WORKSPACE_DIR/worker-env/bin/activate"
echo "environment activated"
echo "venv: $VIRTUAL_ENV"
@@ -103,7 +118,7 @@ fi
export REPORT_ADDR WORKER_PORT USE_SSL
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
cd "$SERVER_DIR"
+25 -3
View File
@@ -17,7 +17,27 @@ class Endpoint:
"""
@staticmethod
def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]:
def get_autoscaler_server_url(instance: str) -> str:
endpoints = {
"alpha": "run-alpha",
"candidate": "run-candidate",
"prod": "run",
}
return f"https://{endpoints[instance]}.vast.ai/"
@staticmethod
def get_server_url(instance: str) -> str:
endpoints = {
"alpha": "alpha",
"candidate": "candidate",
"prod": "console",
}
return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
@staticmethod
def get_endpoint_api_key(
endpoint_name: str, account_api_key: str, instance: str
) -> Optional[str]:
"""
Fetch endpoint API key from VastAI console following the healthcheck pattern.
@@ -28,12 +48,14 @@ class Endpoint:
Returns:
Endpoint API key if successful, None otherwise
"""
vast_console_url = "https://console.vast.ai/api/v0/endptjobs/"
headers = {"Authorization": f"Bearer {account_api_key}"}
try:
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
response = requests.get(vast_console_url, headers=headers)
response = requests.get(
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
headers=headers,
)
if response.status_code != 200:
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
+15
View File
@@ -0,0 +1,15 @@
import tempfile
from functools import cache
import requests
@cache
def get_cert_file_path():
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
response = requests.get(cert_url)
response.raise_for_status()
# Use a temporary file that is not deleted on close
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
f.write(response.content)
return f.name
+210
View File
@@ -0,0 +1,210 @@
# ComfyUI PyWorker
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture.
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
## Requirements
This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [ComfyUI API Wrapper](https://github.com/ai-dock/comfyui-api-wrapper).
A docker image is provided but you may use any if the above requirements are met.
## Benchmarking
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
| Environment Variable | Default Value | Description |
| -------------------- | ------------- | ----------- |
| BENCHMARK_TEST_WIDTH | 512 | Image width (pixels) |
| BENCHMARK_TEST_HEIGHT | 512 | Image height (pixels) |
| BENCHMARK_TEST_STEPS | 20 | Number of denoising steps |
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
### Calibrating Benchmark Duration
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
**Example:** If your typical workflow should complete in 90 seconds on acceptable hardware:
```bash
# 1. Measure it/sec on your reference machine
# RTX 4090 typically achieves ~43 it/sec with SD1.5
# 2. Calculate required steps
# 90 seconds × 43 it/sec = 3870 steps
# 3. Configure benchmark
export BENCHMARK_TEST_STEPS=3870
# 4. Machines completing significantly slower than 90s indicate hardware issues
```
**Performance expectations:**
- Benchmark duration should remain consistent across identical GPU models
- Significant variation (>20%) may indicate thermal, power, or configuration issues
## Endpoint
The worker provides a single endpoint:
- `/generate/sync`: Processes ComfyUI workflows using either predefined modifiers or custom workflow JSON
## Request Format
The worker accepts requests in the following format. Choose either modifier mode OR custom workflow mode:
**Modifier Mode:**
```json
{
"input": {
"request_id": "uuid-string", // optional - UUID generated if not provided
"modifier": "RawWorkflow",
"modifications": {
"prompt": "a beautiful landscape",
"width": 1024,
"height": 1024,
"steps": 20,
"seed": 123456789
},
"s3": { ... }, // optional
"webhook": { ... } // optional
}
}
```
**Custom Workflow Mode:**
```json
{
"input": {
"request_id": "uuid-string", // optional - UUID generated if not provided
"workflow_json": {
// Complete ComfyUI workflow JSON
},
"s3": { ... }, // optional
"webhook": { ... } // optional
}
}
```
## Request Fields
### Required Fields
- **`input`**: Contains the main workflow data
- **`input.request_id`**: Unique identifier for the request
### Workflow Mode (Choose One)
You must provide either `modifier` OR `workflow_json`, but not both:
#### Option 1: Modifier Mode
- **`input.modifier`**: Name of the predefined workflow modifier (e.g., "Text2Image")
- **`input.modifications`**: Parameters to pass to the modifier
#### Option 2: Custom Workflow Mode
- **`input.workflow_json`**: Complete ComfyUI workflow JSON
### Optional Fields
- **`input.s3`**: S3 configuration for file storage
- **`input.webhook`**: Webhook configuration for notifications
These configurations can be provided in the request JSON or via environment variables. Request-level configuration takes precedence over environment variables.
#### S3 Configuration
**Via Request JSON:**
```json
"s3": {
"access_key_id": "your-s3-access-key",
"secret_access_key": "your-s3-secret-access-key",
"endpoint_url": "https://my-endpoint.backblaze.com",
"bucket_name": "your-bucket",
"region": "us-east-1"
}
```
**Via Environment Variables:**
```bash
S3_ACCESS_KEY_ID=your-key
S3_SECRET_ACCESS_KEY=your-secret
S3_BUCKET_NAME=your-bucket
S3_ENDPOINT_URL=https://s3.amazonaws.com
S3_REGION=us-east-1
```
#### Webhook Configuration
**Via Request JSON:**
```json
"webhook": {
"url": "your-webhook-url",
"extra_params": {
"custom_field": "value"
}
}
```
**Via Environment Variables:**
```bash
WEBHOOK_URL=https://your-webhook.com # Default webhook URL
WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
```
## Examples
### Basic Text-to-Image (Modifier Mode)
```json
{
"input": {
"modifier": "Text2Image",
"modifications": {
"prompt": "a cat sitting on a windowsill",
"width": 512,
"height": 512,
"steps": 20,
"seed": 42
}
}
}
```
### Custom Workflow Mode
```json
{
"input": {
"request_id": "67890", // optional - using custom ID for tracking
"workflow_json": {
"3": {
"inputs": {
"seed": 42,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": ["4", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["5", 0]
},
"class_type": "KSampler"
}
}
}
}
```
## Client Libraries
See the test client examples for implementation details on how to integrate with the ComfyUI worker.
---
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
View File
+155
View File
@@ -0,0 +1,155 @@
import logging
import uuid
import random
from urllib.parse import urljoin
import json
import requests
from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types import count_workload
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_text2image_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
"""Simple Text2Image using the new modifier-based approach"""
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
"""Helper function for making requests with consistent error handling"""
try:
response = requests.post(
url,
json=payload,
timeout=timeout,
verify=verify
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred during {context}: {http_err}")
log.error(f"Status Code: {response.status_code}")
log.error("Response content:", response.text)
return None
except requests.exceptions.Timeout:
log.error(f"Timeout occurred during {context}: {url}")
return None
except requests.exceptions.ConnectionError:
log.error(f"Connection error occurred during {context}: {url}")
return None
except json.JSONDecodeError as json_err:
log.error(f"Failed to decode JSON response during {context}: {json_err}")
if 'response' in locals():
print("Response content:", response.text)
return None
except Exception as err:
log.error(f"An unexpected error occurred during {context}: {err}")
if 'response' in locals():
log.error("Response content (if available):", response.text)
return None
WORKER_ENDPOINT = "/generate/sync"
# This worker has concurrency = 1. All workloads have cost value 1.0
COST = count_workload()
# Route to get worker URL
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
# First request - get routing information
route_response = make_request(
url=urljoin(server_url, "/route/"),
payload=route_payload,
timeout=4,
context="route request"
)
if route_response is None:
return None
if "url" not in route_response or not route_response["url"]:
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
return None
if "status" in route_response:
print(f"Autoscaler status: {route_response['status']}")
return None
# Extract data from route response
url = route_response["url"]
auth_data = dict(
signature=route_response["signature"],
cost=route_response["cost"],
endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"],
url=route_response["url"],
)
# Build the payload for the worker request
worker_payload = {
"input": {
"request_id": str(uuid.uuid4()),
"modifier": "Text2Image",
"modifications": {
"prompt": "a beautiful landscape with mountains and lakes",
"width": 1024,
"height": 1024,
"steps": 20,
"seed": random.randint(0, 2**32 - 1)
},
"workflow_json": {} # Empty since using modifier approach
}
}
req_data = dict(payload=worker_payload, auth_data=auth_data)
worker_url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {worker_url}")
# Second request - call the worker endpoint
worker_response = make_request(
url=worker_url,
payload=req_data,
verify=get_cert_file_path(),
context="worker request"
)
return worker_response
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
result = call_text2image_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
if result is None:
log.error("Text2Image workflow failed")
else:
print(result)
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
+60
View File
@@ -0,0 +1,60 @@
import os
import sys
import random
import dataclasses
from typing import Dict, Any
from functools import cache
from math import ceil
from lib.data_types import ApiPayload, JsonDataException
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
def count_workload() -> float:
# Always 100.0 where there is a single instance of ComfyUI handling requests
# Results will indicate % or a job completed per second. Avoids sub 0.1 sec performance indication
return 100.0
@dataclasses.dataclass
class ComfyWorkflowData(ApiPayload):
input: dict
@classmethod
def for_test(cls):
"""
Use the variables available to simulate workflows of the required running time
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
"""
test_prompt = random.choice(test_prompts).rstrip()
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"modifier": "Text2Image",
"modifications": {
"prompt": test_prompt,
"width": os.getenv('BENCHMARK_TEST_WIDTH', 512),
"height": os.getenv('BENCHMARK_TEST_HEIGHT', 512),
"steps": os.getenv('BENCHMARK_TEST_STEPS', 20),
"seed": random.randint(0, sys.maxsize),
}
}
)
def generate_payload_json(self) -> Dict[str, Any]:
# input is already a dict, just return it wrapped in the expected structure
return {"input": self.input}
def count_workload(self) -> float:
return count_workload()
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ComfyWorkflowData":
# Extract required fields
if "input" not in json_msg:
raise JsonDataException({"input": "missing parameter"})
return cls(
input=json_msg["input"]
)
@@ -0,0 +1,34 @@
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
stardew valley, fine details
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
realistic futuristic city-downtown with short buildings, sunset
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
Pope Francis wearing biker (leather jacket), a masterpiece
Luke Skywalker ordering a burger and fries from the Death Star canteen.
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
london luxurious interior living-room, light walls
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
the street of amedieval fantasy town, at dawn, dark, highly detailed
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
+122
View File
@@ -0,0 +1,122 @@
import os
import logging
import dataclasses
import base64
from typing import Optional, Union, Type
from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import ComfyWorkflowData
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: ", # This error is emitted when the model file is not there at all
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def generate_client_response(
client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("SUCCESS")
# Check if the response is actually streaming based on response headers/content-type
is_streaming_response = (
model_response.content_type == "text/event-stream"
or model_response.content_type == "application/x-ndjson"
or model_response.headers.get("Transfer-Encoding") == "chunked"
or "stream" in model_response.content_type.lower()
)
if is_streaming_response:
log.debug("Detected streaming response...")
res = web.StreamResponse()
res.content_type = model_response.content_type
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
else:
log.debug("Detected non-streaming response...")
content = await model_response.read()
return web.Response(
body=content,
status=model_response.status,
content_type=model_response.content_type
)
case code:
log.debug(f"Model responded with error {code}")
return web.Response(status=code)
@dataclasses.dataclass
class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/generate/sync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[ComfyWorkflowData]:
return ComfyWorkflowData
def make_benchmark_payload(self) -> ComfyWorkflowData:
return ComfyWorkflowData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
return await generate_client_response(client_request, model_response)
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=False,
benchmark_handler=ComfyWorkflowHandler(
benchmark_runs=3, benchmark_words=100
),
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, "Downloading:"),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
+8
View File
@@ -0,0 +1,8 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import ComfyWorkflowData
WORKER_ENDPOINT = "/generate/sync"
if __name__ == "__main__":
test_load_cmd(ComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
+4
View File
@@ -5,6 +5,7 @@ import requests
from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
"""
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
@@ -51,6 +52,7 @@ def call_default_workflow(
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
print_truncate_res(str(response.json()))
@@ -141,6 +143,7 @@ def call_custom_workflow_for_sd3(
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
print_truncate_res(str(response.json()))
@@ -153,6 +156,7 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
try:
+1 -1
View File
@@ -13,7 +13,7 @@ from lib.server import start_server
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
MODEL_SERVER_URL = "http://0.0.0.0:38188"
MODEL_SERVER_URL = "http://127.0.0.1:18288" # API Wrapper Service
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
+80
View File
@@ -0,0 +1,80 @@
# OpenAI Compatible PyWorker
This is the base PyWorker for OpenAI compatible inference servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
## Instance Setup
1. Pick a template
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
## Using the Test Client
Several examples have been provided in the client to help you get started with your own implementation.
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
```
### Interactive Chat (streaming)
Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
```
+77
View File
@@ -0,0 +1,77 @@
# <INFERENCE_SERVER> + <MODEL_NAME> (serverless)
Run <INFERENCE_SERVER> with our serverless autoscaling infrastructure.
See the [serverless documentation](https://docs.vast.ai/serverless) and the [Getting Started](https://docs.vast.ai/serverless/getting-started) guide for in-depth details about how to use these templates.
## Configuration
Two environment variables are provided to help you configure the <INFERENCE_SERVER> server:
| Variable | Default Value | Used For |
| --- | --- | --- |
| `MODEL_NAME` | `<MODEL_NAME>` | The model to load. Also accepts [hf.co/repo/model](#) links |
| `<ARGS_VAR>` | `<ARGS_VAL>` | Arguments to pass to the `<ARGS_RECEIVER>` command |
This template has been configured to work with <MIN_VRAM> VRAM. Setting alternative models and server arguments will change the VRAM requirements. Check model cards and <INFERENCE_SERVER_DOCS> for guidance.
## Usage
We have provided a demonstration client to help you implement this template into your own infrastructure
### Client Setup
Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
```
### Interactive Chat (streaming)
Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
```
View File
+599
View File
@@ -0,0 +1,599 @@
import logging
import sys
import json
import subprocess
from urllib.parse import urljoin
from typing import Dict, Any, Optional, Iterator, Union, List
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
class APIClient:
"""Lightweight client focused solely on API communication"""
# Remove the generic WORKER_ENDPOINT since we're now going direct
DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4
def __init__(
self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
):
self.endpoint_group_name = endpoint_group_name
self.api_key = api_key
self.server_url = server_url
self.endpoint_api_key = endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service"""
if not self.endpoint_api_key:
raise ValueError("No valid endpoint API key available")
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.endpoint_api_key,
"cost": cost,
}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
timeout=self.DEFAULT_TIMEOUT,
)
response.raise_for_status()
return response.json()
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Create auth data from routing response"""
return {
"signature": message["signature"],
"cost": message["cost"],
"endpoint": message["endpoint"],
"reqnum": message["reqnum"],
"url": message["url"],
}
def _make_request(
self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint"""
# Get worker URL and auth data
cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost)
worker_url = message["url"]
auth_data = self._create_auth_data(message)
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}")
log.debug(f"Payload: {req_data}")
# Make the request using the specified method
if method.upper() == "POST":
response = requests.post(
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET":
response = requests.get(
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
if stream:
return self._handle_streaming_response(response)
else:
return response.json()
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
"""Handle streaming response and yield tokens"""
try:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
yield data # Yield the full chunk
except json.JSONDecodeError:
continue
except Exception as e:
log.error(f"Error handling streaming response: {e}")
raise
def call_completions(
self, config: CompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/completions", stream=config.stream
)
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
)
class ToolManager:
"""Handles tool definitions and execution"""
@staticmethod
def list_files() -> str:
"""Execute ls on current directory"""
try:
result = subprocess.run(
["ls", "-la", "."], capture_output=True, text=True, timeout=10
)
if result.returncode == 0:
return result.stdout
else:
return f"Error: {result.stderr}"
except Exception as e:
return f"Error running ls: {e}"
@staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""Get the ls tool definition"""
return [
{
"type": "function",
"function": {
"name": "list_files",
"description": "List files and directories in the cwd",
"parameters": {"type": "object", "properties": {}, "required": []},
},
}
]
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result"""
function_name = tool_call["function"]["name"]
if function_name == "list_files":
return self.list_files()
else:
raise ValueError(f"Unknown tool function: {function_name}")
class APIDemo:
"""Demo and testing functionality for the API client"""
def __init__(
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
self.client = client
self.model = model
self.tool_manager = tool_manager or ToolManager()
def handle_streaming_response(
self, response_stream, show_reasoning: bool = True
) -> str:
"""
Handle streaming chat response and display all output.
"""
full_response = ""
reasoning_content = ""
reasoning_started = False
content_started = False
for chunk in response_stream:
# Normalize the chunk
if isinstance(chunk, str):
chunk = chunk.strip()
if chunk.startswith("data: "):
chunk = chunk[6:].strip()
if chunk in ["[DONE]", ""]:
continue
try:
parsed_chunk = json.loads(chunk)
except json.JSONDecodeError:
continue
elif isinstance(chunk, dict):
parsed_chunk = chunk
else:
continue
# Parse delta from the chunk
choices = parsed_chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
reasoning_token = delta.get("reasoning_content", "")
content_token = delta.get("content", "")
# Print reasoning token if applicable
if show_reasoning and reasoning_token:
if not reasoning_started:
print("\n🧠 Reasoning: ", end="", flush=True)
reasoning_started = True
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
reasoning_content += reasoning_token
# Print content token
if content_token:
if not content_started:
if show_reasoning and reasoning_started:
print(f"\n💬 Response: ", end="", flush=True)
else:
print("Assistant: ", end="", flush=True)
content_started = True
print(content_token, end="", flush=True)
full_response += content_token
print() # Ensure newline after response
if show_reasoning:
if reasoning_started or content_started:
print("\nStreaming completed.")
if reasoning_started:
print(f"Reasoning tokens: {len(reasoning_content.split())}")
if content_started:
print(f"Response tokens: {len(full_response.split())}")
return full_response
def test_tool_support(self) -> bool:
"""Test if the endpoint supports function calling"""
log.debug("Testing endpoint tool calling support...")
# Try a simple request with minimal tools to test support
messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [
{
"type": "function",
"function": {"name": "test_function", "description": "Test function"},
}
]
config = ChatCompletionConfig(
model=self.model,
messages=messages,
max_tokens=10,
tools=minimal_tool,
tool_choice="none", # Don't actually call the tool
)
try:
response = self.client.call_chat_completions(config)
return True
except Exception as e:
log.error(f"Error: Endpoint does not support tool calling: {e}")
return False
def demo_completions(self) -> None:
"""Demo: test basic completions endpoint"""
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
config = CompletionConfig(
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
)
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
)
response = self.client.call_completions(config)
if isinstance(response, dict):
print("\nResponse:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_chat(self, use_streaming: bool = True) -> None:
"""
Demo: test chat completions endpoint with optional streaming
"""
print("=" * 60)
print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60)
config = ChatCompletionConfig(
model=self.model,
messages=[{"role": "user", "content": CHAT_PROMPT}],
stream=use_streaming,
)
log.info(f"Testing chat completions with model '{self.model}'...")
response = self.client.call_chat_completions(config)
if use_streaming:
try:
self.handle_streaming_response(response, show_reasoning=True)
except Exception as e:
log.error(f"\nError during streaming: {e}")
import traceback
traceback.print_exc()
return
else:
if isinstance(response, dict):
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
print("=" * 60)
print("TOOL USE DEMO: List Directory Contents")
print("=" * 60)
# Test if tools are supported first
if not self.test_tool_support():
return
# Request with tool available
messages = [{"role": "user", "content": TOOLS_PROMPT}]
config = ChatCompletionConfig(
model=self.model,
messages=messages,
tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto",
)
log.info(f"Making initial request with tool using model '{self.model}'...")
response = self.client.call_chat_completions(config)
if not isinstance(response, dict):
raise ValueError("Expected dict response for tool use")
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
print(f"Assistant response: {message.get('content', 'No content')}")
# Check for tool calls
tool_calls = message.get("tool_calls")
if not tool_calls:
raise ValueError(
"No tool calls made - model may not support function calling"
)
print(f"Tool calls detected: {len(tool_calls)}")
# Execute the tool call
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
print(f"Executing tool: {function_name}")
tool_result = self.tool_manager.execute_tool_call(tool_call)
print(f"Tool result:\n{tool_result}")
# Add tool result and continue conversation
messages.append(message) # Add assistant's message with tool call
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": tool_result,
}
)
# Get final response
final_config = ChatCompletionConfig(
model=self.model,
messages=messages,
tools=self.tool_manager.get_ls_tool_definition(),
)
print("Getting final response...")
final_response = self.client.call_chat_completions(final_config)
if isinstance(final_response, dict):
final_choice = final_response.get("choices", [{}])[0]
final_message = final_choice.get("message", {})
final_content = final_message.get("content", "")
print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:")
print("=" * 60)
print(final_content)
print("=" * 60)
def interactive_chat(self) -> None:
"""Interactive chat session with streaming"""
print("=" * 60)
print("INTERACTIVE STREAMING CHAT")
print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history")
print()
messages = []
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() == "quit":
print("👋 Goodbye!")
break
elif user_input.lower() == "clear":
messages = []
print("Chat history cleared")
continue
elif not user_input:
continue
messages.append({"role": "user", "content": user_input})
config = ChatCompletionConfig(
model=self.model, messages=messages, stream=True, temperature=0.7
)
print("Assistant: ", end="", flush=True)
response = self.client.call_chat_completions(config)
assistant_content = self.handle_streaming_response(
response, show_reasoning=True
)
# Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content})
except KeyboardInterrupt:
print("\n👋 Chat interrupted. Goodbye!")
break
except Exception as e:
log.error(f"\nError: {e}")
continue
def main():
"""Main function with CLI switches for different tests"""
from lib.test_utils import test_args
# Add mandatory model argument
test_args.add_argument(
"--model", required=True, help="Model to use for requests (required)"
)
# Add test mode arguments
test_args.add_argument(
"--completion", action="store_true", help="Test completions endpoint"
)
test_args.add_argument(
"--chat",
action="store_true",
help="Test chat completions endpoint (non-streaming)",
)
test_args.add_argument(
"--chat-stream",
action="store_true",
help="Test chat completions endpoint with streaming",
)
test_args.add_argument(
"--tools",
action="store_true",
help="Test function calling with ls tool (non-streaming)",
)
test_args.add_argument(
"--interactive",
action="store_true",
help="Start interactive streaming chat session",
)
args = test_args.parse_args()
# Check that only one test mode is selected
test_modes = [
args.completion,
args.chat,
args.chat_stream,
args.tools,
args.interactive,
]
selected_count = sum(test_modes)
if selected_count == 0:
print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint")
print(" --chat : Test chat completions endpoint (non-streaming)")
print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool (non-streaming)")
print(" --interactive : Start interactive streaming chat session")
print(
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
sys.exit(1)
elif selected_count > 1:
print("Please specify exactly one test mode")
sys.exit(1)
try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client
client = APIClient(
endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key,
server_url=Endpoint.get_autoscaler_server_url(args.instance),
endpoint_api_key=endpoint_api_key,
)
# Create tool manager and demo (passing the model parameter)
tool_manager = ToolManager()
demo = APIDemo(client, args.model, tool_manager)
print(f"Using model: {args.model}")
print("=" * 60)
# Run the selected test
if args.completion:
demo.demo_completions()
elif args.chat:
demo.demo_chat(use_streaming=False)
elif args.chat_stream:
demo.demo_chat(use_streaming=True)
elif args.tools:
demo.demo_ls_tool()
elif args.interactive:
demo.interactive_chat()
except Exception as e:
log.error(f"Error during test: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()
+58
View File
@@ -0,0 +1,58 @@
import json
from dataclasses import dataclass, field, fields, is_dataclass
from typing import Optional, List, Dict, Any
class SerializableDataclass:
def _serialize_recursive(self, obj: Any) -> Any:
if is_dataclass(obj):
return {
field.name: self._serialize_recursive(getattr(obj, field.name))
for field in fields(obj)
}
elif isinstance(obj, dict):
return {key: self._serialize_recursive(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return [self._serialize_recursive(item) for item in obj]
elif isinstance(obj, set):
return [self._serialize_recursive(item) for item in obj]
else:
return obj
def to_dict(self) -> Dict[str, Any]:
return self._serialize_recursive(self)
def to_json(self, indent: int = 2) -> str:
return json.dumps(self.to_dict(), indent=indent)
@dataclass
class CompletionConfig(SerializableDataclass):
"""Configuration for completion requests"""
model: str
prompt: str = "Hello"
max_tokens: int = 256
temperature: float = 0.7
top_k: int = 20
top_p: float = 0.4
stream: bool = False
@dataclass
class ChatCompletionConfig(SerializableDataclass):
"""Configuration for chat completion requests"""
model: str
messages: list = field(default_factory=list)
max_tokens: int = 2096
temperature: float = 0.7
top_k: int = 20
top_p: float = 0.4
stream: bool = False
tools: Optional[List[Dict[str, Any]]] = field(default_factory=list)
tool_choice: str = "auto"
def __post_init__(self):
if self.messages is None:
self.messages = [{"role": "user", "content": "Hello"}]
+182
View File
@@ -0,0 +1,182 @@
import os, json, random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
from typing import Union, Type, Dict, Any, Optional
from aiohttp import web, ClientResponse
import nltk
import logging
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
log = logging.getLogger(__name__)
"""
Generic dataclass accepts any dictionary in input.
"""
@dataclass
class GenericData(ApiPayload, ABC):
input: Dict[str, Any]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
return cls(input=data["input"])
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
errors = {}
# Validate required parameters
required_params = ["input"]
for param in required_params:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
try:
# Create clean data dict and delegate to from_dict
clean_data = {"input": json_msg["input"]}
return cls.from_dict(clean_data)
except (json.JSONDecodeError, JsonDataException) as e:
errors["parameters"] = str(e)
raise JsonDataException(errors)
@classmethod
@abstractmethod
def for_test(cls) -> "GenericData":
pass
def generate_payload_json(self) -> Dict[str, Any]:
return self.input
def count_workload(self) -> int:
return self.input.get("max_tokens", 0)
@dataclass
class GenericHandler(EndpointHandler[GenericData], ABC):
@property
@abstractmethod
def endpoint(self) -> str:
pass
@property
def healthcheck_endpoint(self) -> Optional[str]:
return os.environ.get("MODEL_HEALTH_ENDPOINT")
@classmethod
def payload_cls(cls) -> Type[GenericData]:
return GenericData
@abstractmethod
def make_benchmark_payload(self) -> GenericData:
pass
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
# Check if the response is actually streaming based on response headers/content-type
is_streaming_response = (
model_response.content_type == "text/event-stream"
or model_response.content_type == "application/x-ndjson"
or model_response.headers.get("Transfer-Encoding") == "chunked"
or "stream" in model_response.content_type.lower()
)
if is_streaming_response:
log.debug("Detected streaming response...")
res = web.StreamResponse()
res.content_type = model_response.content_type
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
else:
log.debug("Detected non-streaming response...")
content = await model_response.read()
return web.Response(
body=content,
status=200,
content_type=model_response.content_type,
)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
@dataclass
class CompletionsData(GenericData):
@classmethod
def for_test(cls) -> "CompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
test_input = {
"model": model,
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input)
@dataclass
class CompletionsHandler(GenericHandler):
@property
def endpoint(self) -> str:
return "/v1/completions"
@classmethod
def payload_cls(cls) -> Type[CompletionsData]:
return CompletionsData
def make_benchmark_payload(self) -> CompletionsData:
return CompletionsData.for_test()
@dataclass
class ChatCompletionsData(GenericData):
"""Chat completions-specific data implementation"""
@classmethod
def for_test(cls) -> "ChatCompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
# Chat completions use messages format instead of prompt
test_input = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input)
@dataclass
class ChatCompletionsHandler(GenericHandler):
@property
def endpoint(self) -> str:
return "/v1/chat/completions"
@classmethod
def payload_cls(cls) -> Type[ChatCompletionsData]:
return ChatCompletionsData
def make_benchmark_payload(self) -> ChatCompletionsData:
return ChatCompletionsData.for_test()
+60
View File
@@ -0,0 +1,60 @@
import os
import logging
from .data_types.server import CompletionsHandler, ChatCompletionsHandler
from aiohttp import web
from lib.backend import Backend, LogAction
from lib.server import start_server
# This line indicates that the inference server is listening
MODEL_SERVER_START_LOG_MSG = [
"Application startup complete.", # vLLM
"llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI
]
MODEL_SERVER_ERROR_LOG_MSGS = [
"INFO exited: vllm", # vLLM
"RuntimeError: Engine", # vLLM
"Error: pull model manifest:", # Ollama
"stalled; retrying", # Ollama
"Error: WebserverFailed", # TGI
"Error: DownloadError", # TGI
"Error: ShardCannotStart", # TGI
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
backend = Backend(
model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
+28
View File
@@ -0,0 +1,28 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types.server import CompletionsData
import os
WORKER_ENDPOINT = "/v1/completions"
if __name__ == "__main__":
# Check if MODEL_NAME environment variable is set
model_name_set = os.environ.get("MODEL_NAME") is not None
# Add model argument - required only if MODEL_NAME is not set
test_args.add_argument(
"--model",
dest="model",
required=not model_name_set,
help="Model to use for completions request (required if MODEL_NAME env var not set)",
)
# Parse known args to get model early, before test_load_cmd adds its args
known_args, _ = test_args.parse_known_args()
# Set environment variable if model was provided
if hasattr(known_args, "model") and known_args.model:
os.environ["MODEL_NAME"] = known_args.model
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
# Now call test_load_cmd normally - it will add its own args and re-parse
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)
+7 -1
View File
@@ -4,6 +4,7 @@ import json
from urllib.parse import urljoin
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
logging.basicConfig(
level=logging.DEBUG,
@@ -42,7 +43,11 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(url, json=req_data)
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
res = response.json()
print(res)
@@ -100,6 +105,7 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
try: