Compare commits

..

4 Commits

Author SHA1 Message Date
Colter Downing e756f61b9a graphing errors over time 2025-10-25 12:14:27 -07:00
Colter Downing 8cb98c84f9 non vibe coded test_load 2025-10-24 19:08:36 -07:00
Colter Downing e251afda2b improved test load 2025-10-24 12:53:35 -07:00
Lucas Armand 74bd932327 Suppress matplot debug logs 2025-10-24 12:30:20 -07:00
24 changed files with 805 additions and 1573 deletions
-1
View File
@@ -3,4 +3,3 @@
__pycache__
bin/
lib64
.venv
+3 -4
View File
@@ -39,12 +39,11 @@ reporting these metrics to the autoscaler.
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d)
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938)
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447)
Currently available workers:
* `openai`: A simple example worker for a basic vLLM server.
* `hello_world`: A simple example worker for a basic LLM server.
* `comfyui`: A worker for the ComfyUI image generation backend.
* `tgi`: A worker for the Text Generation Inference backend.
+29 -41
View File
@@ -26,11 +26,10 @@ from lib.data_types import (
LogAction,
ApiPayload_T,
JsonDataException,
RequestMetrics,
BenchmarkResult
RequestMetrics
)
VERSION = "0.2.1"
VERSION = "0.1.0"
MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
@@ -66,17 +65,10 @@ class Backend:
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
report_addr: str = dataclasses.field(
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
)
mtoken: str = dataclasses.field(
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
)
def __post_init__(self):
self.metrics = Metrics()
self.metrics._set_version(self.version)
self.metrics._set_mtoken(self.mtoken)
self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False
@@ -111,19 +103,23 @@ class Backend:
#######################################Private#######################################
def _fetch_pubkey(self):
report_addr = self.report_addr.rstrip("/")
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
try:
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = None
for _ in range(5):
try:
key = RSA.import_key(result)
if key is not None:
return key
except (ValueError , subprocess.CalledProcessError) as e:
break
except ValueError as e:
log.debug(f"Error downloading key: {e}")
time.sleep(15)
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
async def __handle_request(
self,
@@ -289,7 +285,7 @@ class Backend:
message = {
key: value
for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature" and key != "__request_id"
if key != "signature"
}
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug(
@@ -299,7 +295,7 @@ class Backend:
elif message in self.msg_history:
log.debug(f"message: {message} already in message history")
return False
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
@@ -318,10 +314,10 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark")
# trigger model load
# payload = self.benchmark_handler.make_benchmark_payload()
# _ = await self.__call_api(
# handler=self.benchmark_handler, payload=payload
# )
payload = self.benchmark_handler.make_benchmark_payload()
_ = await self.__call_api(
handler=self.benchmark_handler, payload=payload
)
return float(f.readline())
except FileNotFoundError:
pass
@@ -336,26 +332,18 @@ class Backend:
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
start = time.time()
benchmark_requests = []
tasks = []
total_workload = 0
for i in range(concurrent_requests):
for _ in range(concurrent_requests):
payload = self.benchmark_handler.make_benchmark_payload()
workload = payload.count_workload()
task = self.__call_api(handler=self.benchmark_handler, payload=payload)
benchmark_requests.append(
BenchmarkResult(request_idx=i, workload=workload, task=task)
total_workload += payload.count_workload()
tasks.append(
self.__call_api(handler=self.benchmark_handler, payload=payload)
)
responses = await gather(*[br.task for br in benchmark_requests])
for br, response in zip(benchmark_requests, responses):
br.response = response
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
responses = await gather(*tasks)
time_elapsed = time.time() - start
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
if successful_responses == 0:
self.backend_errored("No successful responses from benchmark")
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
throughput = total_workload / time_elapsed
sum_throughput += throughput
@@ -369,7 +357,7 @@ class Backend:
f"Run: {run}, concurrent_requests: {concurrent_requests}",
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
f"Throughput: {throughput} workload/s",
f"Successful responses: {successful_responses}/{concurrent_requests}",
f"Successful responses: {len([r for r in responses if r.status == 200])}",
"#" * 60,
]
)
@@ -396,7 +384,7 @@ class Backend:
)
# some backends need a few seconds after logging successful startup before
# they can begin accepting requests
# await sleep(5)
await sleep(5)
try:
max_throughput = await run_benchmark()
self.__start_healthcheck = True
@@ -417,7 +405,7 @@ class Backend:
async def tail_log():
log.debug(f"tailing file: {self.model_log_file}")
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
async with await open_file(self.model_log_file) as f:
while True:
line = await f.readline()
if line:
+5 -18
View File
@@ -3,7 +3,7 @@ import logging
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
from aiohttp import web, ClientResponse
import inspect
@@ -65,12 +65,12 @@ class ApiPayload(ABC):
class AuthData:
"""data used to authenticate requester"""
signature: str
cost: str
endpoint: str
reqnum: int
request_idx: int
signature: str
url: str
request_idx: int
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]):
@@ -190,11 +190,10 @@ class SystemMetrics:
self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage
def reset(self, expected: float | None) -> None:
def reset(self):
# autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading
if self.model_loading_time == expected:
self.model_loading_time = None
@@ -207,17 +206,6 @@ class RequestMetrics:
status: str
success: bool = False
@dataclass
class BenchmarkResult:
request_idx: int
workload: float
task: Awaitable[ClientResponse]
response: Optional[ClientResponse] = None
@property
def is_successful(self) -> bool:
return self.response is not None and self.response.status == 200
@dataclass
class ModelMetrics:
"""Model specific metrics"""
@@ -258,7 +246,7 @@ class ModelMetrics:
def wait_time(self) -> float:
if (len(self.requests_working) == 0):
return 0.0
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
return sum([request.workload for request in self.requests_working.values()]) / self.max_throughput
@property
def cur_load(self) -> float:
@@ -286,7 +274,6 @@ class AutoScalerData:
"""Data that is reported to autoscaler"""
id: int
mtoken: str
version: str
loadtime: float
cur_load: float
+14 -63
View File
@@ -28,7 +28,6 @@ def get_url() -> str:
@dataclass
class Metrics:
version: str = "0"
mtoken: str = ""
last_metric_update: float = 0.0
last_request_served: float = 0.0
update_pending: bool = False
@@ -143,80 +142,44 @@ class Metrics:
def _set_version(self, version: str) -> None:
self.version = version
def _set_mtoken(self, mtoken: str) -> None:
self.mtoken = mtoken
#######################################Private#######################################
async def __send_delete_requests_and_reset(self):
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
async def send_data(report_addr: str, success: bool) -> bool:
data = {
"worker_id": self.id,
"mtoken": self.mtoken,
"request_idxs": idxs,
"success": success_flag,
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success],
"success": success
}
log.debug(
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
)
full_path = report_addr.rstrip("/") + "/delete_requests/"
for attempt in range(1, 4):
try:
session = await self.http()
async with session.post(full_path, json=data) as res:
log.debug(f"delete_requests response: {res.status}")
res.raise_for_status()
return True
except asyncio.TimeoutError:
log.debug("delete_requests timed out")
log.debug(f"delete_requests timed out")
except (ClientResponseError, Exception) as e:
log.debug(f"delete_requests failed with error: {e}")
await asyncio.sleep(2)
log.debug(f"retrying delete_request, attempt: {attempt}")
return False
# Take a snapshot of what we plan to send this tick.
# New arrivals after this snapshot will remain in the queue for the next tick.
snapshot = list(self.model_metrics.requests_deleting)
success_idxs = [r.request_idx for r in snapshot if r.success is True]
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
if not success_idxs and not failed_idxs:
return # nothing to do
for report_addr in self.report_addr:
# TODO: Add a Redis subscriber queue for delete_requests
if report_addr == "https://cloud.vast.ai/api/v0":
# Patch: ignore the Redis API report_addr
continue
sent_success = True
sent_failed = True
if success_idxs:
sent_success = await post(report_addr, success_idxs, True)
if failed_idxs:
sent_failed = await post(report_addr, failed_idxs, False)
if sent_success and sent_failed:
# Remove only the items we actually sent from the live queue.
sent_set = set(success_idxs) | set(failed_idxs)
self.model_metrics.requests_deleting[:] = [
r for r in self.model_metrics.requests_deleting
if r.request_idx not in sent_set
]
success = await send_data(report_addr, success=True) and await send_data(report_addr, success=False)
if success is True:
self.model_metrics.requests_deleting.clear()
break
async def __send_metrics_and_reset(self):
loadtime_snapshot = self.system_metrics.model_loading_time
def compute_autoscaler_data() -> AutoScalerData:
return AutoScalerData(
id=self.id,
mtoken=self.mtoken,
version=self.version,
loadtime=(loadtime_snapshot or 0.0),
loadtime=(self.system_metrics.model_loading_time or 0.0),
new_load=self.model_metrics.workload_processing,
cur_load=self.model_metrics.cur_load,
rej_load=self.model_metrics.workload_rejected,
@@ -234,25 +197,17 @@ class Metrics:
async def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data()
log_data = asdict(data)
def obfuscate(secret: str) -> str:
if secret is None:
return ""
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
full_path = report_addr.rstrip("/") + "/worker_status/"
log.debug(
"\n".join(
[
"#" * 60,
f"sending data to autoscaler",
f"{json.dumps(log_data, indent=2)}",
f"{json.dumps((asdict(data)), indent=2)}",
"#" * 60,
]
)
)
full_path = report_addr.rstrip("/") + "/worker_status/"
for attempt in range(1, 4):
try:
session = await self.http()
@@ -272,15 +227,11 @@ class Metrics:
self.system_metrics.update_disk_usage()
sent = False
for report_addr in self.report_addr:
if await send_data(report_addr):
sent = True
success = await send_data(report_addr)
if success is True:
break
if sent:
# clear the one-shot loadtime only if we actually sent *this* value
self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False
self.model_metrics.reset()
self.system_metrics.reset()
self.last_metric_update = time.time()
+1 -21
View File
@@ -3,17 +3,15 @@ import logging
from typing import List
import ssl
from asyncio import run, gather
import asyncio
from lib.backend import Backend
from lib.metrics import Metrics
from aiohttp import web
log = logging.getLogger(__file__)
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
try:
log.debug("getting certificate...")
use_ssl = os.environ.get("USE_SSL", "false") == "true"
if use_ssl is True:
@@ -40,21 +38,3 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
await gather(site.start(), backend._start_tracking())
run(main())
except Exception as e:
err_msg = f"PyWorker failed to launch: {e}"
log.error(err_msg)
async def beacon():
metrics = Metrics()
metrics._set_version(getattr(backend, "version", "0"))
metrics._set_mtoken(getattr(backend, "mtoken", ""))
try:
while True:
metrics._model_errored(err_msg)
await metrics._Metrics__send_metrics_and_reset()
await asyncio.sleep(10)
finally:
await metrics.aclose()
run(beacon())
-1
View File
@@ -8,4 +8,3 @@ Requests~=2.32
transformers~=4.52
utils==1.0.*
hf_transfer>=0.1.9
git+https://github.com/vast-ai/vast-sdk.git@worker-sdk
+5 -58
View File
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR"
@@ -41,14 +41,6 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG
echo_var MODEL_LOG
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
if [ -e "$MODEL_LOG" ]; then
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
cat "$MODEL_LOG" >> "$MODEL_LOG.old"
: > "$MODEL_LOG"
fi
# Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
@@ -132,54 +124,9 @@ cd "$SERVER_DIR"
echo "launching PyWorker server"
set +e
# Try worker entrypoint first
echo "trying workers.${BACKEND}.worker"
python3 -m "workers.${BACKEND}.worker" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
# If that fails, fall back to server
if [ "${PY_STATUS}" -ne 0 ]; then
echo "workers.${BACKEND}.worker failed with status ${PY_STATUS}, trying workers.${BACKEND}.server"
python3 -m "workers.${BACKEND}.server" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
fi
set -e
if [ "${PY_STATUS}" -ne 0 ]; then
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
MTOKEN="${MASTER_TOKEN:-}"
VERSION="${PYWORKER_VERSION:-0}"
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
for addr in "${REPORT_ADDRS[@]}"; do
curl -sS -X POST -H 'Content-Type: application/json' \
-d "$(cat <<JSON
{
"id": ${CONTAINER_ID:-0},
"mtoken": "${MTOKEN}",
"version": "${VERSION}",
"loadtime": 0,
"new_load": 0,
"cur_load": 0,
"rej_load": 0,
"max_perf": 0,
"cur_perf": 0,
"error_msg": "${ERROR_MSG}",
"num_requests_working": 0,
"num_requests_recieved": 0,
"additional_disk_usage": 0,
"working_request_idxs": [],
"cur_capacity": 0,
"max_capacity": 0,
"url": "${URL}"
}
JSON
)" "${addr%/}/worker_status/" || true
done
fi
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
echo "launching PyWorker server done"
-184
View File
@@ -1,184 +0,0 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_lyrics = [
"[verse]\nGuardian cloaked in twilight hue\nShadows melt where he breaks through\nEchoes swirl in mystic flight\nHooded hero owns the night\n\n[verse]\nThrough the chaos shapes arise\nFeral whispers, glowing eyes\nOrcs and creatures side by side\nMarch within the inky tide\n\n[chorus]\nRise above the fear and gloom\nLet your courage fully bloom\nIn the darkness stand your ground\nHear the night proclaim your sound",
"[verse]\nMorning sun on fields of gold\nGentle stories unfold\nEvery breeze a quiet song\nWhere the peaceful hearts belong\n\n[verse]\nLanterns glow at stable doors\nRustling leaves on orchard floors\nSimple joys in every hand\nLife grows soft in fertile land\n\n[chorus]\nLet the day drift slow and free\nRoot your soul where you can be\nIn this haven warm and bright\nFeel the earth breathe pure delight",
"[verse]\nLittle feet on dusty ground\nChasing dreams without a sound\nSoccer ball in morning light\nHopes take wing in youthful flight\n\n[verse]\nChrome reflections paint the day\nSwagger in the steps that play\nCopper tones in shining air\nChildhood gleaming everywhere\n\n[chorus]\nKick the world with boundless cheer\nHold the magic close and near\nIn each moment bold and true\nLet the sky belong to you",
"[verse]\nSunset bleeds across the street\nGilded calm in summer heat\nLow-rise towers rimmed with fire\nDreams ignite as lights climb higher\n\n[verse]\nFootsteps scatter through the haze\nFutures shimmer in the blaze\nEvery window tells a tale\nFloating through a tangerine veil\n\n[chorus]\nLet the neon softly glow\nLet your restless heartbeat slow\nIn this city forged in light\nCarry hope into the night",
"[verse]\nOcean breathes in rolling arcs\nSprays of diamond, glowing sparks\nWaves unfold a perfect line\nNatures rhythm feels divine\n\n[verse]\nSun above in golden sweep\nPaints the rise of every deep\nShimmer drifting through the blue\nWorld reborn in every view\n\n[chorus]\nLet the tide pull you along\nHear the waters ancient song\nIn the cresting waves youll find\nQuiet peace for heart and mind",
"[verse]\nGlass aglow with swirling light\nFruits and mints in colors bright\nIcy whispers clink and chime\nFlowing forms suspend in time\n\n[verse]\nCreamy spirals drift within\nGentle currents slowly spin\nWarm reflections lingering sweet\nMixing flavors at your feet\n\n[chorus]\nSip the glow and let it rise\nTaste the sunset in disguise\nIn this moment clear and true\nLet the warmth flow into you",
"[verse]\nEngines rumble down the lane\nCopper clouds of steam and rain\nOilpunk dreams in metal shine\nRider drifting down the line\n\n[verse]\nLeather jacket, steady glare\nStories sparking in the air\nMagazine lights frame his face\nKing of roads in timeless grace\n\n[chorus]\nThrottle up beyond the bend\nFeel the force of steel ascend\nRide the night and hold on tight\nClaim the world in streaks of light",
"[verse]\nCut-out shapes in swirling play\nTextures dance in bold array\nCats in denim, grinning wide\nStrut across the patterned tide\n\n[verse]\nPosters hum with neon glow\nSurreal scenes begin to grow\nColors crisp as folded art\nPatchwork beating like a heart\n\n[chorus]\nLet the collage come alive\nWatch the vibrant pieces thrive\nIn this joyful, crafted space\nEvery shape finds its own place",
"[verse]\nTiny world in crystal glass\nAncient tales behind the mass\nVillage lights in winter gleam\nFrozen in a mystic dream\n\n[verse]\nLantern beams in swirling air\nSoft enchantment everywhere\nShadows drift with gentle grace\nMagic sealed within the space\n\n[chorus]\nHold the sphere and you will see\nEchoes of a memory\nIn the glow of fragile light\nLives a realm of pure delight",
"[verse]\nArmor hums with power bright\nChopping sparks in jungle night\nMecha spirits shift and scream\nThrough the ferns like shattered beams\n\n[verse]\nAxes blaze in glowing arcs\nLighting up the shadowed marks\nNature roars in trembling air\nClash of steel and cosmic flare\n\n[chorus]\nRaise the fire, strike the ground\nLet your legend shake the sound\nIn the wild where echoes roam\nForge the fight and carve your home",
"[verse]\nCrowds ignite in vibrant flare\nBeats explode through smoky air\nDJ robes replaced with flame\nPope on decks in holy frame\n\n[verse]\nLeather gleams in blinding light\nTurntables spin with sacred might\nChoirs echo in the bass\nHeaven pulses through the place\n\n[chorus]\nLift the roof and shake the floor\nSacred rhythm evermore\nLet the music take control\nFeel the blessing in your soul",
]
benchmark_dataset = [
{
"input": {
"request_id": "",
"workflow_json": {
"14": {
"inputs": {
"tags": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic",
"lyrics": lyrics,
"lyrics_strength": 0.99,
"clip": ["40", 1]
},
"class_type": "TextEncodeAceStepAudio",
"_meta": {
"title": "TextEncodeAceStepAudio"
}
},
"17": {
"inputs": {
"seconds": 180,
"batch_size": 1
},
"class_type": "EmptyAceStepLatentAudio",
"_meta": {
"title": "EmptyAceStepLatentAudio"
}
},
"18": {
"inputs": {
"samples": ["52", 0],
"vae": ["40", 2]
},
"class_type": "VAEDecodeAudio",
"_meta": {
"title": "VAE Decode Audio"
}
},
"40": {
"inputs": {
"ckpt_name": "ace_step_v1_3.5b.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"44": {
"inputs": {
"conditioning": ["14", 0]
},
"class_type": "ConditioningZeroOut",
"_meta": {
"title": "ConditioningZeroOut"
}
},
"49": {
"inputs": {
"model": ["51", 0],
"operation": ["50", 0]
},
"class_type": "LatentApplyOperationCFG",
"_meta": {
"title": "LatentApplyOperationCFG"
}
},
"50": {
"inputs": {
"multiplier": 1.15
},
"class_type": "LatentOperationTonemapReinhard",
"_meta": {
"title": "LatentOperationTonemapReinhard"
}
},
"51": {
"inputs": {
"shift": 6,
"model": ["40", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"52": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 65,
"cfg": 4,
"sampler_name": "er_sde",
"scheduler": "linear_quadratic",
"denoise": 1,
"model": ["49", 0],
"positive": ["14", 0],
"negative": ["44", 0],
"latent_image": ["17", 0]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"59": {
"inputs": {
"filename_prefix": "audio/ComfyUI",
"quality": "V0",
"audioUI": "",
"audio": ["18", 0]
},
"class_type": "SaveAudioMP3",
"_meta": {
"title": "Save Audio (MP3)"
}
}
}
}
} for lyrics in benchmark_lyrics
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
runs=1
),
workload_calculator= lambda _ : 1000.0
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+3 -15
View File
@@ -12,21 +12,9 @@ A docker image is provided but you may use any if the above requirements are met
## Benchmarking
### Custom Benchmark Workflows
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
**Ways to provide the benchmark file:**
- Fork this repository and add your `benchmark.json` file
- Write the file during worker provisioning (onstart script or setup phase)
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
### Default Benchmark (Fallback)
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
| Environment Variable | Default Value | Description |
| -------------------- | ------------- | ----------- |
@@ -36,7 +24,7 @@ The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
#### Calibrating Fallback Benchmark Duration
### Calibrating Benchmark Duration
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
+133 -13
View File
@@ -1,16 +1,107 @@
from .data_types import count_workload
import logging
import uuid
import random
import asyncio
import random
from urllib.parse import urljoin
import json
from vastai import Serverless
import requests
async def main():
async with Serverless() as client:
endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name
from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types import count_workload
payload = {
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_text2image_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
"""Simple Text2Image using the new modifier-based approach"""
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
"""Helper function for making requests with consistent error handling"""
try:
response = requests.post(
url,
json=payload,
timeout=timeout,
verify=verify
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred during {context}: {http_err}")
log.error(f"Status Code: {response.status_code}")
log.error("Response content:", response.text)
return None
except requests.exceptions.Timeout:
log.error(f"Timeout occurred during {context}: {url}")
return None
except requests.exceptions.ConnectionError:
log.error(f"Connection error occurred during {context}: {url}")
return None
except json.JSONDecodeError as json_err:
log.error(f"Failed to decode JSON response during {context}: {json_err}")
if 'response' in locals():
print("Response content:", response.text)
return None
except Exception as err:
log.error(f"An unexpected error occurred during {context}: {err}")
if 'response' in locals():
log.error("Response content (if available):", response.text)
return None
WORKER_ENDPOINT = "/generate/sync"
# This worker has concurrency = 1. All workloads have cost value 1.0
COST = count_workload()
# Route to get worker URL
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
# First request - get routing information
route_response = make_request(
url=urljoin(server_url, "/route/"),
payload=route_payload,
timeout=4,
context="route request"
)
if route_response is None:
return None
if "url" not in route_response or not route_response["url"]:
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
return None
if "status" in route_response:
print(f"Autoscaler status: {route_response['status']}")
return None
# Extract data from route response
url = route_response["url"]
auth_data = dict(
signature=route_response["signature"],
cost=route_response["cost"],
endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"],
url=route_response["url"],
)
# Build the payload for the worker request
worker_payload = {
"input": {
"request_id": str(uuid.uuid4()),
"modifier": "Text2Image",
@@ -25,11 +116,40 @@ async def main():
}
}
response = await endpoint.request("/generate/sync", payload, cost=count_workload())
req_data = dict(payload=worker_payload, auth_data=auth_data)
worker_url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {worker_url}")
# Second request - call the worker endpoint
worker_response = make_request(
url=worker_url,
payload=req_data,
verify=get_cert_file_path(),
context="worker request"
)
return worker_response
# Get the file from the path on the local machine using SCP or SFTP
# or configure S3 to upload to cloud storage.
print(response["response"]["output"][0]["local_path"])
if __name__ == "__main__":
asyncio.run(main())
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
result = call_text2image_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
if result is None:
log.error("Text2Image workflow failed")
else:
print(result)
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
+4 -28
View File
@@ -5,13 +5,12 @@ import dataclasses
from typing import Dict, Any
from functools import cache
from math import ceil
from pathlib import Path
import json
import logging
from lib.data_types import ApiPayload, JsonDataException
log = logging.getLogger(__file__)
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
def count_workload() -> float:
# Always 100.0 where there is a single instance of ComfyUI handling requests
@@ -25,32 +24,9 @@ class ComfyWorkflowData(ApiPayload):
@classmethod
def for_test(cls):
"""
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
Otherwise, use the variables available to simulate workflows of the required running time
Use the variables available to simulate workflows of the required running time
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
"""
# Try to load benchmark.json
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
if benchmark_file.exists():
try:
with open(benchmark_file, "r") as f:
benchmark_workflow = json.load(f)
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"workflow_json": benchmark_workflow
}
)
except (json.JSONDecodeError, IOError):
# JSON is malformed or file can't be read, fall through to default
log.error(f"Failed to benchmark using {benchmark_file}")
# Fallback: read prompts and construct payload
log.info("Using fallback method for benchmarking")
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
test_prompt = random.choice(test_prompts).rstrip()
return cls(
input={
@@ -1,107 +0,0 @@
{
"3": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "text, watermark",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
-1
View File
@@ -19,7 +19,6 @@ MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: ", # This error is emitted when the model file is not there at all
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
]
-81
View File
@@ -1,81 +0,0 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_prompts = [
"Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.",
"Cozy farming-game scene with fine details.",
"2D vector child with soccer ball; airbrush chrome; swagger; antique copper.",
"Realistic futuristic downtown of low buildings at sunset.",
"Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.",
"Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.",
"Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.",
"Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.",
"Medieval village inside glass sphere; volumetric light; macro focus.",
"Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.",
"Pope Francis DJ in leather jacket, mixing on giant console; dramatic.",
]
benchmark_dataset = [
{
"input": {
"request_id": f"test-{random.randint(1000, 99999)}",
"modifier": "Text2Image",
"modifications": {
"prompt": prompt,
"width": 512,
"height": 512,
"steps": 20,
"seed": random.randint(0, sys.maxsize)
}
}
} for prompt in benchmark_prompts
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
)
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+12 -6
View File
@@ -7,13 +7,20 @@ from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from vastai import Serverless
"""
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
"""
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
ENDPOINT_NAME = "my-comfyui-endpoint"
COST = 100 # Use a constant cost for image generation
def call_default_workflow(client: Serverless) -> None:
def call_default_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/prompt"
COST = 100
route_payload = {
@@ -75,7 +82,6 @@ def call_custom_workflow_for_sd3(
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
request_idx=message["request_idx"],
)
workflow = {
"3": {
+411 -341
View File
@@ -1,15 +1,14 @@
import logging
import json
import os
import sys
import json
import subprocess
import argparse
from typing import Any, Dict, List, Optional
from urllib.parse import urljoin
from typing import Dict, Any, Optional, Iterator, Union, List
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig
from vastai import Serverless
import asyncio
# ---------------------- Logging ----------------------
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
@@ -17,20 +16,135 @@ logging.basicConfig(
)
log = logging.getLogger(__file__)
# ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? "
"What do you think this directory might be for?"
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
class APIClient:
"""Lightweight client focused solely on API communication"""
# Remove the generic WORKER_ENDPOINT since we're now going direct
DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4
def __init__(
self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
):
self.endpoint_group_name = endpoint_group_name
self.api_key = api_key
self.server_url = server_url
self.endpoint_api_key = endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service"""
if not self.endpoint_api_key:
raise ValueError("No valid endpoint API key available")
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.endpoint_api_key,
"cost": cost,
}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
timeout=self.DEFAULT_TIMEOUT,
)
response.raise_for_status()
return response.json()
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Create auth data from routing response"""
return {
"signature": message["signature"],
"cost": message["cost"],
"endpoint": message["endpoint"],
"reqnum": message["reqnum"],
"url": message["url"],
}
def _make_request(
self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint"""
# Get worker URL and auth data
cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost)
worker_url = message["url"]
auth_data = self._create_auth_data(message)
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}")
log.debug(f"Payload: {req_data}")
# Make the request using the specified method
if method.upper() == "POST":
response = requests.post(
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET":
response = requests.get(
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
if stream:
return self._handle_streaming_response(response)
else:
return response.json()
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
"""Handle streaming response and yield tokens"""
try:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
yield data # Yield the full chunk
except json.JSONDecodeError:
continue
except Exception as e:
log.error(f"Error handling streaming response: {e}")
raise
def call_completions(
self, config: CompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/completions", stream=config.stream
)
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
)
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- Tooling ----------------------
class ToolManager:
"""Handles tool definitions and execution"""
@@ -50,7 +164,7 @@ class ToolManager:
@staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""OpenAI-compatible tool schema"""
"""Get the ls tool definition"""
return [
{
"type": "function",
@@ -64,217 +178,98 @@ class ToolManager:
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result"""
function_name = (tool_call.get("function") or {}).get("name")
function_name = tool_call["function"]["name"]
if function_name == "list_files":
return self.list_files()
else:
raise ValueError(f"Unknown tool function: {function_name}")
# ----- Helpers to handle streamed tool_calls assembly -----
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
"""
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
We merge into a per-index state dict until the assistant message finishes.
"""
idx = tc_delta.get("index")
if idx is None:
return
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
if tc_delta.get("id"):
entry["id"] = tc_delta["id"]
fn_delta = tc_delta.get("function") or {}
if "name" in fn_delta and fn_delta["name"]:
entry["function"]["name"] = fn_delta["name"]
if "arguments" in fn_delta and fn_delta["arguments"]:
entry["function"]["arguments"] += fn_delta["arguments"]
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
return [state[i] for i in sorted(state.keys())]
# ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
}
}
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
# ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
}
}
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
# ---------------------- Demo Runner ----------------------
class APIDemo:
"""Demo and testing functionality for the API client"""
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
def __init__(
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
self.client = client
self.model = model
self.tool_manager = tool_manager or ToolManager()
# ----- Streaming handler -----
async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
def handle_streaming_response(
self, response_stream, show_reasoning: bool = True
) -> str:
"""
Handle streaming chat response and display all output.
"""
full_response = ""
reasoning_content = ""
printed_reasoning = False
printed_answer = False
reasoning_started = False
content_started = False
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
for chunk in response_stream:
# Normalize the chunk
if isinstance(chunk, str):
chunk = chunk.strip()
if chunk.startswith("data: "):
chunk = chunk[6:].strip()
if chunk in ["[DONE]", ""]:
continue
try:
parsed_chunk = json.loads(chunk)
except json.JSONDecodeError:
continue
elif isinstance(chunk, dict):
parsed_chunk = chunk
else:
continue
# reasoning tokens
rc = delta.get("reasoning_content")
if rc and show_reasoning:
if not printed_reasoning:
# Parse delta from the chunk
choices = parsed_chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
reasoning_token = delta.get("reasoning_content", "")
content_token = delta.get("content", "")
# Print reasoning token if applicable
if show_reasoning and reasoning_token:
if not reasoning_started:
print("\n🧠 Reasoning: ", end="", flush=True)
printed_reasoning = True
print(rc, end="", flush=True)
reasoning_content += rc
reasoning_started = True
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
reasoning_content += reasoning_token
# content tokens
content_part = delta.get("content")
if content_part:
if not printed_answer:
if show_reasoning and printed_reasoning:
print("\n💬 Response: ", end="", flush=True)
# Print content token
if content_token:
if not content_started:
if show_reasoning and reasoning_started:
print(f"\n💬 Response: ", end="", flush=True)
else:
print("Assistant: ", end="", flush=True)
printed_answer = True
print(content_part, end="", flush=True)
full_response += content_part
content_started = True
print(content_token, end="", flush=True)
full_response += content_token
print() # Ensure newline after response
print() # newline
if show_reasoning:
if printed_reasoning or printed_answer:
if reasoning_started or content_started:
print("\nStreaming completed.")
if printed_reasoning:
if reasoning_started:
print(f"Reasoning tokens: {len(reasoning_content.split())}")
if printed_answer:
if content_started:
print(f"Response tokens: {len(full_response.split())}")
return full_response
async def demo_completions(self) -> None:
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
def test_tool_support(self) -> bool:
"""Test if the endpoint supports function calling"""
log.debug("Testing endpoint tool calling support...")
response = await call_completions(
client=self.client,
model=self.model,
prompt=COMPLETIONS_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print("\nResponse:")
print(json.dumps(response, indent=2))
async def demo_chat(self, use_streaming: bool = True) -> None:
print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
print("=" * 60)
messages = [{"role": "user", "content": CHAT_PROMPT}]
if use_streaming:
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
try:
await self.handle_streaming_response(stream, show_reasoning=True)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
else:
response = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def test_tool_support(self) -> bool:
"""Probe that tool schema is accepted (no actual call)"""
# Try a simple request with minimal tools to test support
messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [
{
@@ -282,147 +277,170 @@ class APIDemo:
"function": {"name": "test_function", "description": "Test function"},
}
]
try:
_ = await call_chat_completions(
client=self.client,
config = ChatCompletionConfig(
model=self.model,
messages=messages,
max_tokens=10,
tools=minimal_tool,
tool_choice="none",
max_tokens=10
tool_choice="none", # Don't actually call the tool
)
try:
response = self.client.call_chat_completions(config)
return True
except Exception as e:
log.error("Endpoint does not support tool calling: %s", e)
log.error(f"Error: Endpoint does not support tool calling: {e}")
return False
async def demo_ls_tool(self) -> None:
"""Ask to list files using function calling, then provide final analysis"""
def demo_completions(self) -> None:
"""Demo: test basic completions endpoint"""
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
config = CompletionConfig(
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
)
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
)
response = self.client.call_completions(config)
if isinstance(response, dict):
print("\nResponse:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_chat(self, use_streaming: bool = True) -> None:
"""
Demo: test chat completions endpoint with optional streaming
"""
print("=" * 60)
print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60)
config = ChatCompletionConfig(
model=self.model,
messages=[{"role": "user", "content": CHAT_PROMPT}],
stream=use_streaming,
)
log.info(f"Testing chat completions with model '{self.model}'...")
response = self.client.call_chat_completions(config)
if use_streaming:
try:
self.handle_streaming_response(response, show_reasoning=True)
except Exception as e:
log.error(f"\nError during streaming: {e}")
import traceback
traceback.print_exc()
return
else:
if isinstance(response, dict):
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
print("=" * 60)
print("TOOL USE DEMO: List Directory Contents")
print("=" * 60)
if not await self.test_tool_support():
# Test if tools are supported first
if not self.test_tool_support():
return
messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
# Request with tool available
messages = [{"role": "user", "content": TOOLS_PROMPT}]
# First pass: let the model decide tools, stream tool_calls and partial content
stream = await stream_chat_completions(
client=self.client,
config = ChatCompletionConfig(
model=self.model,
messages=messages,
tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto",
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
assistant_content_buf: List[str] = []
tool_calls_state: Dict[int, Dict[str, Any]] = {}
printed_reasoning = False
printed_answer = False
log.info(f"Making initial request with tool using model '{self.model}'...")
response = self.client.call_chat_completions(config)
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
if not isinstance(response, dict):
raise ValueError("Expected dict response for tool use")
rc = delta.get("reasoning_content")
if rc:
if not printed_reasoning:
printed_reasoning = True
print("🧠 Reasoning: ", end="", flush=True)
print(rc, end="", flush=True)
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
content_part = delta.get("content")
if content_part:
assistant_content_buf.append(content_part)
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(content_part, end="", flush=True)
print(f"Assistant response: {message.get('content', 'No content')}")
if "tool_calls" in delta and delta["tool_calls"]:
for tc_delta in delta["tool_calls"]:
_merge_tool_call_delta(tool_calls_state, tc_delta)
# Check for tool calls
tool_calls = message.get("tool_calls")
if not tool_calls:
raise ValueError(
"No tool calls made - model may not support function calling"
)
# If no tool calls, were done.
if not tool_calls_state:
print("\n(No tool calls were made.)")
return
print(f"Tool calls detected: {len(tool_calls)}")
# Build assistant message with tool_calls
assistant_message = {
"role": "assistant",
"content": "".join(assistant_content_buf) if assistant_content_buf else None,
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
# Execute the tool call
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
print(f"Executing tool: {function_name}")
tool_result = self.tool_manager.execute_tool_call(tool_call)
print(f"Tool result:\n{tool_result}")
# Add tool result and continue conversation
messages.append(message) # Add assistant's message with tool call
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": tool_result,
}
messages.append(assistant_message)
)
# Execute tools and feed results back
for tc in assistant_message["tool_calls"]:
tool_name = (tc.get("function") or {}).get("name")
call_id = tc.get("id")
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
try:
args = json.loads(raw_args) if raw_args.strip() else {}
except Exception as e:
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
continue
try:
if tool_name == "list_files":
tool_result = self.tool_manager.list_files()
else:
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
except Exception as e:
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
print("\n[Tool executed]", tool_name)
print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
# Second pass: get final streamed answer after tool results
stream2 = await stream_chat_completions(
client=self.client,
# Get final response
final_config = ChatCompletionConfig(
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
tools=self.tool_manager.get_ls_tool_definition(),
)
final_buf = []
printed_reasoning2 = False
printed_answer2 = False
print("Getting final response...")
final_response = self.client.call_chat_completions(final_config)
async for chunk in stream2:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
rc2 = delta.get("reasoning_content")
if rc2:
if not printed_reasoning2:
printed_reasoning2 = True
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
print(rc2, end="", flush=True)
c2 = delta.get("content")
if c2:
final_buf.append(c2)
if not printed_answer2:
printed_answer2 = True
print("\n💬 Response (final): ", end="", flush=True)
print(c2, end="", flush=True)
if isinstance(final_response, dict):
final_choice = final_response.get("choices", [{}])[0]
final_message = final_choice.get("message", {})
final_content = final_message.get("content", "")
print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:")
print("=" * 60)
print("".join(final_buf))
print(final_content)
print("=" * 60)
async def interactive_chat(self) -> None:
def interactive_chat(self) -> None:
"""Interactive chat session with streaming"""
print("=" * 60)
print("INTERACTIVE STREAMING CHAT")
@@ -431,7 +449,7 @@ class APIDemo:
print("Type 'quit' to exit, 'clear' to clear history")
print()
messages: List[Dict[str, Any]] = []
messages = []
while True:
try:
@@ -449,15 +467,16 @@ class APIDemo:
messages.append({"role": "user", "content": user_input})
print("Assistant: ", end="", flush=True)
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=0.7
config = ChatCompletionConfig(
model=self.model, messages=messages, stream=True, temperature=0.7
)
print("Assistant: ", end="", flush=True)
response = self.client.call_chat_completions(config)
assistant_content = self.handle_streaming_response(
response, show_reasoning=True
)
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
# Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content})
@@ -466,64 +485,115 @@ class APIDemo:
print("\n👋 Chat interrupted. Goodbye!")
break
except Exception as e:
log.error("\nError: %s", e)
log.error(f"\nError: {e}")
continue
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", required=True, help="Model to use for requests (required)")
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
def main():
"""Main function with CLI switches for different tests"""
from lib.test_utils import test_args
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
return p
# Add mandatory model argument
test_args.add_argument(
"--model", required=True, help="Model to use for requests (required)"
)
# Add test mode arguments
test_args.add_argument(
"--completion", action="store_true", help="Test completions endpoint"
)
test_args.add_argument(
"--chat",
action="store_true",
help="Test chat completions endpoint (non-streaming)",
)
test_args.add_argument(
"--chat-stream",
action="store_true",
help="Test chat completions endpoint with streaming",
)
test_args.add_argument(
"--tools",
action="store_true",
help="Test function calling with ls tool (non-streaming)",
)
test_args.add_argument(
"--interactive",
action="store_true",
help="Start interactive streaming chat session",
)
async def main_async():
args = build_arg_parser().parse_args()
args = test_args.parse_args()
selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
if selected == 0:
# Check that only one test mode is selected
test_modes = [
args.completion,
args.chat,
args.chat_stream,
args.tools,
args.interactive,
]
selected_count = sum(test_modes)
if selected_count == 0:
print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint")
print(" --chat : Test chat completions endpoint (non-streaming)")
print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool")
print(" --tools : Test function calling with ls tool (non-streaming)")
print(" --interactive : Start interactive streaming chat session")
print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
print(
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
sys.exit(1)
elif selected > 1:
elif selected_count > 1:
print("Please specify exactly one test mode")
sys.exit(1)
try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client
client = APIClient(
endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key,
server_url=Endpoint.get_autoscaler_server_url(args.instance),
endpoint_api_key=endpoint_api_key,
)
# Create tool manager and demo (passing the model parameter)
tool_manager = ToolManager()
demo = APIDemo(client, args.model, tool_manager)
print(f"Using model: {args.model}")
print("=" * 60)
try:
async with Serverless() as client:
demo = APIDemo(client, args.model, ToolManager())
# Run the selected test
if args.completion:
await demo.demo_completions()
demo.demo_completions()
elif args.chat:
await demo.demo_chat(use_streaming=False)
demo.demo_chat(use_streaming=False)
elif args.chat_stream:
await demo.demo_chat(use_streaming=True)
demo.demo_chat(use_streaming=True)
elif args.tools:
await demo.demo_ls_tool()
demo.demo_ls_tool()
elif args.interactive:
await demo.interactive_chat()
demo.interactive_chat()
except Exception as e:
log.error("Error during test: %s", e, exc_info=True)
log.error(f"Error during test: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main_async())
main()
+4 -29
View File
@@ -119,25 +119,14 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
class CompletionsData(GenericData):
@classmethod
def for_test(cls) -> "CompletionsData":
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
test_input = {
"model": model,
"prompt": f"{system_prompt}\n\n{unique_question}",
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
@@ -164,18 +153,7 @@ class ChatCompletionsData(GenericData):
@classmethod
def for_test(cls) -> "ChatCompletionsData":
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
@@ -183,10 +161,7 @@ class ChatCompletionsData(GenericData):
# Chat completions use messages format instead of prompt
test_input = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt}, # Shared prefix
{"role": "user", "content": unique_question} # Unique per request
],
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"max_tokens": 500,
}
-1
View File
@@ -11,7 +11,6 @@ MODEL_SERVER_START_LOG_MSG = [
"llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI
"main: model loaded" # llama.cpp
]
MODEL_SERVER_ERROR_LOG_MSGS = [
-1
View File
@@ -82,7 +82,6 @@ def do_one(endpoint_name: str,
# 1) Check if we got a worker back from route
worker_url = msg.get("url", "")
if not worker_url:
status = msg.get("status", "")
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
if m:
tot, loading, standby, err = map(int, m.groups())
-78
View File
@@ -1,78 +0,0 @@
import nltk
import random
import os
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# vLLM model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18000
MODEL_LOG_FILE = '/var/log/portal/vllm.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# vLLM-specific log messages
MODEL_LOAD_LOG_MSG = [
"Application startup complete.",
]
MODEL_ERROR_LOG_MSGS = [
"INFO exited: vllm",
"RuntimeError: Engine",
"Traceback (most recent call last):"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Download'
]
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
def completions_benchmark_generator() -> dict:
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
benchmark_data = {
"model": model,
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
return benchmark_data
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/v1/completions",
workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True,
max_queue_time=60.0,
benchmark_config=BenchmarkConfig(
generator=completions_benchmark_generator,
concurrency=100,
runs=2
)
),
HandlerConfig(
route="/v1/chat/completions",
workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True,
max_queue_time=60.0,
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+119 -55
View File
@@ -1,61 +1,125 @@
from vastai import Serverless
import asyncio
import logging
import sys
import json
from urllib.parse import urljoin
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
MAX_TOKENS = 1024
PROMPT = "Think step by step: Tell me about the Python programming language."
async def call_generate(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"inputs": PROMPT,
"parameters": {
"max_new_tokens": MAX_TOKENS,
"temperature": 0.7,
"return_full_text": False
}
}
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
print(resp["response"]["generated_text"])
async def call_generate_stream(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"inputs": PROMPT,
"parameters": {
"max_new_tokens": MAX_TOKENS,
"temperature": 0.7,
"do_sample": True,
"return_full_text": False,
}
}
resp = await endpoint.request(
"/generate_stream",
payload,
cost=MAX_TOKENS,
stream=True,
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
stream = resp["response"]
log = logging.getLogger(__file__)
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("Answer:\n", end="", flush=True)
print(tok, end="", flush=True)
async def main():
async with Serverless() as client:
await call_generate(client)
await call_generate_stream(client)
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
WORKER_ENDPOINT = "/generate"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=url,
)
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
res = response.json()
print(res)
def call_generate_stream(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/generate_stream"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
print(f"url: {url}")
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
response = requests.post(url, json=req_data, stream=True)
response.raise_for_status() # Raise an exception for bad status codes
for line in response.iter_lines():
payload = line.decode().lstrip("data:").rstrip()
if payload:
try:
data = json.loads(payload)
print(data["token"]["text"], end="")
sys.stdout.flush()
except (json.JSONDecodeError, KeyError) as e:
log.warning(f"Failed to parse streaming response: {e}")
continue
print()
if __name__ == "__main__":
asyncio.run(main())
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
try:
call_generate(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_generate_stream(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
-76
View File
@@ -1,76 +0,0 @@
import nltk
import random
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# TGI model configuration
MODEL_SERVER_URL = 'http://0.0.0.0'
MODEL_SERVER_PORT = 5001
MODEL_LOG_FILE = "/workspace/infer.log"
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# TGI-specific log messages
MODEL_LOAD_LOG_MSG = [
'"message":"Connected","target":"text_generation_router"',
'"message":"Connected","target":"text_generation_router::server"',
]
MODEL_ERROR_LOG_MSGS = [
"Error: WebserverFailed",
"Error: DownloadError",
"Error: ShardCannotStart",
]
MODEL_INFO_LOG_MSGS = [
'"message":"Download'
]
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
def benchmark_generator() -> dict:
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
benchmark_data = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 128,
"temperature": 0.7,
"return_full_text": False
}
}
return benchmark_data
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate",
allow_parallel_requests=True,
max_queue_time=60.0,
benchmark_config=BenchmarkConfig(
generator=benchmark_generator,
concurrency=50
),
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
),
HandlerConfig(
route="/generate_stream",
allow_parallel_requests=True,
max_queue_time=60.0,
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
-288
View File
@@ -1,288 +0,0 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_prompts = [
"Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.",
"Cozy farming-game scene with fine details.",
"2D vector child with soccer ball; airbrush chrome; swagger; antique copper.",
"Realistic futuristic downtown of low buildings at sunset.",
"Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.",
"Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.",
"Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.",
"Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.",
"Medieval village inside glass sphere; volumetric light; macro focus.",
"Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.",
"Pope Francis DJ in leather jacket, mixing on giant console; dramatic.",
]
benchmark_dataset = [
{
"input": {
"workflow_json": {
"90": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader",
"_meta": {
"title": "Load CLIP"
}
},
"91": {
"inputs": {
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW",
"clip": [
"90",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"92": {
"inputs": {
"vae_name": "wan_2.1_vae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"93": {
"inputs": {
"shift": 8.000000000000002,
"model": [
"101",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"94": {
"inputs": {
"shift": 8,
"model": [
"102",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"95": {
"inputs": {
"add_noise": "disable",
"noise_seed": 0,
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 10,
"end_at_step": 10000,
"return_with_leftover_noise": "disable",
"model": [
"94",
0
],
"positive": [
"99",
0
],
"negative": [
"91",
0
],
"latent_image": [
"96",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"96": {
"inputs": {
"add_noise": "enable",
"noise_seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 0,
"end_at_step": 10,
"return_with_leftover_noise": "enable",
"model": [
"93",
0
],
"positive": [
"99",
0
],
"negative": [
"91",
0
],
"latent_image": [
"104",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"97": {
"inputs": {
"samples": [
"95",
0
],
"vae": [
"92",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"98": {
"inputs": {
"filename_prefix": "video/ComfyUI",
"format": "auto",
"codec": "auto",
"video": [
"100",
0
]
},
"class_type": "SaveVideo",
"_meta": {
"title": "Save Video"
}
},
"99": {
"inputs": {
"text":prompt,
"clip": [
"90",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"100": {
"inputs": {
"fps": 16,
"images": [
"97",
0
]
},
"class_type": "CreateVideo",
"_meta": {
"title": "Create Video"
}
},
"101": {
"inputs": {
"unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"102": {
"inputs": {
"unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"104": {
"inputs": {
"width": 640,
"height": 640,
"length": 81,
"batch_size": 1
},
"class_type": "EmptyHunyuanLatentVideo",
"_meta": {
"title": "EmptyHunyuanLatentVideo"
}
}
}
}
} for prompt in benchmark_prompts
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
runs=1
),
workload_calculator= lambda _ : 10000.0
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()