Compare commits
104 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eedf81c0a3 | |||
| 3adec1826d | |||
| b55bfa9611 | |||
| 7db54f3bd7 | |||
| d63a060202 | |||
| c6521cb6d4 | |||
| b7fe4ebb91 | |||
| 8ae7b74605 | |||
| 106067d716 | |||
| f5134d4bf5 | |||
| 47e5460532 | |||
| ec2ac0a21a | |||
| 2cde573c56 | |||
| b2e4a5db0c | |||
| 7437028cb2 | |||
| 02c8307af7 | |||
| 7c0f316eeb | |||
| b4025a744f | |||
| d190308329 | |||
| 9f5a432513 | |||
| e09f1fa953 | |||
| ba6f1c2e4b | |||
| 944f83fc03 | |||
| 298590fb88 | |||
| 814c3acd4c | |||
| 22bca74087 | |||
| 9c795e2a01 | |||
| 830b532781 | |||
| d6a6e34c6b | |||
| ac1e109c48 | |||
| d6eb498ee4 | |||
| f56bbc0ebe | |||
| bcecd6df40 | |||
| 4d9bf2048c | |||
| 7788bc4a62 | |||
| 37ad3f8d46 | |||
| 70d51bafe1 | |||
| 63909736bb | |||
| f4f7080df1 | |||
| d51a338e8f | |||
| 92a04bd7af | |||
| 0f13506938 | |||
| 01e752d31f | |||
| 5edfa968ca | |||
| 5b5ef7227a | |||
| 16990ff8ff | |||
| 9748176366 | |||
| b39193ae70 | |||
| 9a6ca5d412 | |||
| e9ba1b03e4 | |||
| c98d661513 | |||
| f6fd1c6ac1 | |||
| 055e346c8c | |||
| 1cedb28acf | |||
| ec25dda3ad | |||
| 0397af719d | |||
| 4fdc314fd9 | |||
| 3786cf978d | |||
| a86d4bcf9c | |||
| e9b6a14a5e | |||
| cadac033e1 | |||
| 639d82f5b4 | |||
| 25db78e39d | |||
| 4e2f2311d0 | |||
| 38782d89bc | |||
| 0185216ccb | |||
| b20d9e714c | |||
| b1eb65d75d | |||
| 1d09d7fe96 | |||
| 1b37054dec | |||
| 1a1e4174b8 | |||
| b8377c4081 | |||
| 1e4fa87437 | |||
| 4c5fa03c7b | |||
| a8fe74f771 | |||
| b482de8394 | |||
| 703435d10e | |||
| 947fc5eea4 | |||
| 7c1a544b19 | |||
| 16b414676e | |||
| ba74ac8136 | |||
| 92ff412679 | |||
| fc75a64684 | |||
| b00bef547c | |||
| 3f4acb29fa | |||
| 58b078f908 | |||
| f9fdf04884 | |||
| 636f17d27f | |||
| 08c88f7527 | |||
| 8797b504af | |||
| cd946b0a9f | |||
| c595b42410 | |||
| 0bf3247a34 | |||
| 52ac4c0c1a | |||
| 8804e17201 | |||
| 4016cf9a53 | |||
| e0be45f39a | |||
| be2aafdb1f | |||
| 9e369c55a5 | |||
| 69d9b7455f | |||
| 6fb610cb5b | |||
| 0bf2d04223 | |||
| 9ebf1924ea | |||
| 0ab9a13a46 |
@@ -3,3 +3,4 @@
|
||||
__pycache__
|
||||
bin/
|
||||
lib64
|
||||
.venv
|
||||
@@ -39,11 +39,12 @@ reporting these metrics to the autoscaler.
|
||||
|
||||
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
|
||||
|
||||
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00)
|
||||
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447)
|
||||
* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d)
|
||||
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d)
|
||||
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938)
|
||||
|
||||
Currently available workers:
|
||||
* `hello_world`: A simple example worker for a basic LLM server.
|
||||
* `openai`: A simple example worker for a basic vLLM server.
|
||||
* `comfyui`: A worker for the ComfyUI image generation backend.
|
||||
* `tgi`: A worker for the Text Generation Inference backend.
|
||||
|
||||
|
||||
+149
-80
@@ -8,9 +8,11 @@ import logging
|
||||
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||
from functools import cached_property
|
||||
from distutils.util import strtobool
|
||||
|
||||
from anyio import open_file
|
||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
|
||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
||||
import asyncio
|
||||
|
||||
import requests
|
||||
from Crypto.Signature import pkcs1_15
|
||||
@@ -24,8 +26,12 @@ from lib.data_types import (
|
||||
LogAction,
|
||||
ApiPayload_T,
|
||||
JsonDataException,
|
||||
RequestMetrics,
|
||||
BenchmarkResult
|
||||
)
|
||||
|
||||
VERSION = "0.2.0"
|
||||
|
||||
MSG_HISTORY_LEN = 100
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
@@ -52,14 +58,28 @@ class Backend:
|
||||
EndpointHandler # this endpoint handler will be used for benchmarking
|
||||
)
|
||||
log_actions: List[Tuple[LogAction, str]]
|
||||
max_wait_time: float = 10.0
|
||||
reqnum = -1
|
||||
version = VERSION
|
||||
msg_history = []
|
||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||
unsecured: bool = dataclasses.field(
|
||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||
)
|
||||
report_addr: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
||||
)
|
||||
mtoken: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.metrics = Metrics()
|
||||
self.metrics._set_version(self.version)
|
||||
self.metrics._set_mtoken(self.mtoken)
|
||||
self._total_pubkey_fetch_errors = 0
|
||||
self._pubkey = self._fetch_pubkey()
|
||||
self.__start_healthcheck: bool = False
|
||||
|
||||
@property
|
||||
def pubkey(self) -> Optional[RSA.RsaKey]:
|
||||
@@ -70,7 +90,13 @@ class Backend:
|
||||
@cached_property
|
||||
def session(self):
|
||||
log.debug(f"starting session with {self.model_server_url}")
|
||||
return ClientSession(self.model_server_url)
|
||||
connector = TCPConnector(
|
||||
force_close=True, # Required for long running jobs
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
|
||||
timeout = ClientTimeout(total=None)
|
||||
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
||||
|
||||
def create_handler(
|
||||
self,
|
||||
@@ -85,23 +111,19 @@ class Backend:
|
||||
|
||||
#######################################Private#######################################
|
||||
def _fetch_pubkey(self):
|
||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||
report_addr = self.report_addr.rstrip("/")
|
||||
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
||||
try:
|
||||
result = subprocess.check_output(command, universal_newlines=True)
|
||||
log.debug("public key:")
|
||||
log.debug(result)
|
||||
key = None
|
||||
for _ in range(5):
|
||||
try:
|
||||
key = RSA.import_key(result)
|
||||
break
|
||||
except ValueError as e:
|
||||
log.debug(f"Error downloading key: {e}")
|
||||
time.sleep(15)
|
||||
if key is None:
|
||||
self._total_pubkey_fetch_errors += 1
|
||||
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
||||
self.backend_errored("Failed to get autoscaler pubkey")
|
||||
if key is not None:
|
||||
return key
|
||||
except (ValueError , subprocess.CalledProcessError) as e:
|
||||
log.debug(f"Error downloading key: {e}")
|
||||
self.backend_errored("Failed to get autoscaler pubkey")
|
||||
|
||||
|
||||
async def __handle_request(
|
||||
self,
|
||||
@@ -117,62 +139,56 @@ class Backend:
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||
workload = payload.count_workload()
|
||||
|
||||
async def wait_for_disconnection() -> None:
|
||||
while request.transport and not request.transport.is_closing():
|
||||
await sleep(0.5)
|
||||
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
||||
|
||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||
await wait_for_disconnection()
|
||||
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
||||
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
|
||||
return web.Response(status=500)
|
||||
await request.wait_for_disconnection()
|
||||
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
|
||||
self.metrics._request_canceled(request_metrics)
|
||||
raise asyncio.CancelledError
|
||||
|
||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||
log.debug(f"got request, {auth_data.reqnum}")
|
||||
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
|
||||
if self.allow_parallel_requests is False:
|
||||
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
|
||||
await self.sem.acquire()
|
||||
log.debug(
|
||||
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
|
||||
)
|
||||
else:
|
||||
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = await self.__call_api(handler=handler, payload=payload)
|
||||
status_code = response.status
|
||||
log.debug(
|
||||
" ".join(
|
||||
[
|
||||
f"request with reqnum:{auth_data.reqnum}",
|
||||
f"request with reqnum:{request_metrics.reqnum}",
|
||||
f"returned status code: {status_code},",
|
||||
]
|
||||
)
|
||||
)
|
||||
res = await handler.generate_client_response(request, response)
|
||||
self.metrics._request_end(
|
||||
workload=workload,
|
||||
req_response_time=time.time() - start_time,
|
||||
reqnum=auth_data.reqnum,
|
||||
)
|
||||
self.metrics._request_success(request_metrics)
|
||||
return res
|
||||
except requests.exceptions.RequestException as e:
|
||||
log.debug(f"[backend] Request error: {e}")
|
||||
self.metrics._request_errored(
|
||||
workload=workload, reqnum=auth_data.reqnum
|
||||
)
|
||||
self.metrics._request_errored(request_metrics)
|
||||
return web.Response(status=500)
|
||||
finally:
|
||||
self.sem.release()
|
||||
|
||||
###########
|
||||
|
||||
if self.__check_signature(auth_data) is False:
|
||||
self.metrics._request_reject(request_metrics)
|
||||
return web.Response(status=401)
|
||||
|
||||
if self.metrics.model_metrics.wait_time > self.max_wait_time:
|
||||
self.metrics._request_reject(request_metrics)
|
||||
return web.Response(status=429)
|
||||
|
||||
acquired = False
|
||||
try:
|
||||
self.metrics._request_start(request_metrics)
|
||||
if self.allow_parallel_requests is False:
|
||||
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
||||
await self.sem.acquire()
|
||||
acquired = True
|
||||
log.debug(
|
||||
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
|
||||
)
|
||||
else:
|
||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||
done, pending = await wait(
|
||||
[
|
||||
create_task(make_request()),
|
||||
@@ -180,21 +196,52 @@ class Backend:
|
||||
],
|
||||
return_when=FIRST_COMPLETED,
|
||||
)
|
||||
[task.cancel() for task in pending]
|
||||
return done.pop().result()
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
done_task = done.pop()
|
||||
try:
|
||||
return done_task.result()
|
||||
except Exception as e:
|
||||
log.debug(f"Request task raised exception: {e}")
|
||||
return web.Response(status=500)
|
||||
except asyncio.CancelledError:
|
||||
# Client is gone. Do not write a response; just unwind.
|
||||
return web.Response(status=499)
|
||||
except Exception as e:
|
||||
log.debug(f"Exception in main handler loop {e}")
|
||||
return web.Response(status=500)
|
||||
finally:
|
||||
# Always release the semaphore if it was acquired
|
||||
if acquired:
|
||||
self.sem.release()
|
||||
self.metrics._request_end(request_metrics)
|
||||
|
||||
@cached_property
|
||||
def healthcheck_session(self):
|
||||
"""Dedicated session for healthchecks to avoid conflicts with API session"""
|
||||
log.debug("creating dedicated healthcheck session")
|
||||
connector = TCPConnector(
|
||||
force_close=True, # Keep this for isolation
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
timeout = ClientTimeout(total=10) # Reasonable timeout for healthchecks
|
||||
return ClientSession(timeout=timeout, connector=connector)
|
||||
|
||||
async def __healthcheck(self):
|
||||
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
||||
if health_check_url is None:
|
||||
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||
return
|
||||
await sleep(5)
|
||||
|
||||
while True:
|
||||
await sleep(10)
|
||||
if self.__start_healthcheck is False:
|
||||
continue
|
||||
try:
|
||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||
async with self.session.get(health_check_url) as response:
|
||||
async with self.healthcheck_session.get(health_check_url) as response:
|
||||
if response.status == 200:
|
||||
log.debug("Healthcheck successful")
|
||||
elif response.status == 503:
|
||||
@@ -203,7 +250,6 @@ class Backend:
|
||||
f"Healthcheck failed with status: {response.status}"
|
||||
)
|
||||
else:
|
||||
# endpoint not ready yet so bail
|
||||
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
||||
except Exception as e:
|
||||
log.debug(f"Healthcheck failed with exception: {e}")
|
||||
@@ -211,7 +257,7 @@ class Backend:
|
||||
|
||||
async def _start_tracking(self) -> None:
|
||||
await gather(
|
||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
|
||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
||||
)
|
||||
|
||||
def backend_errored(self, msg: str) -> None:
|
||||
@@ -225,6 +271,9 @@ class Backend:
|
||||
return await self.session.post(url=handler.endpoint, json=api_payload)
|
||||
|
||||
def __check_signature(self, auth_data: AuthData) -> bool:
|
||||
if self.unsecured is True:
|
||||
return True
|
||||
|
||||
def verify_signature(message, signature):
|
||||
if self.pubkey is None:
|
||||
log.debug(f"No Public Key!")
|
||||
@@ -240,7 +289,7 @@ class Backend:
|
||||
message = {
|
||||
key: value
|
||||
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||
if key != "signature"
|
||||
if key != "signature" and key != "__request_id"
|
||||
}
|
||||
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||
log.debug(
|
||||
@@ -250,7 +299,7 @@ class Backend:
|
||||
elif message in self.msg_history:
|
||||
log.debug(f"message: {message} already in message history")
|
||||
return False
|
||||
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
|
||||
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
|
||||
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||
self.msg_history.append(message)
|
||||
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||
@@ -269,48 +318,67 @@ class Backend:
|
||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||
log.debug("already ran benchmark")
|
||||
# trigger model load
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
_ = await self.__call_api(
|
||||
handler=self.benchmark_handler, payload=payload
|
||||
)
|
||||
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||
# _ = await self.__call_api(
|
||||
# handler=self.benchmark_handler, payload=payload
|
||||
# )
|
||||
return float(f.readline())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
max_throughput = 0
|
||||
last_throughput = 0
|
||||
sum_throughput = 0
|
||||
for run in range(self.benchmark_handler.benchmark_runs + 1):
|
||||
start = time.time()
|
||||
|
||||
log.debug("Initial run to trigger model loading...")
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||
|
||||
max_throughput = 0
|
||||
sum_throughput = 0
|
||||
concurrent_requests = 10 if self.allow_parallel_requests else 1
|
||||
|
||||
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
||||
start = time.time()
|
||||
benchmark_requests = []
|
||||
|
||||
for i in range(concurrent_requests):
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
res = await self.__call_api(
|
||||
handler=self.benchmark_handler, payload=payload
|
||||
)
|
||||
data = await res.json()
|
||||
time_elapsed = time.time() - start
|
||||
# first run triggers one-time loading of the model which is very slow, so we skip counting it
|
||||
if run == 0:
|
||||
continue
|
||||
else:
|
||||
workload = payload.count_workload()
|
||||
last_throughput = workload / time_elapsed
|
||||
sum_throughput += last_throughput
|
||||
max_throughput = max(max_throughput, last_throughput)
|
||||
task = self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||
benchmark_requests.append(
|
||||
BenchmarkResult(request_idx=i, workload=workload, task=task)
|
||||
)
|
||||
|
||||
responses = await gather(*[br.task for br in benchmark_requests])
|
||||
for br, response in zip(benchmark_requests, responses):
|
||||
br.response = response
|
||||
|
||||
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
||||
time_elapsed = time.time() - start
|
||||
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
||||
if successful_responses == 0:
|
||||
self.backend_errored("No successful responses from benchmark")
|
||||
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
|
||||
|
||||
throughput = total_workload / time_elapsed
|
||||
sum_throughput += throughput
|
||||
max_throughput = max(max_throughput, throughput)
|
||||
|
||||
# Log results for debugging
|
||||
log.debug(
|
||||
"\n".join(
|
||||
[
|
||||
"#" * 60,
|
||||
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
|
||||
"",
|
||||
f"response: {data}",
|
||||
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
||||
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
||||
f"Throughput: {throughput} workload/s",
|
||||
f"Successful responses: {successful_responses}/{concurrent_requests}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
||||
log.debug(
|
||||
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||
)
|
||||
# save max_throughput so we don't have to run benchmark again on restart of cold instances
|
||||
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||
f.write(str(max_throughput))
|
||||
return max_throughput
|
||||
@@ -328,9 +396,10 @@ class Backend:
|
||||
)
|
||||
# some backends need a few seconds after logging successful startup before
|
||||
# they can begin accepting requests
|
||||
await sleep(5)
|
||||
# await sleep(5)
|
||||
try:
|
||||
max_throughput = await run_benchmark()
|
||||
self.__start_healthcheck = True
|
||||
self.metrics._model_loaded(
|
||||
max_throughput=max_throughput,
|
||||
)
|
||||
@@ -348,13 +417,13 @@ class Backend:
|
||||
|
||||
async def tail_log():
|
||||
log.debug(f"tailing file: {self.model_log_file}")
|
||||
async with await open_file(self.model_log_file) as f:
|
||||
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore'):
|
||||
while True:
|
||||
line = await f.readline()
|
||||
if line:
|
||||
await handle_log_line(line.rstrip())
|
||||
else:
|
||||
time.sleep(LOG_POLL_INTERVAL)
|
||||
await asyncio.sleep(LOG_POLL_INTERVAL)
|
||||
|
||||
###########
|
||||
|
||||
|
||||
+53
-9
@@ -3,12 +3,11 @@ import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
|
||||
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
|
||||
from aiohttp import web, ClientResponse
|
||||
import inspect
|
||||
|
||||
import psutil
|
||||
import requests
|
||||
|
||||
|
||||
"""
|
||||
@@ -66,10 +65,11 @@ class ApiPayload(ABC):
|
||||
class AuthData:
|
||||
"""data used to authenticate requester"""
|
||||
|
||||
signature: str
|
||||
cost: str
|
||||
endpoint: str
|
||||
reqnum: int
|
||||
request_idx: int
|
||||
signature: str
|
||||
url: str
|
||||
|
||||
@classmethod
|
||||
@@ -190,13 +190,34 @@ class SystemMetrics:
|
||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||
self.last_disk_usage = disk_usage
|
||||
|
||||
def reset(self):
|
||||
def reset(self, expected: float | None) -> None:
|
||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||
# as well: they should send model_loading_time once when they are done loading
|
||||
if self.model_loading_time == expected:
|
||||
self.model_loading_time = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestMetrics:
|
||||
"""Tracks metrics for an active request."""
|
||||
request_idx: int
|
||||
reqnum: int
|
||||
workload: float
|
||||
status: str
|
||||
success: bool = False
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
request_idx: int
|
||||
workload: float
|
||||
task: Awaitable[ClientResponse]
|
||||
response: Optional[ClientResponse] = None
|
||||
|
||||
@property
|
||||
def is_successful(self) -> bool:
|
||||
return self.response is not None and self.response.status == 200
|
||||
|
||||
@dataclass
|
||||
class ModelMetrics:
|
||||
"""Model specific metrics"""
|
||||
@@ -206,13 +227,15 @@ class ModelMetrics:
|
||||
workload_received: float
|
||||
workload_cancelled: float
|
||||
workload_errored: float
|
||||
workload_pending: float
|
||||
workload_rejected: float
|
||||
# these are not
|
||||
cur_perf: float
|
||||
workload_pending: float
|
||||
error_msg: Optional[str]
|
||||
max_throughput: float
|
||||
requests_recieved: Set[int] = field(default_factory=set)
|
||||
requests_working: Set[int] = field(default_factory=set)
|
||||
requests_working: dict[int, RequestMetrics] = field(default_factory=dict)
|
||||
requests_deleting: list[RequestMetrics] = field(default_factory=list)
|
||||
last_update: float = field(default_factory=time.time)
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
@@ -221,7 +244,7 @@ class ModelMetrics:
|
||||
workload_served=0.0,
|
||||
workload_cancelled=0.0,
|
||||
workload_errored=0.0,
|
||||
cur_perf=0.0,
|
||||
workload_rejected=0.0,
|
||||
workload_received=0.0,
|
||||
error_msg=None,
|
||||
max_throughput=0.0,
|
||||
@@ -231,6 +254,20 @@ class ModelMetrics:
|
||||
def workload_processing(self) -> float:
|
||||
return max(self.workload_received - self.workload_cancelled, 0.0)
|
||||
|
||||
@property
|
||||
def wait_time(self) -> float:
|
||||
if (len(self.requests_working) == 0):
|
||||
return 0.0
|
||||
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
|
||||
|
||||
@property
|
||||
def cur_load(self) -> float:
|
||||
return sum([request.workload for request in self.requests_working.values()])
|
||||
|
||||
@property
|
||||
def working_request_idxs(self) -> list[int]:
|
||||
return [req.request_idx for req in self.requests_working.values()]
|
||||
|
||||
def set_errored(self, error_msg):
|
||||
self.reset()
|
||||
self.error_msg = error_msg
|
||||
@@ -240,15 +277,21 @@ class ModelMetrics:
|
||||
self.workload_received = 0
|
||||
self.workload_cancelled = 0
|
||||
self.workload_errored = 0
|
||||
self.workload_rejected = 0
|
||||
self.last_update = time.time()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoScalaerData:
|
||||
class AutoScalerData:
|
||||
"""Data that is reported to autoscaler"""
|
||||
|
||||
id: int
|
||||
mtoken: str
|
||||
version: str
|
||||
loadtime: float
|
||||
cur_load: float
|
||||
rej_load: float
|
||||
new_load: float
|
||||
error_msg: str
|
||||
max_perf: float
|
||||
cur_perf: float
|
||||
@@ -257,6 +300,7 @@ class AutoScalaerData:
|
||||
num_requests_working: int
|
||||
num_requests_recieved: int
|
||||
additional_disk_usage: float
|
||||
working_request_idxs: list[int]
|
||||
url: str
|
||||
|
||||
|
||||
|
||||
+179
-46
@@ -5,14 +5,14 @@ import json
|
||||
from asyncio import sleep
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from functools import cache
|
||||
from urllib.parse import urljoin
|
||||
import asyncio
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError
|
||||
|
||||
import requests
|
||||
|
||||
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
|
||||
from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics
|
||||
from typing import Awaitable, NoReturn, List
|
||||
|
||||
METRICS_UPDATE_INTERVAL = 1
|
||||
DELETE_REQUESTS_INTERVAL = 1
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
@@ -27,7 +27,10 @@ def get_url() -> str:
|
||||
|
||||
@dataclass
|
||||
class Metrics:
|
||||
version: str = "0"
|
||||
mtoken: str = ""
|
||||
last_metric_update: float = 0.0
|
||||
last_request_served: float = 0.0
|
||||
update_pending: bool = False
|
||||
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
|
||||
report_addr: List[str] = field(
|
||||
@@ -36,44 +39,84 @@ class Metrics:
|
||||
url: str = field(default_factory=get_url)
|
||||
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
|
||||
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
|
||||
_session: ClientSession | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def _request_start(self, workload: float, reqnum: int) -> None:
|
||||
async def http(self) -> ClientSession:
|
||||
if self._session is None:
|
||||
self._session = ClientSession(
|
||||
timeout=ClientTimeout(total=10),
|
||||
connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True)
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
def _request_start(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called prior to forwarding a request to a model API.
|
||||
"""
|
||||
log.debug("request start")
|
||||
self.model_metrics.workload_pending += workload
|
||||
self.model_metrics.workload_received += workload
|
||||
self.model_metrics.requests_recieved.add(reqnum)
|
||||
self.model_metrics.requests_working.add(reqnum)
|
||||
|
||||
def _request_end(
|
||||
self, workload: float, req_response_time: float, reqnum: int
|
||||
) -> None:
|
||||
"""
|
||||
this function is called after a response from model API is received.
|
||||
"""
|
||||
self.model_metrics.workload_served += workload
|
||||
self.model_metrics.workload_pending -= workload
|
||||
self.model_metrics.requests_working.discard(reqnum)
|
||||
self.model_metrics.cur_perf = workload / req_response_time
|
||||
request.status = "Started"
|
||||
self.model_metrics.workload_pending += request.workload
|
||||
self.model_metrics.workload_received += request.workload
|
||||
self.model_metrics.requests_recieved.add(request.reqnum)
|
||||
self.model_metrics.requests_working[request.reqnum] = request
|
||||
self.update_pending = True
|
||||
|
||||
def _request_errored(self, workload: float, reqnum: int) -> None:
|
||||
def _request_end(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called after handling of a request ends, regardless of the outcome
|
||||
"""
|
||||
self.model_metrics.workload_pending -= request.workload
|
||||
self.model_metrics.requests_working.pop(request.reqnum, None)
|
||||
self.model_metrics.requests_deleting.append(request)
|
||||
self.last_request_served = time.time()
|
||||
|
||||
def _request_success(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called after a response from model API is received and forwarded.
|
||||
"""
|
||||
self.model_metrics.workload_served += request.workload
|
||||
request.status = "Success"
|
||||
request.success = True
|
||||
self.update_pending = True
|
||||
|
||||
def _request_errored(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called if model API returns an error
|
||||
"""
|
||||
self.model_metrics.workload_pending -= workload
|
||||
self.model_metrics.workload_errored += workload
|
||||
self.model_metrics.requests_working.discard(reqnum)
|
||||
self.model_metrics.workload_errored += request.workload
|
||||
request.status = "Error"
|
||||
request.success = False
|
||||
self.update_pending = True
|
||||
|
||||
def _request_canceled(self, workload: float, reqnum: int) -> None:
|
||||
def _request_canceled(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called if client drops connection before model API has responded
|
||||
"""
|
||||
self.model_metrics.workload_pending -= workload
|
||||
self.model_metrics.workload_cancelled += workload
|
||||
self.model_metrics.requests_working.discard(reqnum)
|
||||
self.model_metrics.workload_cancelled += request.workload
|
||||
request.success = True
|
||||
request.status = "Cancelled"
|
||||
|
||||
def _request_reject(self, request: RequestMetrics):
|
||||
"""
|
||||
this function is called if the current wait time for the model is above max_wait_time
|
||||
"""
|
||||
self.model_metrics.requests_recieved.add(request.reqnum)
|
||||
self.model_metrics.requests_deleting.append(request)
|
||||
self.model_metrics.workload_rejected += request.workload
|
||||
request.success = False
|
||||
request.status = "Rejected"
|
||||
self.update_pending = True
|
||||
|
||||
async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]:
|
||||
while True:
|
||||
await sleep(DELETE_REQUESTS_INTERVAL)
|
||||
if len(self.model_metrics.requests_deleting) > 0:
|
||||
await self.__send_delete_requests_and_reset()
|
||||
|
||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||
while True:
|
||||
@@ -81,10 +124,10 @@ class Metrics:
|
||||
elapsed = time.time() - self.last_metric_update
|
||||
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
||||
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
||||
self.__send_metrics_and_reset(elapsed)
|
||||
await self.__send_metrics_and_reset()
|
||||
elif self.update_pending or elapsed > 10:
|
||||
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
||||
self.__send_metrics_and_reset(elapsed)
|
||||
await self.__send_metrics_and_reset()
|
||||
|
||||
def _model_loaded(self, max_throughput: float) -> None:
|
||||
self.system_metrics.model_loading_time = (
|
||||
@@ -97,57 +140,147 @@ class Metrics:
|
||||
self.model_metrics.set_errored(error_msg)
|
||||
self.system_metrics.model_is_loaded = True
|
||||
|
||||
def _set_version(self, version: str) -> None:
|
||||
self.version = version
|
||||
|
||||
def _set_mtoken(self, mtoken: str) -> None:
|
||||
self.mtoken = mtoken
|
||||
|
||||
#######################################Private#######################################
|
||||
|
||||
def __send_metrics_and_reset(self, elapsed):
|
||||
async def __send_delete_requests_and_reset(self):
|
||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||
data = {
|
||||
"worker_id": self.id,
|
||||
"mtoken": self.mtoken,
|
||||
"request_idxs": idxs,
|
||||
"success": success_flag,
|
||||
}
|
||||
log.debug(
|
||||
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
|
||||
)
|
||||
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
session = await self.http()
|
||||
async with session.post(full_path, json=data) as res:
|
||||
log.debug(f"delete_requests response: {res.status}")
|
||||
res.raise_for_status()
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
log.debug("delete_requests timed out")
|
||||
except (ClientResponseError, Exception) as e:
|
||||
log.debug(f"delete_requests failed with error: {e}")
|
||||
await asyncio.sleep(2)
|
||||
log.debug(f"retrying delete_request, attempt: {attempt}")
|
||||
return False
|
||||
|
||||
def compute_autoscaler_data() -> AutoScalaerData:
|
||||
return AutoScalaerData(
|
||||
# Take a snapshot of what we plan to send this tick.
|
||||
# New arrivals after this snapshot will remain in the queue for the next tick.
|
||||
snapshot = list(self.model_metrics.requests_deleting)
|
||||
success_idxs = [r.request_idx for r in snapshot if r.success is True]
|
||||
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
|
||||
|
||||
if not success_idxs and not failed_idxs:
|
||||
return # nothing to do
|
||||
|
||||
for report_addr in self.report_addr:
|
||||
# TODO: Add a Redis subscriber queue for delete_requests
|
||||
if report_addr == "https://cloud.vast.ai/api/v0":
|
||||
# Patch: ignore the Redis API report_addr
|
||||
continue
|
||||
sent_success = True
|
||||
sent_failed = True
|
||||
|
||||
if success_idxs:
|
||||
sent_success = await post(report_addr, success_idxs, True)
|
||||
if failed_idxs:
|
||||
sent_failed = await post(report_addr, failed_idxs, False)
|
||||
|
||||
if sent_success and sent_failed:
|
||||
# Remove only the items we actually sent from the live queue.
|
||||
sent_set = set(success_idxs) | set(failed_idxs)
|
||||
self.model_metrics.requests_deleting[:] = [
|
||||
r for r in self.model_metrics.requests_deleting
|
||||
if r.request_idx not in sent_set
|
||||
]
|
||||
break
|
||||
|
||||
|
||||
async def __send_metrics_and_reset(self):
|
||||
|
||||
loadtime_snapshot = self.system_metrics.model_loading_time
|
||||
|
||||
def compute_autoscaler_data() -> AutoScalerData:
|
||||
return AutoScalerData(
|
||||
id=self.id,
|
||||
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
||||
cur_load=(self.model_metrics.workload_processing / elapsed),
|
||||
mtoken=self.mtoken,
|
||||
version=self.version,
|
||||
loadtime=(loadtime_snapshot or 0.0),
|
||||
new_load=self.model_metrics.workload_processing,
|
||||
cur_load=self.model_metrics.cur_load,
|
||||
rej_load=self.model_metrics.workload_rejected,
|
||||
max_perf=self.model_metrics.max_throughput,
|
||||
cur_perf=self.model_metrics.cur_perf,
|
||||
cur_perf=self.model_metrics.workload_served,
|
||||
error_msg=self.model_metrics.error_msg or "",
|
||||
num_requests_working=len(self.model_metrics.requests_working),
|
||||
num_requests_recieved=len(self.model_metrics.requests_recieved),
|
||||
additional_disk_usage=self.system_metrics.additional_disk_usage,
|
||||
working_request_idxs=self.model_metrics.working_request_idxs,
|
||||
cur_capacity=0,
|
||||
max_capacity=0,
|
||||
url=self.url,
|
||||
)
|
||||
|
||||
def send_data(report_addr: str) -> None:
|
||||
async def send_data(report_addr: str) -> bool:
|
||||
data = compute_autoscaler_data()
|
||||
full_path = urljoin(report_addr, "/worker_status/")
|
||||
log_data = asdict(data)
|
||||
def obfuscate(secret: str) -> str:
|
||||
if secret is None:
|
||||
return ""
|
||||
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
||||
|
||||
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
||||
log.debug(
|
||||
"\n".join(
|
||||
[
|
||||
"#" * 60,
|
||||
f"sending data to autoscaler",
|
||||
f"{json.dumps((asdict(data)), indent=2)}",
|
||||
f"{json.dumps(log_data, indent=2)}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
requests.post(full_path, json=asdict(data), timeout=1)
|
||||
break
|
||||
except requests.Timeout:
|
||||
session = await self.http()
|
||||
async with session.post(full_path, json=asdict(data)) as res:
|
||||
res.raise_for_status()
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
log.debug(f"autoscaler status update timed out")
|
||||
except Exception as e:
|
||||
except (ClientResponseError, Exception) as e:
|
||||
log.debug(f"autoscaler status update failed with error: {e}")
|
||||
time.sleep(2)
|
||||
await asyncio.sleep(2)
|
||||
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
|
||||
log.debug(f"failed to send update through {report_addr}")
|
||||
return False
|
||||
|
||||
###########
|
||||
|
||||
self.system_metrics.update_disk_usage()
|
||||
|
||||
sent = False
|
||||
for report_addr in self.report_addr:
|
||||
send_data(report_addr)
|
||||
if await send_data(report_addr):
|
||||
sent = True
|
||||
break
|
||||
|
||||
if sent:
|
||||
# clear the one-shot loadtime only if we actually sent *this* value
|
||||
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||
self.update_pending = False
|
||||
self.model_metrics.reset()
|
||||
self.system_metrics.reset()
|
||||
self.last_metric_update = time.time()
|
||||
|
||||
+27
-9
@@ -10,6 +10,7 @@ from collections import Counter
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from urllib.parse import urljoin
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
import requests
|
||||
|
||||
from lib.data_types import AuthData, ApiPayload
|
||||
@@ -53,6 +54,13 @@ test_args.add_argument(
|
||||
default="https://run.vast.ai",
|
||||
help="Call local autoscaler instead of prod, for dev use only",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"-i",
|
||||
dest="instance",
|
||||
type=str,
|
||||
default="prod",
|
||||
help="Autoscaler shard to run the command against, default: prod",
|
||||
)
|
||||
|
||||
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
|
||||
|
||||
@@ -70,6 +78,7 @@ class ClientState:
|
||||
api_key: str
|
||||
server_url: str
|
||||
worker_endpoint: str
|
||||
instance: str
|
||||
payload: ApiPayload
|
||||
url: str = ""
|
||||
status: ClientStatus = ClientStatus.FetchEndpoint
|
||||
@@ -79,11 +88,7 @@ class ClientState:
|
||||
|
||||
def make_call(self):
|
||||
self.status = ClientStatus.FetchEndpoint
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=self.endpoint_group_name,
|
||||
account_api_key=self.api_key,
|
||||
)
|
||||
if not endpoint_api_key:
|
||||
if not self.api_key:
|
||||
self.as_error.append(
|
||||
f"Endpoint {self.endpoint_group_name} not found for API key",
|
||||
)
|
||||
@@ -91,12 +96,14 @@ class ClientState:
|
||||
return
|
||||
route_payload = {
|
||||
"endpoint": self.endpoint_group_name,
|
||||
"api_key": endpoint_api_key,
|
||||
"api_key": self.api_key,
|
||||
"cost": self.payload.count_workload(),
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
response = requests.post(
|
||||
urljoin(self.server_url, "/route/"),
|
||||
json=route_payload,
|
||||
headers=headers,
|
||||
timeout=4,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
@@ -114,9 +121,11 @@ class ClientState:
|
||||
self.url = worker_address
|
||||
url = urljoin(worker_address, self.worker_endpoint)
|
||||
self.status = ClientStatus.Generating
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
self.infer_error.append(
|
||||
@@ -135,6 +144,7 @@ class ClientState:
|
||||
try:
|
||||
self.make_call()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.status = ClientStatus.Error
|
||||
_ = e
|
||||
self.conn_errors[self.url] += 1
|
||||
@@ -226,6 +236,7 @@ def run_test(
|
||||
server_url: str,
|
||||
worker_endpoint: str,
|
||||
payload_cls: Type[ApiPayload],
|
||||
instance: str,
|
||||
):
|
||||
threads = []
|
||||
|
||||
@@ -234,8 +245,7 @@ def run_test(
|
||||
print_thread.daemon = True # makes threads get killed on program exit
|
||||
print_thread.start()
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=endpoint_group_name,
|
||||
account_api_key=api_key,
|
||||
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
|
||||
)
|
||||
if not endpoint_api_key:
|
||||
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
||||
@@ -248,6 +258,7 @@ def run_test(
|
||||
server_url=server_url,
|
||||
worker_endpoint=worker_endpoint,
|
||||
payload=payload_cls.for_test(),
|
||||
instance=instance,
|
||||
)
|
||||
clients.append(client)
|
||||
thread = threading.Thread(target=client.simulate_user, args=())
|
||||
@@ -281,12 +292,19 @@ def test_load_cmd(
|
||||
args = arg_parser.parse_args()
|
||||
if hasattr(args, "comfy_model"):
|
||||
os.environ["COMFY_MODEL"] = args.comfy_model
|
||||
server_url = {
|
||||
"prod": "https://run.vast.ai",
|
||||
"alpha": "https://run-alpha.vast.ai",
|
||||
"candidate": "https://run-candidate.vast.ai",
|
||||
"local": "http://localhost:8080",
|
||||
}.get(args.instance, "http://localhost:8080")
|
||||
run_test(
|
||||
num_requests=args.num_requests,
|
||||
requests_per_second=args.requests_per_second,
|
||||
api_key=args.api_key,
|
||||
server_url=args.server_url,
|
||||
server_url=server_url,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
worker_endpoint=endpoint,
|
||||
payload_cls=payload_cls,
|
||||
instance=args.instance,
|
||||
)
|
||||
|
||||
+4
-2
@@ -1,4 +1,4 @@
|
||||
aiohttp~=3.11
|
||||
aiohttp[speedups]==3.10.1
|
||||
anyio~=4.4
|
||||
lib~=4.0
|
||||
nltk~=3.9
|
||||
@@ -6,4 +6,6 @@ psutil~=6.0
|
||||
pycryptodome~=3.20
|
||||
Requests~=2.32
|
||||
transformers~=4.52
|
||||
utils~=1.0
|
||||
utils==1.0.*
|
||||
hf_transfer>=0.1.9
|
||||
vastai-sdk>=0.2.0g
|
||||
+25
-10
@@ -41,22 +41,37 @@ echo_var DEBUG_LOG
|
||||
echo_var PYWORKER_LOG
|
||||
echo_var MODEL_LOG
|
||||
|
||||
env | grep _ >> /etc/environment;
|
||||
|
||||
# Populate /etc/environment with quoted values
|
||||
if ! grep -q "VAST" /etc/environment; then
|
||||
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
|
||||
name=${line%%=*}
|
||||
value=${line#*=}
|
||||
printf '%s="%s"\n' "$name" "$value"
|
||||
done > /etc/environment
|
||||
fi
|
||||
|
||||
if [ ! -d "$ENV_PATH" ]
|
||||
then
|
||||
apt install -y python3.10-venv
|
||||
echo "setting up venv"
|
||||
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
|
||||
if ! which uv; then
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source ~/.local/bin/env
|
||||
fi
|
||||
|
||||
python3 -m venv "$WORKSPACE_DIR/worker-env"
|
||||
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
||||
# Fork testing
|
||||
[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
|
||||
if [[ -n ${PYWORKER_REF:-} ]]; then
|
||||
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
|
||||
fi
|
||||
|
||||
pip install -r vast-pyworker/requirements.txt
|
||||
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
|
||||
source "$ENV_PATH/bin/activate"
|
||||
|
||||
uv pip install -r "${SERVER_DIR}/requirements.txt"
|
||||
|
||||
touch ~/.no_auto_tmux
|
||||
else
|
||||
[[ -f ~/.local/bin/env ]] && source ~/.local/bin/env
|
||||
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
||||
echo "environment activated"
|
||||
echo "venv: $VIRTUAL_ENV"
|
||||
@@ -87,14 +102,14 @@ if [ "$USE_SSL" = true ]; then
|
||||
IP.1 = 0.0.0.0
|
||||
EOF
|
||||
|
||||
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
|
||||
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
|
||||
-nodes \
|
||||
-sha256 \
|
||||
-keyout /etc/instance.key \
|
||||
-out /etc/instance.csr \
|
||||
-config /etc/openssl-san.cnf
|
||||
|
||||
curl --header 'Content-Type: application/octet-stream' \
|
||||
curl --header 'Content-Type: application/octet-stream' \
|
||||
--data-binary @//etc/instance.csr \
|
||||
-X \
|
||||
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
|
||||
@@ -103,7 +118,7 @@ fi
|
||||
|
||||
|
||||
|
||||
export REPORT_ADDR WORKER_PORT USE_SSL
|
||||
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
|
||||
|
||||
cd "$SERVER_DIR"
|
||||
|
||||
|
||||
+66
-6
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
@@ -17,7 +18,63 @@ class Endpoint:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]:
|
||||
def get_endpoint_info(
|
||||
endpoint_name: str, account_api_key: str, instance: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||
url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}"
|
||||
# Retry a few times to smooth over transient propagation/network delays
|
||||
for attempt in range(4):
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=8)
|
||||
if response.status_code != 200:
|
||||
# brief backoff and retry
|
||||
time.sleep(0.3 * (attempt + 1))
|
||||
continue
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception:
|
||||
# JSON parse failed; backoff and retry
|
||||
time.sleep(0.3 * (attempt + 1))
|
||||
continue
|
||||
result = data.get("results", []) if isinstance(data, dict) else []
|
||||
endpoint = next(
|
||||
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||
None,
|
||||
)
|
||||
if endpoint and endpoint.get("id") and endpoint.get("api_key"):
|
||||
return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")}
|
||||
except Exception:
|
||||
# network or other transient error; retry
|
||||
time.sleep(0.3 * (attempt + 1))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_autoscaler_server_url(instance: str) -> str:
|
||||
endpoints = {
|
||||
"alpha": "run-alpha",
|
||||
"candidate": "run-candidate",
|
||||
"prod": "run",
|
||||
}
|
||||
host = endpoints.get(instance)
|
||||
if host:
|
||||
return f"https://{host}.vast.ai/"
|
||||
return "http://localhost:8080"
|
||||
|
||||
@staticmethod
|
||||
def get_server_url(instance: str) -> str:
|
||||
endpoints = {
|
||||
"alpha": "alpha",
|
||||
"candidate": "candidate",
|
||||
"prod": "console",
|
||||
}
|
||||
host = endpoints.get(instance, "alpha")
|
||||
return f"https://{host}.vast.ai/api/v0/endptjobs/"
|
||||
|
||||
@staticmethod
|
||||
def get_endpoint_api_key(
|
||||
endpoint_name: str, account_api_key: str, instance: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Fetch endpoint API key from VastAI console following the healthcheck pattern.
|
||||
|
||||
@@ -28,12 +85,15 @@ class Endpoint:
|
||||
Returns:
|
||||
Endpoint API key if successful, None otherwise
|
||||
"""
|
||||
vast_console_url = "https://console.vast.ai/api/v0/endptjobs/"
|
||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||
|
||||
try:
|
||||
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
||||
response = requests.get(vast_console_url, headers=headers)
|
||||
response = requests.get(
|
||||
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
||||
headers=headers,
|
||||
timeout=8,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
|
||||
@@ -42,14 +102,14 @@ class Endpoint:
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except requests.exceptions.JSONDecodeError as e:
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to parse JSON response: {e}")
|
||||
return None
|
||||
|
||||
result = data.get("results", [])
|
||||
|
||||
endpoint: Optional[Dict[str, Any]] = next(
|
||||
(item for item in result if item["endpoint_name"] == endpoint_name),
|
||||
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||
None,
|
||||
)
|
||||
if not endpoint:
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
import tempfile
|
||||
from functools import cache
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@cache
|
||||
def get_cert_file_path():
|
||||
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
|
||||
response = requests.get(cert_url)
|
||||
response.raise_for_status()
|
||||
# Use a temporary file that is not deleted on close
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
|
||||
f.write(response.content)
|
||||
return f.name
|
||||
@@ -0,0 +1,222 @@
|
||||
# ComfyUI PyWorker
|
||||
|
||||
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture.
|
||||
|
||||
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
|
||||
|
||||
## Requirements
|
||||
|
||||
This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [ComfyUI API Wrapper](https://github.com/ai-dock/comfyui-api-wrapper).
|
||||
|
||||
A docker image is provided but you may use any if the above requirements are met.
|
||||
|
||||
## Benchmarking
|
||||
|
||||
### Custom Benchmark Workflows
|
||||
|
||||
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
|
||||
|
||||
**Ways to provide the benchmark file:**
|
||||
- Fork this repository and add your `benchmark.json` file
|
||||
- Write the file during worker provisioning (onstart script or setup phase)
|
||||
|
||||
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
|
||||
|
||||
### Default Benchmark (Fallback)
|
||||
|
||||
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
|
||||
|
||||
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
|
||||
|
||||
| Environment Variable | Default Value | Description |
|
||||
| -------------------- | ------------- | ----------- |
|
||||
| BENCHMARK_TEST_WIDTH | 512 | Image width (pixels) |
|
||||
| BENCHMARK_TEST_HEIGHT | 512 | Image height (pixels) |
|
||||
| BENCHMARK_TEST_STEPS | 20 | Number of denoising steps |
|
||||
|
||||
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
||||
|
||||
#### Calibrating Fallback Benchmark Duration
|
||||
|
||||
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
||||
|
||||
**Example:** If your typical workflow should complete in 90 seconds on acceptable hardware:
|
||||
|
||||
```bash
|
||||
# 1. Measure it/sec on your reference machine
|
||||
# RTX 4090 typically achieves ~43 it/sec with SD1.5
|
||||
|
||||
# 2. Calculate required steps
|
||||
# 90 seconds × 43 it/sec = 3870 steps
|
||||
|
||||
# 3. Configure benchmark
|
||||
export BENCHMARK_TEST_STEPS=3870
|
||||
|
||||
# 4. Machines completing significantly slower than 90s indicate hardware issues
|
||||
```
|
||||
|
||||
**Performance expectations:**
|
||||
- Benchmark duration should remain consistent across identical GPU models
|
||||
- Significant variation (>20%) may indicate thermal, power, or configuration issues
|
||||
|
||||
## Endpoint
|
||||
|
||||
The worker provides a single endpoint:
|
||||
|
||||
- `/generate/sync`: Processes ComfyUI workflows using either predefined modifiers or custom workflow JSON
|
||||
|
||||
## Request Format
|
||||
|
||||
The worker accepts requests in the following format. Choose either modifier mode OR custom workflow mode:
|
||||
|
||||
**Modifier Mode:**
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"request_id": "uuid-string", // optional - UUID generated if not provided
|
||||
"modifier": "RawWorkflow",
|
||||
"modifications": {
|
||||
"prompt": "a beautiful landscape",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"steps": 20,
|
||||
"seed": 123456789
|
||||
},
|
||||
"s3": { ... }, // optional
|
||||
"webhook": { ... } // optional
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Custom Workflow Mode:**
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"request_id": "uuid-string", // optional - UUID generated if not provided
|
||||
"workflow_json": {
|
||||
// Complete ComfyUI workflow JSON
|
||||
},
|
||||
"s3": { ... }, // optional
|
||||
"webhook": { ... } // optional
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Request Fields
|
||||
|
||||
### Required Fields
|
||||
|
||||
- **`input`**: Contains the main workflow data
|
||||
- **`input.request_id`**: Unique identifier for the request
|
||||
|
||||
### Workflow Mode (Choose One)
|
||||
|
||||
You must provide either `modifier` OR `workflow_json`, but not both:
|
||||
|
||||
#### Option 1: Modifier Mode
|
||||
- **`input.modifier`**: Name of the predefined workflow modifier (e.g., "Text2Image")
|
||||
- **`input.modifications`**: Parameters to pass to the modifier
|
||||
|
||||
#### Option 2: Custom Workflow Mode
|
||||
- **`input.workflow_json`**: Complete ComfyUI workflow JSON
|
||||
|
||||
### Optional Fields
|
||||
|
||||
- **`input.s3`**: S3 configuration for file storage
|
||||
- **`input.webhook`**: Webhook configuration for notifications
|
||||
|
||||
These configurations can be provided in the request JSON or via environment variables. Request-level configuration takes precedence over environment variables.
|
||||
|
||||
#### S3 Configuration
|
||||
|
||||
**Via Request JSON:**
|
||||
```json
|
||||
"s3": {
|
||||
"access_key_id": "your-s3-access-key",
|
||||
"secret_access_key": "your-s3-secret-access-key",
|
||||
"endpoint_url": "https://my-endpoint.backblaze.com",
|
||||
"bucket_name": "your-bucket",
|
||||
"region": "us-east-1"
|
||||
}
|
||||
```
|
||||
|
||||
**Via Environment Variables:**
|
||||
```bash
|
||||
S3_ACCESS_KEY_ID=your-key
|
||||
S3_SECRET_ACCESS_KEY=your-secret
|
||||
S3_BUCKET_NAME=your-bucket
|
||||
S3_ENDPOINT_URL=https://s3.amazonaws.com
|
||||
S3_REGION=us-east-1
|
||||
```
|
||||
|
||||
#### Webhook Configuration
|
||||
|
||||
**Via Request JSON:**
|
||||
```json
|
||||
"webhook": {
|
||||
"url": "your-webhook-url",
|
||||
"extra_params": {
|
||||
"custom_field": "value"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Via Environment Variables:**
|
||||
```bash
|
||||
WEBHOOK_URL=https://your-webhook.com # Default webhook URL
|
||||
WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic Text-to-Image (Modifier Mode)
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"modifier": "Text2Image",
|
||||
"modifications": {
|
||||
"prompt": "a cat sitting on a windowsill",
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"steps": 20,
|
||||
"seed": 42
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Custom Workflow Mode
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"request_id": "67890", // optional - using custom ID for tracking
|
||||
"workflow_json": {
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": 42,
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": ["4", 0],
|
||||
"positive": ["6", 0],
|
||||
"negative": ["7", 0],
|
||||
"latent_image": ["5", 0]
|
||||
},
|
||||
"class_type": "KSampler"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Client Libraries
|
||||
|
||||
See the test client examples for implementation details on how to integrate with the ComfyUI worker.
|
||||
|
||||
---
|
||||
|
||||
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
|
||||
@@ -0,0 +1,35 @@
|
||||
from .data_types import count_workload
|
||||
import uuid
|
||||
import random
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from vastai import Serverless
|
||||
|
||||
async def main():
|
||||
async with Serverless() as client:
|
||||
endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"request_id": str(uuid.uuid4()),
|
||||
"modifier": "Text2Image",
|
||||
"modifications": {
|
||||
"prompt": "a beautiful landscape with mountains and lakes",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"steps": 20,
|
||||
"seed": random.randint(0, 2**32 - 1)
|
||||
},
|
||||
"workflow_json": {} # Empty since using modifier approach
|
||||
}
|
||||
}
|
||||
|
||||
response = await endpoint.request("/generate/sync", payload, cost=count_workload())
|
||||
|
||||
# Get the file from the path on the local machine using SCP or SFTP
|
||||
# or configure S3 to upload to cloud storage.
|
||||
print(response["response"]["output"][0]["local_path"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import dataclasses
|
||||
from typing import Dict, Any
|
||||
from functools import cache
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
import json
|
||||
import logging
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
def count_workload() -> float:
|
||||
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
||||
# Results will indicate % or a job completed per second. Avoids sub 0.1 sec performance indication
|
||||
return 100.0
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ComfyWorkflowData(ApiPayload):
|
||||
input: dict
|
||||
|
||||
@classmethod
|
||||
def for_test(cls):
|
||||
"""
|
||||
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
|
||||
Otherwise, use the variables available to simulate workflows of the required running time
|
||||
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
||||
"""
|
||||
# Try to load benchmark.json
|
||||
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
|
||||
|
||||
if benchmark_file.exists():
|
||||
try:
|
||||
with open(benchmark_file, "r") as f:
|
||||
benchmark_workflow = json.load(f)
|
||||
return cls(
|
||||
input={
|
||||
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||
"workflow_json": benchmark_workflow
|
||||
}
|
||||
)
|
||||
except (json.JSONDecodeError, IOError):
|
||||
# JSON is malformed or file can't be read, fall through to default
|
||||
log.error(f"Failed to benchmark using {benchmark_file}")
|
||||
|
||||
# Fallback: read prompts and construct payload
|
||||
log.info("Using fallback method for benchmarking")
|
||||
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
|
||||
test_prompts = f.readlines()
|
||||
|
||||
test_prompt = random.choice(test_prompts).rstrip()
|
||||
return cls(
|
||||
input={
|
||||
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||
"modifier": "Text2Image",
|
||||
"modifications": {
|
||||
"prompt": test_prompt,
|
||||
"width": os.getenv('BENCHMARK_TEST_WIDTH', 512),
|
||||
"height": os.getenv('BENCHMARK_TEST_HEIGHT', 512),
|
||||
"steps": os.getenv('BENCHMARK_TEST_STEPS', 20),
|
||||
"seed": random.randint(0, sys.maxsize),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
# input is already a dict, just return it wrapped in the expected structure
|
||||
return {"input": self.input}
|
||||
|
||||
def count_workload(self) -> float:
|
||||
return count_workload()
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ComfyWorkflowData":
|
||||
# Extract required fields
|
||||
if "input" not in json_msg:
|
||||
raise JsonDataException({"input": "missing parameter"})
|
||||
|
||||
return cls(
|
||||
input=json_msg["input"]
|
||||
)
|
||||
@@ -0,0 +1,107 @@
|
||||
{
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": "__RANDOM_INT__",
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "text, watermark",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
|
||||
stardew valley, fine details
|
||||
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
|
||||
realistic futuristic city-downtown with short buildings, sunset
|
||||
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
|
||||
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
|
||||
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
|
||||
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
|
||||
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
|
||||
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
|
||||
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
|
||||
Pope Francis wearing biker (leather jacket), a masterpiece
|
||||
Luke Skywalker ordering a burger and fries from the Death Star canteen.
|
||||
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
|
||||
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
|
||||
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
|
||||
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
|
||||
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
|
||||
london luxurious interior living-room, light walls
|
||||
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
|
||||
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
|
||||
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
|
||||
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
|
||||
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
|
||||
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
|
||||
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
|
||||
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
|
||||
the street of amedieval fantasy town, at dawn, dark, highly detailed
|
||||
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
|
||||
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
|
||||
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
|
||||
@@ -0,0 +1,117 @@
|
||||
import os
|
||||
import logging
|
||||
import dataclasses
|
||||
import base64
|
||||
from typing import Optional, Union, Type
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.data_types import EndpointHandler
|
||||
from lib.server import start_server
|
||||
from .data_types import ComfyWorkflowData
|
||||
|
||||
|
||||
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||
"Value not in list: ", # This error is emitted when the model file is not there at all
|
||||
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
|
||||
]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
async def generate_client_response(
|
||||
client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
# Check if the response is actually streaming based on response headers/content-type
|
||||
is_streaming_response = (
|
||||
model_response.content_type == "text/event-stream"
|
||||
or model_response.content_type == "application/x-ndjson"
|
||||
or model_response.headers.get("Transfer-Encoding") == "chunked"
|
||||
or "stream" in model_response.content_type.lower()
|
||||
)
|
||||
|
||||
if is_streaming_response:
|
||||
log.debug("Detected streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = model_response.content_type
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
else:
|
||||
log.debug("Detected non-streaming response...")
|
||||
content = await model_response.read()
|
||||
return web.Response(
|
||||
body=content,
|
||||
status=model_response.status,
|
||||
content_type=model_response.content_type
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate/sync"
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[ComfyWorkflowData]:
|
||||
return ComfyWorkflowData
|
||||
|
||||
def make_benchmark_payload(self) -> ComfyWorkflowData:
|
||||
return ComfyWorkflowData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
return await generate_client_response(client_request, model_response)
|
||||
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=False,
|
||||
benchmark_handler=ComfyWorkflowHandler(
|
||||
benchmark_runs=3, benchmark_words=100
|
||||
),
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
(LogAction.Info, "Downloading:"),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def handle_ping(_):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server(backend, routes)
|
||||
@@ -0,0 +1,8 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from .data_types import ComfyWorkflowData
|
||||
|
||||
WORKER_ENDPOINT = "/generate/sync"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_cmd(ComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||
+10
-12
@@ -5,21 +5,15 @@ import requests
|
||||
|
||||
from lib.test_utils import print_truncate_res
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
|
||||
"""
|
||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
from vastai import Serverless
|
||||
|
||||
|
||||
def call_default_workflow(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
ENDPOINT_NAME = "my-comfyui-endpoint"
|
||||
COST = 100 # Use a constant cost for image generation
|
||||
|
||||
def call_default_workflow(client: Serverless) -> None:
|
||||
WORKER_ENDPOINT = "/prompt"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
@@ -51,6 +45,7 @@ def call_default_workflow(
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
print_truncate_res(str(response.json()))
|
||||
@@ -80,6 +75,7 @@ def call_custom_workflow_for_sd3(
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
request_idx=message["request_idx"],
|
||||
)
|
||||
workflow = {
|
||||
"3": {
|
||||
@@ -141,6 +137,7 @@ def call_custom_workflow_for_sd3(
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
print_truncate_res(str(response.json()))
|
||||
@@ -153,6 +150,7 @@ if __name__ == "__main__":
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
if endpoint_api_key:
|
||||
try:
|
||||
|
||||
@@ -13,7 +13,7 @@ from lib.server import start_server
|
||||
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
|
||||
|
||||
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:38188"
|
||||
MODEL_SERVER_URL = "http://127.0.0.1:18288" # API Wrapper Service
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
# OpenAI Compatible PyWorker
|
||||
|
||||
This is the base PyWorker for OpenAI compatible inference servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||
|
||||
## Instance Setup
|
||||
|
||||
1. Pick a template
|
||||
|
||||
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
||||
|
||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
|
||||
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
|
||||
|
||||
|
||||
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
||||
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
|
||||
## Client Setup (Demo)
|
||||
|
||||
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vast-ai/pyworker
|
||||
cd pyworker
|
||||
pip install uv
|
||||
uv venv -p 3.12
|
||||
source .venv/bin/activate
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Using the Test Client
|
||||
|
||||
Several examples have been provided in the client to help you get started with your own implementation.
|
||||
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (streaming)
|
||||
|
||||
Call to `/v1/chat/completions` with streaming response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
|
||||
Call to `/v1/chat/completions` with tool and json response.
|
||||
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Interactive Chat (streaming)
|
||||
|
||||
Interactive session with calls to `/v1/chat/completions`.
|
||||
|
||||
Type `clear` to clear the chat history or `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# <INFERENCE_SERVER> + <MODEL_NAME> (serverless)
|
||||
|
||||
Run <INFERENCE_SERVER> with our serverless autoscaling infrastructure.
|
||||
|
||||
See the [serverless documentation](https://docs.vast.ai/serverless) and the [Getting Started](https://docs.vast.ai/serverless/getting-started) guide for in-depth details about how to use these templates.
|
||||
|
||||
## Configuration
|
||||
|
||||
Two environment variables are provided to help you configure the <INFERENCE_SERVER> server:
|
||||
|
||||
| Variable | Default Value | Used For |
|
||||
| --- | --- | --- |
|
||||
| `MODEL_NAME` | `<MODEL_NAME>` | The model to load. Also accepts [hf.co/repo/model](#) links |
|
||||
| `<ARGS_VAR>` | `<ARGS_VAL>` | Arguments to pass to the `<ARGS_RECEIVER>` command |
|
||||
|
||||
This template has been configured to work with <MIN_VRAM> VRAM. Setting alternative models and server arguments will change the VRAM requirements. Check model cards and <INFERENCE_SERVER_DOCS> for guidance.
|
||||
|
||||
## Usage
|
||||
|
||||
We have provided a demonstration client to help you implement this template into your own infrastructure
|
||||
|
||||
### Client Setup
|
||||
|
||||
Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vast-ai/pyworker
|
||||
cd pyworker
|
||||
pip install uv
|
||||
uv venv -p 3.12
|
||||
source .venv/bin/activate
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (streaming)
|
||||
|
||||
Call to `/v1/chat/completions` with streaming response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
|
||||
Call to `/v1/chat/completions` with tool and json response.
|
||||
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Interactive Chat (streaming)
|
||||
|
||||
Interactive session with calls to `/v1/chat/completions`.
|
||||
|
||||
Type `clear` to clear the chat history or `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||
```
|
||||
@@ -0,0 +1,529 @@
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from vastai import Serverless
|
||||
import asyncio
|
||||
|
||||
# ---------------------- Logging ----------------------
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
# ---------------------- Prompts ----------------------
|
||||
COMPLETIONS_PROMPT = "the capital of USA is"
|
||||
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
TOOLS_PROMPT = (
|
||||
"Can you list the files in the current working directory and tell me what you see? "
|
||||
"What do you think this directory might be for?"
|
||||
)
|
||||
|
||||
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
|
||||
MAX_TOKENS = 1024
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
|
||||
# ---------------------- Tooling ----------------------
|
||||
class ToolManager:
|
||||
"""Handles tool definitions and execution"""
|
||||
|
||||
@staticmethod
|
||||
def list_files() -> str:
|
||||
"""Execute ls on current directory"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ls", "-la", "."], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
else:
|
||||
return f"Error: {result.stderr}"
|
||||
except Exception as e:
|
||||
return f"Error running ls: {e}"
|
||||
|
||||
@staticmethod
|
||||
def get_ls_tool_definition() -> List[Dict[str, Any]]:
|
||||
"""OpenAI-compatible tool schema"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_files",
|
||||
"description": "List files and directories in the cwd",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
||||
"""Execute a tool call and return the result"""
|
||||
function_name = (tool_call.get("function") or {}).get("name")
|
||||
if function_name == "list_files":
|
||||
return self.list_files()
|
||||
raise ValueError(f"Unknown tool function: {function_name}")
|
||||
|
||||
|
||||
# ----- Helpers to handle streamed tool_calls assembly -----
|
||||
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
|
||||
"""
|
||||
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
|
||||
We merge into a per-index state dict until the assistant message finishes.
|
||||
"""
|
||||
idx = tc_delta.get("index")
|
||||
if idx is None:
|
||||
return
|
||||
|
||||
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
|
||||
|
||||
if tc_delta.get("id"):
|
||||
entry["id"] = tc_delta["id"]
|
||||
|
||||
fn_delta = tc_delta.get("function") or {}
|
||||
if "name" in fn_delta and fn_delta["name"]:
|
||||
entry["function"]["name"] = fn_delta["name"]
|
||||
if "arguments" in fn_delta and fn_delta["arguments"]:
|
||||
entry["function"]["arguments"] += fn_delta["arguments"]
|
||||
|
||||
|
||||
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
return [state[i] for i in sorted(state.keys())]
|
||||
|
||||
|
||||
# ---- OpenAI-compatible calls (non-streaming) ----
|
||||
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
}
|
||||
}
|
||||
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
# ---- Streaming variants ----
|
||||
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
|
||||
# ---------------------- Demo Runner ----------------------
|
||||
class APIDemo:
|
||||
"""Demo and testing functionality for the API client"""
|
||||
|
||||
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.tool_manager = tool_manager or ToolManager()
|
||||
|
||||
# ----- Streaming handler -----
|
||||
async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
|
||||
full_response = ""
|
||||
reasoning_content = ""
|
||||
printed_reasoning = False
|
||||
printed_answer = False
|
||||
|
||||
async for chunk in stream:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# reasoning tokens
|
||||
rc = delta.get("reasoning_content")
|
||||
if rc and show_reasoning:
|
||||
if not printed_reasoning:
|
||||
print("\n🧠 Reasoning: ", end="", flush=True)
|
||||
printed_reasoning = True
|
||||
print(rc, end="", flush=True)
|
||||
reasoning_content += rc
|
||||
|
||||
# content tokens
|
||||
content_part = delta.get("content")
|
||||
if content_part:
|
||||
if not printed_answer:
|
||||
if show_reasoning and printed_reasoning:
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
else:
|
||||
print("Assistant: ", end="", flush=True)
|
||||
printed_answer = True
|
||||
print(content_part, end="", flush=True)
|
||||
full_response += content_part
|
||||
|
||||
print() # newline
|
||||
if show_reasoning:
|
||||
if printed_reasoning or printed_answer:
|
||||
print("\nStreaming completed.")
|
||||
if printed_reasoning:
|
||||
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
||||
if printed_answer:
|
||||
print(f"Response tokens: {len(full_response.split())}")
|
||||
|
||||
return full_response
|
||||
|
||||
async def demo_completions(self) -> None:
|
||||
print("=" * 60)
|
||||
print("COMPLETIONS DEMO")
|
||||
print("=" * 60)
|
||||
|
||||
response = await call_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
prompt=COMPLETIONS_PROMPT,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
print("\nResponse:")
|
||||
print(json.dumps(response, indent=2))
|
||||
|
||||
async def demo_chat(self, use_streaming: bool = True) -> None:
|
||||
print("=" * 60)
|
||||
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
|
||||
print("=" * 60)
|
||||
|
||||
messages = [{"role": "user", "content": CHAT_PROMPT}]
|
||||
|
||||
if use_streaming:
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
try:
|
||||
await self.handle_streaming_response(stream, show_reasoning=True)
|
||||
except Exception as e:
|
||||
log.error("\nError during streaming: %s", e, exc_info=True)
|
||||
else:
|
||||
response = await call_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
choice = (response.get("choices") or [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
|
||||
if reasoning:
|
||||
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
||||
print(f"\n💬 Assistant: {content}")
|
||||
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
async def test_tool_support(self) -> bool:
|
||||
"""Probe that tool schema is accepted (no actual call)"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
minimal_tool = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "test_function", "description": "Test function"},
|
||||
}
|
||||
]
|
||||
try:
|
||||
_ = await call_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=minimal_tool,
|
||||
tool_choice="none",
|
||||
max_tokens=10
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error("Endpoint does not support tool calling: %s", e)
|
||||
return False
|
||||
|
||||
async def demo_ls_tool(self) -> None:
|
||||
"""Ask to list files using function calling, then provide final analysis"""
|
||||
print("=" * 60)
|
||||
print("TOOL USE DEMO: List Directory Contents")
|
||||
print("=" * 60)
|
||||
|
||||
if not await self.test_tool_support():
|
||||
return
|
||||
|
||||
messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
|
||||
|
||||
# First pass: let the model decide tools, stream tool_calls and partial content
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=self.tool_manager.get_ls_tool_definition(),
|
||||
tool_choice="auto",
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
assistant_content_buf: List[str] = []
|
||||
tool_calls_state: Dict[int, Dict[str, Any]] = {}
|
||||
printed_reasoning = False
|
||||
printed_answer = False
|
||||
|
||||
async for chunk in stream:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
rc = delta.get("reasoning_content")
|
||||
if rc:
|
||||
if not printed_reasoning:
|
||||
printed_reasoning = True
|
||||
print("🧠 Reasoning: ", end="", flush=True)
|
||||
print(rc, end="", flush=True)
|
||||
|
||||
content_part = delta.get("content")
|
||||
if content_part:
|
||||
assistant_content_buf.append(content_part)
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
print(content_part, end="", flush=True)
|
||||
|
||||
if "tool_calls" in delta and delta["tool_calls"]:
|
||||
for tc_delta in delta["tool_calls"]:
|
||||
_merge_tool_call_delta(tool_calls_state, tc_delta)
|
||||
|
||||
# If no tool calls, we’re done.
|
||||
if not tool_calls_state:
|
||||
print("\n(No tool calls were made.)")
|
||||
return
|
||||
|
||||
# Build assistant message with tool_calls
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": "".join(assistant_content_buf) if assistant_content_buf else None,
|
||||
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
|
||||
}
|
||||
messages.append(assistant_message)
|
||||
|
||||
# Execute tools and feed results back
|
||||
for tc in assistant_message["tool_calls"]:
|
||||
tool_name = (tc.get("function") or {}).get("name")
|
||||
call_id = tc.get("id")
|
||||
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
|
||||
|
||||
try:
|
||||
args = json.loads(raw_args) if raw_args.strip() else {}
|
||||
except Exception as e:
|
||||
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
|
||||
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
|
||||
continue
|
||||
|
||||
try:
|
||||
if tool_name == "list_files":
|
||||
tool_result = self.tool_manager.list_files()
|
||||
else:
|
||||
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
|
||||
except Exception as e:
|
||||
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
|
||||
|
||||
print("\n[Tool executed]", tool_name)
|
||||
print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
|
||||
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
|
||||
|
||||
# Second pass: get final streamed answer after tool results
|
||||
stream2 = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
final_buf = []
|
||||
printed_reasoning2 = False
|
||||
printed_answer2 = False
|
||||
|
||||
async for chunk in stream2:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
rc2 = delta.get("reasoning_content")
|
||||
if rc2:
|
||||
if not printed_reasoning2:
|
||||
printed_reasoning2 = True
|
||||
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
|
||||
print(rc2, end="", flush=True)
|
||||
|
||||
c2 = delta.get("content")
|
||||
if c2:
|
||||
final_buf.append(c2)
|
||||
if not printed_answer2:
|
||||
printed_answer2 = True
|
||||
print("\n💬 Response (final): ", end="", flush=True)
|
||||
print(c2, end="", flush=True)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("FINAL LLM ANALYSIS:")
|
||||
print("=" * 60)
|
||||
print("".join(final_buf))
|
||||
print("=" * 60)
|
||||
|
||||
async def interactive_chat(self) -> None:
|
||||
"""Interactive chat session with streaming"""
|
||||
print("=" * 60)
|
||||
print("INTERACTIVE STREAMING CHAT")
|
||||
print("=" * 60)
|
||||
print(f"Using model: {self.model}")
|
||||
print("Type 'quit' to exit, 'clear' to clear history")
|
||||
print()
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("You: ").strip()
|
||||
|
||||
if user_input.lower() == "quit":
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
elif user_input.lower() == "clear":
|
||||
messages = []
|
||||
print("Chat history cleared")
|
||||
continue
|
||||
elif not user_input:
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=0.7
|
||||
)
|
||||
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
|
||||
|
||||
# Add assistant response to conversation history
|
||||
messages.append({"role": "assistant", "content": assistant_content})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Chat interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
log.error("\nError: %s", e)
|
||||
continue
|
||||
|
||||
|
||||
# ---------------------- CLI ----------------------
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
|
||||
p.add_argument("--model", required=True, help="Model to use for requests (required)")
|
||||
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
|
||||
|
||||
modes = p.add_mutually_exclusive_group(required=False)
|
||||
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
|
||||
modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
|
||||
modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
|
||||
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
|
||||
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
|
||||
return p
|
||||
|
||||
|
||||
async def main_async():
|
||||
args = build_arg_parser().parse_args()
|
||||
|
||||
selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
|
||||
if selected == 0:
|
||||
print("Please specify exactly one test mode:")
|
||||
print(" --completion : Test completions endpoint")
|
||||
print(" --chat : Test chat completions endpoint (non-streaming)")
|
||||
print(" --chat-stream : Test chat completions endpoint with streaming")
|
||||
print(" --tools : Test function calling with ls tool")
|
||||
print(" --interactive : Start interactive streaming chat session")
|
||||
print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
|
||||
sys.exit(1)
|
||||
elif selected > 1:
|
||||
print("Please specify exactly one test mode")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using model: {args.model}")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.model, ToolManager())
|
||||
|
||||
if args.completion:
|
||||
await demo.demo_completions()
|
||||
elif args.chat:
|
||||
await demo.demo_chat(use_streaming=False)
|
||||
elif args.chat_stream:
|
||||
await demo.demo_chat(use_streaming=True)
|
||||
elif args.tools:
|
||||
await demo.demo_ls_tool()
|
||||
elif args.interactive:
|
||||
await demo.interactive_chat()
|
||||
|
||||
except Exception as e:
|
||||
log.error("Error during test: %s", e, exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main_async())
|
||||
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
|
||||
class SerializableDataclass:
|
||||
def _serialize_recursive(self, obj: Any) -> Any:
|
||||
if is_dataclass(obj):
|
||||
return {
|
||||
field.name: self._serialize_recursive(getattr(obj, field.name))
|
||||
for field in fields(obj)
|
||||
}
|
||||
elif isinstance(obj, dict):
|
||||
return {key: self._serialize_recursive(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [self._serialize_recursive(item) for item in obj]
|
||||
elif isinstance(obj, set):
|
||||
return [self._serialize_recursive(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return self._serialize_recursive(self)
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionConfig(SerializableDataclass):
|
||||
"""Configuration for completion requests"""
|
||||
|
||||
model: str
|
||||
prompt: str = "Hello"
|
||||
max_tokens: int = 256
|
||||
temperature: float = 0.7
|
||||
top_k: int = 20
|
||||
top_p: float = 0.4
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionConfig(SerializableDataclass):
|
||||
"""Configuration for chat completion requests"""
|
||||
|
||||
model: str
|
||||
messages: list = field(default_factory=list)
|
||||
max_tokens: int = 2096
|
||||
temperature: float = 0.7
|
||||
top_k: int = 20
|
||||
top_p: float = 0.4
|
||||
stream: bool = False
|
||||
tools: Optional[List[Dict[str, Any]]] = field(default_factory=list)
|
||||
tool_choice: str = "auto"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.messages is None:
|
||||
self.messages = [{"role": "user", "content": "Hello"}]
|
||||
@@ -0,0 +1,207 @@
|
||||
import os, json, random
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
|
||||
from typing import Union, Type, Dict, Any, Optional
|
||||
from aiohttp import web, ClientResponse
|
||||
import nltk
|
||||
import logging
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
Generic dataclass accepts any dictionary in input.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenericData(ApiPayload, ABC):
|
||||
input: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
|
||||
return cls(input=data["input"])
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
|
||||
errors = {}
|
||||
|
||||
# Validate required parameters
|
||||
required_params = ["input"]
|
||||
for param in required_params:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
|
||||
try:
|
||||
# Create clean data dict and delegate to from_dict
|
||||
clean_data = {"input": json_msg["input"]}
|
||||
|
||||
return cls.from_dict(clean_data)
|
||||
|
||||
except (json.JSONDecodeError, JsonDataException) as e:
|
||||
errors["parameters"] = str(e)
|
||||
raise JsonDataException(errors)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def for_test(cls) -> "GenericData":
|
||||
pass
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
return self.input
|
||||
|
||||
def count_workload(self) -> int:
|
||||
return self.input.get("max_tokens", 0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenericHandler(EndpointHandler[GenericData], ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def endpoint(self) -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return os.environ.get("MODEL_HEALTH_ENDPOINT")
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[GenericData]:
|
||||
return GenericData
|
||||
|
||||
@abstractmethod
|
||||
def make_benchmark_payload(self) -> GenericData:
|
||||
pass
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
# Check if the response is actually streaming based on response headers/content-type
|
||||
is_streaming_response = (
|
||||
model_response.content_type == "text/event-stream"
|
||||
or model_response.content_type == "application/x-ndjson"
|
||||
or model_response.headers.get("Transfer-Encoding") == "chunked"
|
||||
or "stream" in model_response.content_type.lower()
|
||||
)
|
||||
|
||||
if is_streaming_response:
|
||||
log.debug("Detected streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = model_response.content_type
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
else:
|
||||
log.debug("Detected non-streaming response...")
|
||||
content = await model_response.read()
|
||||
return web.Response(
|
||||
body=content,
|
||||
status=200,
|
||||
content_type=model_response.content_type,
|
||||
)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionsData(GenericData):
|
||||
@classmethod
|
||||
def for_test(cls) -> "CompletionsData":
|
||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||
|
||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||
woodlands, shrublands, and mountainous areas.
|
||||
|
||||
Please answer the following question based on the above context."""
|
||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
test_input = {
|
||||
"model": model,
|
||||
"prompt": f"{system_prompt}\n\n{unique_question}",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
return cls(input=test_input)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionsHandler(GenericHandler):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/v1/completions"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[CompletionsData]:
|
||||
return CompletionsData
|
||||
|
||||
def make_benchmark_payload(self) -> CompletionsData:
|
||||
return CompletionsData.for_test()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionsData(GenericData):
|
||||
"""Chat completions-specific data implementation"""
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "ChatCompletionsData":
|
||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||
|
||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||
woodlands, shrublands, and mountainous areas.
|
||||
|
||||
Please answer the following question based on the above context."""
|
||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
# Chat completions use messages format instead of prompt
|
||||
test_input = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt}, # Shared prefix
|
||||
{"role": "user", "content": unique_question} # Unique per request
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
return cls(input=test_input)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionsHandler(GenericHandler):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/v1/chat/completions"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[ChatCompletionsData]:
|
||||
return ChatCompletionsData
|
||||
|
||||
def make_benchmark_payload(self) -> ChatCompletionsData:
|
||||
return ChatCompletionsData.for_test()
|
||||
@@ -0,0 +1,60 @@
|
||||
import os
|
||||
import logging
|
||||
from .data_types.server import CompletionsHandler, ChatCompletionsHandler
|
||||
from aiohttp import web
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.server import start_server
|
||||
|
||||
# This line indicates that the inference server is listening
|
||||
MODEL_SERVER_START_LOG_MSG = [
|
||||
"Application startup complete.", # vLLM
|
||||
"llama runner started", # Ollama
|
||||
'"message":"Connected","target":"text_generation_router"', # TGI
|
||||
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
||||
]
|
||||
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"INFO exited: vllm", # vLLM
|
||||
"RuntimeError: Engine", # vLLM
|
||||
"Error: pull model manifest:", # Ollama
|
||||
"stalled; retrying", # Ollama
|
||||
"Error: WebserverFailed", # TGI
|
||||
"Error: DownloadError", # TGI
|
||||
"Error: ShardCannotStart", # TGI
|
||||
]
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=os.environ["MODEL_SERVER_URL"],
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def handle_ping(_):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
|
||||
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server(backend, routes)
|
||||
@@ -0,0 +1,434 @@
|
||||
from lib.test_utils import test_args
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from lib.data_types import AuthData
|
||||
from .data_types.server import CompletionsData
|
||||
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import requests
|
||||
from dataclasses import dataclass
|
||||
from collections import Counter
|
||||
from urllib.parse import urljoin, urlparse
|
||||
import re
|
||||
|
||||
# Headless plotting
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import logging
|
||||
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
|
||||
from requests.adapters import HTTPAdapter
|
||||
|
||||
def get_incremented_path(path: str) -> str:
|
||||
base, ext = os.path.splitext(path)
|
||||
if not os.path.exists(path):
|
||||
return path
|
||||
i = 1
|
||||
while os.path.exists(f"{base}-{i}{ext}"):
|
||||
i += 1
|
||||
return f"{base}-{i}{ext}"
|
||||
|
||||
WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT)
|
||||
|
||||
@dataclass
|
||||
class ReqResult:
|
||||
worker_url: str
|
||||
route_ms: float
|
||||
worker_ms: float
|
||||
total_ms: float
|
||||
ok: bool
|
||||
error: str = ""
|
||||
status_code: int = 0
|
||||
t_start: float = 0.0
|
||||
t_end: float = 0.0
|
||||
workload: float = 0.0
|
||||
|
||||
def do_one(endpoint_name: str,
|
||||
endpoint_id: int,
|
||||
endpoint_api_key: str,
|
||||
server_url: str,
|
||||
worker_endpoint: str,
|
||||
payload,
|
||||
results_list,
|
||||
t0,
|
||||
status_samples,
|
||||
route_session,
|
||||
worker_session):
|
||||
try:
|
||||
workload = payload.count_workload()
|
||||
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
|
||||
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
|
||||
start = time.time()
|
||||
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
|
||||
t_after_route = time.time()
|
||||
if r0.status_code != 200:
|
||||
results_list.append(ReqResult(worker_url="",
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=0.0,
|
||||
total_ms=(t_after_route - start) * 1000.0,
|
||||
ok=False,
|
||||
error=f"route error {r0.reason} {r0.text}",
|
||||
status_code=r0.status_code,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_route - t0,
|
||||
workload=workload))
|
||||
return
|
||||
msg = r0.json()
|
||||
|
||||
# 1) Check if we got a worker back from route
|
||||
worker_url = msg.get("url", "")
|
||||
if not worker_url:
|
||||
status = msg.get("status", "")
|
||||
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
|
||||
if m:
|
||||
tot, loading, standby, err = map(int, m.groups())
|
||||
idle = max(tot - loading - standby - err, 0)
|
||||
status_samples.append((time.time() - t0, idle))
|
||||
|
||||
# 2) If we got a worker, send the request
|
||||
if worker_url:
|
||||
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
|
||||
t_before_worker = time.time()
|
||||
r1 = worker_session.post(
|
||||
urljoin(worker_url, worker_endpoint),
|
||||
json=req,
|
||||
verify=get_cert_file_path(),
|
||||
timeout=(4, 120),
|
||||
)
|
||||
t_after_worker = time.time()
|
||||
if r1.status_code != 200:
|
||||
results_list.append(ReqResult(worker_url=worker_url,
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||
total_ms=(t_after_worker - start) * 1000.0,
|
||||
ok=False,
|
||||
error=f"worker inference error {r1.reason} {r1.text}",
|
||||
status_code=r1.status_code,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_worker - t0,
|
||||
workload=workload))
|
||||
return
|
||||
# Success case
|
||||
results_list.append(ReqResult(worker_url=worker_url,
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||
total_ms=(t_after_worker - start) * 1000.0,
|
||||
ok=True,
|
||||
error="",
|
||||
status_code=200,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_worker - t0,
|
||||
workload=workload))
|
||||
|
||||
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
|
||||
if worker_url:
|
||||
try:
|
||||
r_status = route_session.post(
|
||||
urljoin(server_url, "/get_endpoint_workers/"),
|
||||
json={"id": endpoint_id},
|
||||
headers={"Authorization": f"Bearer {endpoint_api_key}"},
|
||||
timeout=3,
|
||||
)
|
||||
if r_status.status_code == 200:
|
||||
workers = r_status.json()
|
||||
idle = 0
|
||||
for w in workers:
|
||||
st = str(w.get("status", "")).lower()
|
||||
if (st in ("idle")):
|
||||
idle += 1
|
||||
status_samples.append((time.time() - t0, idle))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
t = time.time()
|
||||
results_list.append(ReqResult(worker_url="",
|
||||
route_ms=0.0,
|
||||
worker_ms=0.0,
|
||||
total_ms=0.0,
|
||||
ok=False,
|
||||
error=f"unknown error {e}",
|
||||
status_code=0,
|
||||
t_start=t - t0,
|
||||
t_end=t - t0,
|
||||
workload=0.0))
|
||||
|
||||
def run_load_with_metrics(num_requests: int,
|
||||
requests_per_second: float,
|
||||
endpoint_group_name: str,
|
||||
account_api_key: str,
|
||||
server_url: str,
|
||||
worker_endpoint: str,
|
||||
instance: str,
|
||||
out_path: str):
|
||||
|
||||
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
|
||||
account_api_key=account_api_key,
|
||||
instance=instance)
|
||||
if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"):
|
||||
print(f"Endpoint {endpoint_group_name} not found for API key")
|
||||
return
|
||||
endpoint_id = int(ep_info["id"])
|
||||
endpoint_api_key = ep_info["api_key"]
|
||||
|
||||
t0 = time.time()
|
||||
results = []
|
||||
status_samples = []
|
||||
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
|
||||
submit_queue_factor = 2 # cap queued tasks to reduce memory
|
||||
|
||||
# Shared HTTP sessions with connection pooling (persistent connections)
|
||||
def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session:
|
||||
sess = requests.Session()
|
||||
adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0)
|
||||
sess.mount("https://", adapter)
|
||||
sess.mount("http://", adapter)
|
||||
return sess
|
||||
|
||||
# Router: mostly single host, small connection pool is sufficient
|
||||
route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
|
||||
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
|
||||
worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
|
||||
|
||||
# Fire requests using a thread pool, scheduling at requested RPS
|
||||
inflight = set()
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
for i in range(num_requests):
|
||||
# Pace submissions to RPS
|
||||
target_time = t0 + i / max(requests_per_second, 1e-9)
|
||||
sleep_s = target_time - time.time()
|
||||
if sleep_s > 0:
|
||||
time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive
|
||||
|
||||
payload = CompletionsData.for_test()
|
||||
fut = executor.submit(
|
||||
do_one,
|
||||
endpoint_group_name,
|
||||
endpoint_id,
|
||||
endpoint_api_key,
|
||||
server_url,
|
||||
worker_endpoint,
|
||||
payload,
|
||||
results,
|
||||
t0,
|
||||
status_samples,
|
||||
route_session,
|
||||
worker_session,
|
||||
)
|
||||
inflight.add(fut)
|
||||
# Prevent unbounded queue growth
|
||||
if len(inflight) >= max_concurrency * submit_queue_factor:
|
||||
done, not_done = wait(inflight, return_when=FIRST_COMPLETED)
|
||||
inflight = not_done
|
||||
# Wait for all outstanding tasks
|
||||
if inflight:
|
||||
wait(inflight)
|
||||
# Close sessions
|
||||
try:
|
||||
route_session.close()
|
||||
finally:
|
||||
worker_session.close()
|
||||
|
||||
# Aggregate results
|
||||
oks = [r for r in results if r.ok]
|
||||
errs = [r for r in results if not r.ok]
|
||||
total_reqs = len(results)
|
||||
succ = len(oks)
|
||||
|
||||
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
|
||||
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
|
||||
route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
|
||||
|
||||
avg_total = float(np.mean(total_ms)) if succ else 0.0
|
||||
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
|
||||
avg_route = float(np.mean(route_ms)) if succ else 0.0
|
||||
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
|
||||
|
||||
# Distribution over workers (by host:port)
|
||||
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
|
||||
dist = Counter(hosts)
|
||||
|
||||
# Idle over time (mode per second)
|
||||
idle_ts, idle_vals = [], []
|
||||
if status_samples:
|
||||
buckets = {}
|
||||
for ts, idle in status_samples:
|
||||
k = int(ts)
|
||||
buckets.setdefault(k, []).append(idle)
|
||||
keys = sorted(buckets.keys())
|
||||
idle_ts = keys
|
||||
# Use the most frequent sampled value per second (mode) to keep integer counts
|
||||
idle_vals = []
|
||||
for k in keys:
|
||||
vals_k = [int(v) for v in buckets[k]]
|
||||
if vals_k:
|
||||
cnt = Counter(vals_k)
|
||||
idle_vals.append(cnt.most_common(1)[0][0])
|
||||
else:
|
||||
idle_vals.append(0)
|
||||
|
||||
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
|
||||
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
|
||||
print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
|
||||
if errs:
|
||||
print("Sample errors:")
|
||||
for e in errs[:5]:
|
||||
print(f" {e.status_code} {e.error}")
|
||||
|
||||
# Plot: 2x3 grid
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||||
fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}")
|
||||
|
||||
# Dist per worker
|
||||
ax0 = axes[0, 0]
|
||||
if dist:
|
||||
items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True)
|
||||
labels, counts = zip(*items)
|
||||
ax0.bar(range(len(labels)), counts)
|
||||
ax0.set_xticks(range(len(labels)))
|
||||
ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||
ax0.set_title("Request distribution over workers")
|
||||
ax0.set_ylabel("count")
|
||||
|
||||
# Latency histogram (total)
|
||||
ax1 = axes[0, 1]
|
||||
if succ:
|
||||
ax1.hist(total_ms, bins=30)
|
||||
ax1.set_title("Total latency (ms)")
|
||||
ax1.set_xlabel("ms")
|
||||
ax1.set_ylabel("freq")
|
||||
|
||||
# Eligible workers over time
|
||||
ax_idle = axes[0, 2]
|
||||
if idle_ts:
|
||||
ax_idle.plot(idle_ts, idle_vals, "-o", ms=3)
|
||||
ax_idle.set_title("Eligible workers over time")
|
||||
ax_idle.set_xlabel("time (s)")
|
||||
ax_idle.set_ylabel("eligible count")
|
||||
|
||||
# Throughput over time (completions/sec)
|
||||
ax_idle = axes[1, 0]
|
||||
ax_idle.clear()
|
||||
if succ:
|
||||
per_sec = {}
|
||||
for r in oks:
|
||||
s = int(r.t_end)
|
||||
per_sec[s] = per_sec.get(s, 0) + 1
|
||||
ts = sorted(per_sec.keys())
|
||||
vals = [per_sec[t] for t in ts]
|
||||
ax_idle.plot(ts, vals, "-o", ms=3)
|
||||
ax_idle.set_title("Completions per second")
|
||||
ax_idle.set_xlabel("time (s)")
|
||||
ax_idle.set_ylabel("completions / sec")
|
||||
|
||||
# Summary text
|
||||
ax3 = axes[1, 1]
|
||||
ax3.axis("off")
|
||||
text = (
|
||||
f"Total requests: {total_reqs}\n"
|
||||
f"Success: {succ} Errors: {len(errs)}\n"
|
||||
f"Avg total latency: {avg_total:.1f} ms\n"
|
||||
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
|
||||
f"Avg route latency: {avg_route:.1f} ms\n"
|
||||
f"Avg worker latency: {avg_worker:.1f} ms\n"
|
||||
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
|
||||
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
|
||||
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
|
||||
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
|
||||
)
|
||||
ax3.set_title("Summary")
|
||||
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
|
||||
|
||||
# Error count over time
|
||||
ax_errors = axes[1, 2]
|
||||
all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
|
||||
if all_end_times:
|
||||
min_second = min(all_end_times)
|
||||
max_second = max(all_end_times)
|
||||
# Count errors per second
|
||||
errors_per_second = {}
|
||||
for result in errs:
|
||||
second = int(result.t_end)
|
||||
errors_per_second[second] = errors_per_second.get(second, 0) + 1
|
||||
# Create complete timeline including zeros
|
||||
time_seconds = list(range(min_second, max_second + 1))
|
||||
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
|
||||
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
|
||||
ax_errors.set_title("Errors per second")
|
||||
ax_errors.set_xlabel("time (s)")
|
||||
ax_errors.set_ylabel("errors / sec")
|
||||
|
||||
# Ensure unique output path and create directory if needed
|
||||
final_out_path = get_incremented_path(out_path)
|
||||
out_dir = os.path.dirname(final_out_path)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
plt.savefig(final_out_path, dpi=120)
|
||||
print(f"Saved report to: {final_out_path}")
|
||||
|
||||
# Per-worker latency boxplot (top 12 by volume)
|
||||
groups = {}
|
||||
for r in oks:
|
||||
host = urlparse(r.worker_url).netloc
|
||||
groups.setdefault(host, []).append(r.total_ms)
|
||||
items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12]
|
||||
if items:
|
||||
labels, data = zip(*items)
|
||||
fig2, axb = plt.subplots(1, 1, figsize=(12, 5))
|
||||
axb.boxplot(data, showfliers=False)
|
||||
axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||
axb.set_title("Per-worker latency (ms)")
|
||||
axb.set_ylabel("ms")
|
||||
plt.tight_layout()
|
||||
extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png")
|
||||
plt.savefig(extra_out, dpi=120)
|
||||
fig2.tight_layout()
|
||||
fig2.savefig(extra_out, dpi=120)
|
||||
print(f"Saved worker latency plot to: {extra_out}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if MODEL_NAME environment variable is set
|
||||
model_name_set = os.environ.get("MODEL_NAME") is not None
|
||||
|
||||
# Add model argument - required only if MODEL_NAME is not set
|
||||
test_args.add_argument(
|
||||
"--model",
|
||||
dest="model",
|
||||
required=not model_name_set,
|
||||
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
||||
)
|
||||
|
||||
# Parse known args to get model early, before adding load args
|
||||
known_args, _ = test_args.parse_known_args()
|
||||
if hasattr(known_args, "model") and known_args.model:
|
||||
os.environ["MODEL_NAME"] = known_args.model
|
||||
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
||||
|
||||
# Load test args
|
||||
test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
|
||||
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
|
||||
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
|
||||
args = test_args.parse_args()
|
||||
|
||||
server_url = {
|
||||
"prod": "https://run.vast.ai",
|
||||
"alpha": "https://run-alpha.vast.ai",
|
||||
"candidate": "https://run-candidate.vast.ai",
|
||||
"local": "http://localhost:8080"
|
||||
}.get(args.instance, "http://localhost:8080")
|
||||
|
||||
run_load_with_metrics(
|
||||
num_requests=args.num_requests,
|
||||
requests_per_second=args.requests_per_second,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
server_url=server_url,
|
||||
worker_endpoint=WORKER_ENDPOINT,
|
||||
instance=args.instance,
|
||||
out_path=args.out_path,
|
||||
)
|
||||
+53
-111
@@ -1,119 +1,61 @@
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from vastai import Serverless
|
||||
import asyncio
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
|
||||
MAX_TOKENS = 1024
|
||||
PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
|
||||
async def call_generate(client: Serverless) -> None:
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
payload = {
|
||||
"inputs": PROMPT,
|
||||
"parameters": {
|
||||
"max_new_tokens": MAX_TOKENS,
|
||||
"temperature": 0.7,
|
||||
"return_full_text": False
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=url,
|
||||
)
|
||||
|
||||
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(url, json=req_data)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
print(res)
|
||||
|
||||
|
||||
def call_generate_stream(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
WORKER_ENDPOINT = "/generate_stream"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
print(f"url: {url}")
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
)
|
||||
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
response = requests.post(url, json=req_data, stream=True)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
for line in response.iter_lines():
|
||||
payload = line.decode().lstrip("data:").rstrip()
|
||||
if payload:
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
print(data["token"]["text"], end="")
|
||||
sys.stdout.flush()
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
log.warning(f"Failed to parse streaming response: {e}")
|
||||
continue
|
||||
print()
|
||||
|
||||
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
|
||||
|
||||
print(resp["response"]["generated_text"])
|
||||
|
||||
|
||||
async def call_generate_stream(client: Serverless) -> None:
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"inputs": PROMPT,
|
||||
"parameters": {
|
||||
"max_new_tokens": MAX_TOKENS,
|
||||
"temperature": 0.7,
|
||||
"do_sample": True,
|
||||
"return_full_text": False,
|
||||
}
|
||||
}
|
||||
|
||||
resp = await endpoint.request(
|
||||
"/generate_stream",
|
||||
payload,
|
||||
cost=MAX_TOKENS,
|
||||
stream=True,
|
||||
)
|
||||
stream = resp["response"]
|
||||
|
||||
printed_answer = False
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("Answer:\n", end="", flush=True)
|
||||
print(tok, end="", flush=True)
|
||||
|
||||
async def main():
|
||||
async with Serverless() as client:
|
||||
await call_generate(client)
|
||||
await call_generate_stream(client)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lib.test_utils import test_args
|
||||
|
||||
args = test_args.parse_args()
|
||||
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
)
|
||||
if endpoint_api_key:
|
||||
try:
|
||||
call_generate(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
call_generate_stream(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error during API call: {e}")
|
||||
else:
|
||||
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
|
||||
asyncio.run(main())
|
||||
|
||||
Reference in New Issue
Block a user