Compare commits
65 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 63550d5af3 | |||
| 7ec0e11938 | |||
| 191fbbfe18 | |||
| 9a4a39c71b | |||
| a4339bd3f1 | |||
| 2b26e5e20c | |||
| d3727d4fd7 | |||
| eedf81c0a3 | |||
| 3adec1826d | |||
| b55bfa9611 | |||
| 7db54f3bd7 | |||
| d63a060202 | |||
| c6521cb6d4 | |||
| b7fe4ebb91 | |||
| 8ae7b74605 | |||
| 106067d716 | |||
| f5134d4bf5 | |||
| 47e5460532 | |||
| ec2ac0a21a | |||
| 2cde573c56 | |||
| b2e4a5db0c | |||
| 7437028cb2 | |||
| 02c8307af7 | |||
| 7c0f316eeb | |||
| b4025a744f | |||
| d190308329 | |||
| 9f5a432513 | |||
| e09f1fa953 | |||
| ba6f1c2e4b | |||
| 944f83fc03 | |||
| 298590fb88 | |||
| 814c3acd4c | |||
| 22bca74087 | |||
| 9c795e2a01 | |||
| 830b532781 | |||
| d6a6e34c6b | |||
| ac1e109c48 | |||
| d6eb498ee4 | |||
| f56bbc0ebe | |||
| bcecd6df40 | |||
| 4d9bf2048c | |||
| 7788bc4a62 | |||
| 37ad3f8d46 | |||
| 70d51bafe1 | |||
| 63909736bb | |||
| f4f7080df1 | |||
| d51a338e8f | |||
| 92a04bd7af | |||
| 0f13506938 | |||
| 01e752d31f | |||
| 5edfa968ca | |||
| 5b5ef7227a | |||
| 16990ff8ff | |||
| 9748176366 | |||
| b39193ae70 | |||
| 9a6ca5d412 | |||
| e9ba1b03e4 | |||
| c98d661513 | |||
| f6fd1c6ac1 | |||
| 1cedb28acf | |||
| ec25dda3ad | |||
| 3786cf978d | |||
| a86d4bcf9c | |||
| e9b6a14a5e | |||
| cadac033e1 |
@@ -3,3 +3,4 @@
|
||||
__pycache__
|
||||
bin/
|
||||
lib64
|
||||
.venv
|
||||
@@ -39,11 +39,12 @@ reporting these metrics to the autoscaler.
|
||||
|
||||
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
|
||||
|
||||
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00)
|
||||
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447)
|
||||
* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d)
|
||||
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d)
|
||||
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938)
|
||||
|
||||
Currently available workers:
|
||||
* `hello_world`: A simple example worker for a basic LLM server.
|
||||
* `openai`: A simple example worker for a basic vLLM server.
|
||||
* `comfyui`: A worker for the ComfyUI image generation backend.
|
||||
* `tgi`: A worker for the Text Generation Inference backend.
|
||||
|
||||
|
||||
+96
-60
@@ -12,6 +12,7 @@ from distutils.util import strtobool
|
||||
|
||||
from anyio import open_file
|
||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
||||
import asyncio
|
||||
|
||||
import requests
|
||||
from Crypto.Signature import pkcs1_15
|
||||
@@ -25,8 +26,12 @@ from lib.data_types import (
|
||||
LogAction,
|
||||
ApiPayload_T,
|
||||
JsonDataException,
|
||||
RequestMetrics,
|
||||
BenchmarkResult
|
||||
)
|
||||
|
||||
VERSION = "0.2.0"
|
||||
|
||||
MSG_HISTORY_LEN = 100
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
@@ -53,15 +58,25 @@ class Backend:
|
||||
EndpointHandler # this endpoint handler will be used for benchmarking
|
||||
)
|
||||
log_actions: List[Tuple[LogAction, str]]
|
||||
max_wait_time: float = 10.0
|
||||
reqnum = -1
|
||||
version = VERSION
|
||||
msg_history = []
|
||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||
unsecured: bool = dataclasses.field(
|
||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||
)
|
||||
report_addr: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
||||
)
|
||||
mtoken: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.metrics = Metrics()
|
||||
self.metrics._set_version(self.version)
|
||||
self.metrics._set_mtoken(self.mtoken)
|
||||
self._total_pubkey_fetch_errors = 0
|
||||
self._pubkey = self._fetch_pubkey()
|
||||
self.__start_healthcheck: bool = False
|
||||
@@ -96,23 +111,19 @@ class Backend:
|
||||
|
||||
#######################################Private#######################################
|
||||
def _fetch_pubkey(self):
|
||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||
result = subprocess.check_output(command, universal_newlines=True)
|
||||
log.debug("public key:")
|
||||
log.debug(result)
|
||||
key = None
|
||||
for _ in range(5):
|
||||
try:
|
||||
key = RSA.import_key(result)
|
||||
break
|
||||
except ValueError as e:
|
||||
log.debug(f"Error downloading key: {e}")
|
||||
time.sleep(15)
|
||||
if key is None:
|
||||
self._total_pubkey_fetch_errors += 1
|
||||
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
||||
self.backend_errored("Failed to get autoscaler pubkey")
|
||||
return key
|
||||
report_addr = self.report_addr.rstrip("/")
|
||||
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
||||
try:
|
||||
result = subprocess.check_output(command, universal_newlines=True)
|
||||
log.debug("public key:")
|
||||
log.debug(result)
|
||||
key = RSA.import_key(result)
|
||||
if key is not None:
|
||||
return key
|
||||
except (ValueError , subprocess.CalledProcessError) as e:
|
||||
log.debug(f"Error downloading key: {e}")
|
||||
self.backend_errored("Failed to get autoscaler pubkey")
|
||||
|
||||
|
||||
async def __handle_request(
|
||||
self,
|
||||
@@ -128,55 +139,56 @@ class Backend:
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||
workload = payload.count_workload()
|
||||
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
||||
|
||||
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||
await request.wait_for_disconnection()
|
||||
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
||||
self.metrics._request_canceled(workload=workload)
|
||||
return web.Response(status=500)
|
||||
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
|
||||
self.metrics._request_canceled(request_metrics)
|
||||
raise asyncio.CancelledError
|
||||
|
||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||
log.debug(f"got request, {auth_data.reqnum}")
|
||||
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
|
||||
if self.allow_parallel_requests is False:
|
||||
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
|
||||
await self.sem.acquire()
|
||||
log.debug(
|
||||
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
|
||||
)
|
||||
else:
|
||||
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
|
||||
try:
|
||||
response = await self.__call_api(handler=handler, payload=payload)
|
||||
status_code = response.status
|
||||
log.debug(
|
||||
" ".join(
|
||||
[
|
||||
f"request with reqnum:{auth_data.reqnum}",
|
||||
f"request with reqnum:{request_metrics.reqnum}",
|
||||
f"returned status code: {status_code},",
|
||||
]
|
||||
)
|
||||
)
|
||||
res = await handler.generate_client_response(request, response)
|
||||
self.metrics._request_success(workload=workload)
|
||||
self.metrics._request_success(request_metrics)
|
||||
return res
|
||||
except requests.exceptions.RequestException as e:
|
||||
log.debug(f"[backend] Request error: {e}")
|
||||
self.metrics._request_errored(workload=workload)
|
||||
self.metrics._request_errored(request_metrics)
|
||||
return web.Response(status=500)
|
||||
finally:
|
||||
self.metrics._request_end(
|
||||
workload=workload,
|
||||
reqnum=auth_data.reqnum,
|
||||
)
|
||||
self.sem.release()
|
||||
|
||||
###########
|
||||
|
||||
if self.__check_signature(auth_data) is False:
|
||||
self.metrics._request_reject(request_metrics)
|
||||
return web.Response(status=401)
|
||||
|
||||
if self.metrics.model_metrics.wait_time > self.max_wait_time:
|
||||
self.metrics._request_reject(request_metrics)
|
||||
return web.Response(status=429)
|
||||
|
||||
acquired = False
|
||||
try:
|
||||
self.metrics._request_start(request_metrics)
|
||||
if self.allow_parallel_requests is False:
|
||||
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
||||
await self.sem.acquire()
|
||||
acquired = True
|
||||
log.debug(
|
||||
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
|
||||
)
|
||||
else:
|
||||
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||
done, pending = await wait(
|
||||
[
|
||||
create_task(make_request()),
|
||||
@@ -184,11 +196,27 @@ class Backend:
|
||||
],
|
||||
return_when=FIRST_COMPLETED,
|
||||
)
|
||||
[task.cancel() for task in pending]
|
||||
return done.pop().result()
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
done_task = done.pop()
|
||||
try:
|
||||
return done_task.result()
|
||||
except Exception as e:
|
||||
log.debug(f"Request task raised exception: {e}")
|
||||
return web.Response(status=500)
|
||||
except asyncio.CancelledError:
|
||||
# Client is gone. Do not write a response; just unwind.
|
||||
return web.Response(status=499)
|
||||
except Exception as e:
|
||||
log.debug(f"Exception in main handler loop {e}")
|
||||
return web.Response(status=500)
|
||||
finally:
|
||||
# Always release the semaphore if it was acquired
|
||||
if acquired:
|
||||
self.sem.release()
|
||||
self.metrics._request_end(request_metrics)
|
||||
|
||||
@cached_property
|
||||
def healthcheck_session(self):
|
||||
@@ -229,7 +257,7 @@ class Backend:
|
||||
|
||||
async def _start_tracking(self) -> None:
|
||||
await gather(
|
||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
|
||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
||||
)
|
||||
|
||||
def backend_errored(self, msg: str) -> None:
|
||||
@@ -261,7 +289,7 @@ class Backend:
|
||||
message = {
|
||||
key: value
|
||||
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||
if key != "signature"
|
||||
if key != "signature" and key != "__request_id"
|
||||
}
|
||||
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||
log.debug(
|
||||
@@ -271,7 +299,7 @@ class Backend:
|
||||
elif message in self.msg_history:
|
||||
log.debug(f"message: {message} already in message history")
|
||||
return False
|
||||
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
|
||||
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
|
||||
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||
self.msg_history.append(message)
|
||||
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||
@@ -290,10 +318,10 @@ class Backend:
|
||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||
log.debug("already ran benchmark")
|
||||
# trigger model load
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
_ = await self.__call_api(
|
||||
handler=self.benchmark_handler, payload=payload
|
||||
)
|
||||
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||
# _ = await self.__call_api(
|
||||
# handler=self.benchmark_handler, payload=payload
|
||||
# )
|
||||
return float(f.readline())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
@@ -308,18 +336,26 @@ class Backend:
|
||||
|
||||
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
||||
start = time.time()
|
||||
tasks = []
|
||||
total_workload = 0
|
||||
benchmark_requests = []
|
||||
|
||||
for _ in range(concurrent_requests):
|
||||
for i in range(concurrent_requests):
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
total_workload += payload.count_workload()
|
||||
tasks.append(
|
||||
self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||
workload = payload.count_workload()
|
||||
task = self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||
benchmark_requests.append(
|
||||
BenchmarkResult(request_idx=i, workload=workload, task=task)
|
||||
)
|
||||
|
||||
responses = await gather(*tasks)
|
||||
responses = await gather(*[br.task for br in benchmark_requests])
|
||||
for br, response in zip(benchmark_requests, responses):
|
||||
br.response = response
|
||||
|
||||
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
||||
time_elapsed = time.time() - start
|
||||
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
||||
if successful_responses == 0:
|
||||
self.backend_errored("No successful responses from benchmark")
|
||||
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
|
||||
|
||||
throughput = total_workload / time_elapsed
|
||||
sum_throughput += throughput
|
||||
@@ -333,7 +369,7 @@ class Backend:
|
||||
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
||||
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
||||
f"Throughput: {throughput} workload/s",
|
||||
f"Successful responses: {len([r for r in responses if r.status == 200])}",
|
||||
f"Successful responses: {successful_responses}/{concurrent_requests}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
@@ -360,7 +396,7 @@ class Backend:
|
||||
)
|
||||
# some backends need a few seconds after logging successful startup before
|
||||
# they can begin accepting requests
|
||||
await sleep(5)
|
||||
# await sleep(5)
|
||||
try:
|
||||
max_throughput = await run_benchmark()
|
||||
self.__start_healthcheck = True
|
||||
@@ -381,13 +417,13 @@ class Backend:
|
||||
|
||||
async def tail_log():
|
||||
log.debug(f"tailing file: {self.model_log_file}")
|
||||
async with await open_file(self.model_log_file) as f:
|
||||
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
|
||||
while True:
|
||||
line = await f.readline()
|
||||
if line:
|
||||
await handle_log_line(line.rstrip())
|
||||
else:
|
||||
time.sleep(LOG_POLL_INTERVAL)
|
||||
await asyncio.sleep(LOG_POLL_INTERVAL)
|
||||
|
||||
###########
|
||||
|
||||
@@ -395,4 +431,4 @@ class Backend:
|
||||
if os.path.isfile(self.model_log_file) is True:
|
||||
return await tail_log()
|
||||
else:
|
||||
await sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
+51
-10
@@ -3,7 +3,7 @@ import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
|
||||
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
|
||||
from aiohttp import web, ClientResponse
|
||||
import inspect
|
||||
|
||||
@@ -65,10 +65,11 @@ class ApiPayload(ABC):
|
||||
class AuthData:
|
||||
"""data used to authenticate requester"""
|
||||
|
||||
signature: str
|
||||
cost: str
|
||||
endpoint: str
|
||||
reqnum: int
|
||||
request_idx: int
|
||||
signature: str
|
||||
url: str
|
||||
|
||||
@classmethod
|
||||
@@ -189,13 +190,34 @@ class SystemMetrics:
|
||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||
self.last_disk_usage = disk_usage
|
||||
|
||||
def reset(self):
|
||||
def reset(self, expected: float | None) -> None:
|
||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||
# as well: they should send model_loading_time once when they are done loading
|
||||
self.model_loading_time = None
|
||||
if self.model_loading_time == expected:
|
||||
self.model_loading_time = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestMetrics:
|
||||
"""Tracks metrics for an active request."""
|
||||
request_idx: int
|
||||
reqnum: int
|
||||
workload: float
|
||||
status: str
|
||||
success: bool = False
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
request_idx: int
|
||||
workload: float
|
||||
task: Awaitable[ClientResponse]
|
||||
response: Optional[ClientResponse] = None
|
||||
|
||||
@property
|
||||
def is_successful(self) -> bool:
|
||||
return self.response is not None and self.response.status == 200
|
||||
|
||||
@dataclass
|
||||
class ModelMetrics:
|
||||
"""Model specific metrics"""
|
||||
@@ -205,12 +227,14 @@ class ModelMetrics:
|
||||
workload_received: float
|
||||
workload_cancelled: float
|
||||
workload_errored: float
|
||||
workload_rejected: float
|
||||
# these are not
|
||||
workload_pending: float
|
||||
error_msg: Optional[str]
|
||||
max_throughput: float
|
||||
requests_recieved: Set[int] = field(default_factory=set)
|
||||
requests_working: Set[int] = field(default_factory=set)
|
||||
requests_working: dict[int, RequestMetrics] = field(default_factory=dict)
|
||||
requests_deleting: list[RequestMetrics] = field(default_factory=list)
|
||||
last_update: float = field(default_factory=time.time)
|
||||
|
||||
@classmethod
|
||||
@@ -220,19 +244,30 @@ class ModelMetrics:
|
||||
workload_served=0.0,
|
||||
workload_cancelled=0.0,
|
||||
workload_errored=0.0,
|
||||
workload_rejected=0.0,
|
||||
workload_received=0.0,
|
||||
error_msg=None,
|
||||
max_throughput=0.0,
|
||||
)
|
||||
|
||||
@property
|
||||
def cur_perf(self) -> float:
|
||||
return max(self.workload_served / (time.time() - self.last_update), 0.0)
|
||||
|
||||
@property
|
||||
def workload_processing(self) -> float:
|
||||
return max(self.workload_received - self.workload_cancelled, 0.0)
|
||||
|
||||
@property
|
||||
def wait_time(self) -> float:
|
||||
if (len(self.requests_working) == 0):
|
||||
return 0.0
|
||||
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
|
||||
|
||||
@property
|
||||
def cur_load(self) -> float:
|
||||
return sum([request.workload for request in self.requests_working.values()])
|
||||
|
||||
@property
|
||||
def working_request_idxs(self) -> list[int]:
|
||||
return [req.request_idx for req in self.requests_working.values()]
|
||||
|
||||
def set_errored(self, error_msg):
|
||||
self.reset()
|
||||
self.error_msg = error_msg
|
||||
@@ -242,16 +277,21 @@ class ModelMetrics:
|
||||
self.workload_received = 0
|
||||
self.workload_cancelled = 0
|
||||
self.workload_errored = 0
|
||||
self.workload_rejected = 0
|
||||
self.last_update = time.time()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoScalaerData:
|
||||
class AutoScalerData:
|
||||
"""Data that is reported to autoscaler"""
|
||||
|
||||
id: int
|
||||
mtoken: str
|
||||
version: str
|
||||
loadtime: float
|
||||
cur_load: float
|
||||
rej_load: float
|
||||
new_load: float
|
||||
error_msg: str
|
||||
max_perf: float
|
||||
cur_perf: float
|
||||
@@ -260,6 +300,7 @@ class AutoScalaerData:
|
||||
num_requests_working: int
|
||||
num_requests_recieved: int
|
||||
additional_disk_usage: float
|
||||
working_request_idxs: list[int]
|
||||
url: str
|
||||
|
||||
|
||||
|
||||
+169
-39
@@ -5,13 +5,14 @@ import json
|
||||
from asyncio import sleep
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from functools import cache
|
||||
import asyncio
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError
|
||||
|
||||
import requests
|
||||
|
||||
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
|
||||
from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics
|
||||
from typing import Awaitable, NoReturn, List
|
||||
|
||||
METRICS_UPDATE_INTERVAL = 1
|
||||
DELETE_REQUESTS_INTERVAL = 1
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
@@ -26,7 +27,10 @@ def get_url() -> str:
|
||||
|
||||
@dataclass
|
||||
class Metrics:
|
||||
version: str = "0"
|
||||
mtoken: str = ""
|
||||
last_metric_update: float = 0.0
|
||||
last_request_served: float = 0.0
|
||||
update_pending: bool = False
|
||||
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
|
||||
report_addr: List[str] = field(
|
||||
@@ -35,43 +39,84 @@ class Metrics:
|
||||
url: str = field(default_factory=get_url)
|
||||
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
|
||||
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
|
||||
_session: ClientSession | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def _request_start(self, workload: float, reqnum: int) -> None:
|
||||
async def http(self) -> ClientSession:
|
||||
if self._session is None:
|
||||
self._session = ClientSession(
|
||||
timeout=ClientTimeout(total=10),
|
||||
connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True)
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
def _request_start(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called prior to forwarding a request to a model API.
|
||||
"""
|
||||
log.debug("request start")
|
||||
self.model_metrics.workload_pending += workload
|
||||
self.model_metrics.workload_received += workload
|
||||
self.model_metrics.requests_recieved.add(reqnum)
|
||||
self.model_metrics.requests_working.add(reqnum)
|
||||
request.status = "Started"
|
||||
self.model_metrics.workload_pending += request.workload
|
||||
self.model_metrics.workload_received += request.workload
|
||||
self.model_metrics.requests_recieved.add(request.reqnum)
|
||||
self.model_metrics.requests_working[request.reqnum] = request
|
||||
self.update_pending = True
|
||||
|
||||
def _request_end(self, workload: float, reqnum: int) -> None:
|
||||
def _request_end(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called after handling of a request ends, regardless of the outcome
|
||||
"""
|
||||
self.model_metrics.workload_pending -= workload
|
||||
self.model_metrics.requests_working.discard(reqnum)
|
||||
self.model_metrics.workload_pending -= request.workload
|
||||
self.model_metrics.requests_working.pop(request.reqnum, None)
|
||||
self.model_metrics.requests_deleting.append(request)
|
||||
self.last_request_served = time.time()
|
||||
|
||||
def _request_success(self, workload: float) -> None:
|
||||
def _request_success(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called after a response from model API is received and forwarded.
|
||||
"""
|
||||
self.model_metrics.workload_served += workload
|
||||
self.model_metrics.workload_served += request.workload
|
||||
request.status = "Success"
|
||||
request.success = True
|
||||
self.update_pending = True
|
||||
|
||||
def _request_errored(self, workload: float) -> None:
|
||||
def _request_errored(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called if model API returns an error
|
||||
"""
|
||||
self.model_metrics.workload_errored += workload
|
||||
self.model_metrics.workload_errored += request.workload
|
||||
request.status = "Error"
|
||||
request.success = False
|
||||
self.update_pending = True
|
||||
|
||||
def _request_canceled(self, workload: float) -> None:
|
||||
def _request_canceled(self, request: RequestMetrics) -> None:
|
||||
"""
|
||||
this function is called if client drops connection before model API has responded
|
||||
"""
|
||||
self.model_metrics.workload_cancelled += workload
|
||||
self.model_metrics.workload_cancelled += request.workload
|
||||
request.success = True
|
||||
request.status = "Cancelled"
|
||||
|
||||
def _request_reject(self, request: RequestMetrics):
|
||||
"""
|
||||
this function is called if the current wait time for the model is above max_wait_time
|
||||
"""
|
||||
self.model_metrics.requests_recieved.add(request.reqnum)
|
||||
self.model_metrics.requests_deleting.append(request)
|
||||
self.model_metrics.workload_rejected += request.workload
|
||||
request.success = False
|
||||
request.status = "Rejected"
|
||||
self.update_pending = True
|
||||
|
||||
async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]:
|
||||
while True:
|
||||
await sleep(DELETE_REQUESTS_INTERVAL)
|
||||
if len(self.model_metrics.requests_deleting) > 0:
|
||||
await self.__send_delete_requests_and_reset()
|
||||
|
||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||
while True:
|
||||
@@ -79,10 +124,10 @@ class Metrics:
|
||||
elapsed = time.time() - self.last_metric_update
|
||||
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
||||
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
||||
self.__send_metrics_and_reset(elapsed)
|
||||
await self.__send_metrics_and_reset()
|
||||
elif self.update_pending or elapsed > 10:
|
||||
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
||||
self.__send_metrics_and_reset(elapsed)
|
||||
await self.__send_metrics_and_reset()
|
||||
|
||||
def _model_loaded(self, max_throughput: float) -> None:
|
||||
self.system_metrics.model_loading_time = (
|
||||
@@ -95,49 +140,130 @@ class Metrics:
|
||||
self.model_metrics.set_errored(error_msg)
|
||||
self.system_metrics.model_is_loaded = True
|
||||
|
||||
def _set_version(self, version: str) -> None:
|
||||
self.version = version
|
||||
|
||||
def _set_mtoken(self, mtoken: str) -> None:
|
||||
self.mtoken = mtoken
|
||||
|
||||
#######################################Private#######################################
|
||||
|
||||
def __send_metrics_and_reset(self, elapsed):
|
||||
async def __send_delete_requests_and_reset(self):
|
||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||
data = {
|
||||
"worker_id": self.id,
|
||||
"mtoken": self.mtoken,
|
||||
"request_idxs": idxs,
|
||||
"success": success_flag,
|
||||
}
|
||||
log.debug(
|
||||
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
|
||||
)
|
||||
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
session = await self.http()
|
||||
async with session.post(full_path, json=data) as res:
|
||||
log.debug(f"delete_requests response: {res.status}")
|
||||
res.raise_for_status()
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
log.debug("delete_requests timed out")
|
||||
except (ClientResponseError, Exception) as e:
|
||||
log.debug(f"delete_requests failed with error: {e}")
|
||||
await asyncio.sleep(2)
|
||||
log.debug(f"retrying delete_request, attempt: {attempt}")
|
||||
return False
|
||||
|
||||
def compute_autoscaler_data() -> AutoScalaerData:
|
||||
return AutoScalaerData(
|
||||
# Take a snapshot of what we plan to send this tick.
|
||||
# New arrivals after this snapshot will remain in the queue for the next tick.
|
||||
snapshot = list(self.model_metrics.requests_deleting)
|
||||
success_idxs = [r.request_idx for r in snapshot if r.success is True]
|
||||
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
|
||||
|
||||
if not success_idxs and not failed_idxs:
|
||||
return # nothing to do
|
||||
|
||||
for report_addr in self.report_addr:
|
||||
# TODO: Add a Redis subscriber queue for delete_requests
|
||||
if report_addr == "https://cloud.vast.ai/api/v0":
|
||||
# Patch: ignore the Redis API report_addr
|
||||
continue
|
||||
sent_success = True
|
||||
sent_failed = True
|
||||
|
||||
if success_idxs:
|
||||
sent_success = await post(report_addr, success_idxs, True)
|
||||
if failed_idxs:
|
||||
sent_failed = await post(report_addr, failed_idxs, False)
|
||||
|
||||
if sent_success and sent_failed:
|
||||
# Remove only the items we actually sent from the live queue.
|
||||
sent_set = set(success_idxs) | set(failed_idxs)
|
||||
self.model_metrics.requests_deleting[:] = [
|
||||
r for r in self.model_metrics.requests_deleting
|
||||
if r.request_idx not in sent_set
|
||||
]
|
||||
break
|
||||
|
||||
|
||||
async def __send_metrics_and_reset(self):
|
||||
|
||||
loadtime_snapshot = self.system_metrics.model_loading_time
|
||||
|
||||
def compute_autoscaler_data() -> AutoScalerData:
|
||||
return AutoScalerData(
|
||||
id=self.id,
|
||||
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
||||
cur_load=(self.model_metrics.workload_processing / elapsed),
|
||||
mtoken=self.mtoken,
|
||||
version=self.version,
|
||||
loadtime=(loadtime_snapshot or 0.0),
|
||||
new_load=self.model_metrics.workload_processing,
|
||||
cur_load=self.model_metrics.cur_load,
|
||||
rej_load=self.model_metrics.workload_rejected,
|
||||
max_perf=self.model_metrics.max_throughput,
|
||||
cur_perf=self.model_metrics.cur_perf,
|
||||
cur_perf=self.model_metrics.workload_served,
|
||||
error_msg=self.model_metrics.error_msg or "",
|
||||
num_requests_working=len(self.model_metrics.requests_working),
|
||||
num_requests_recieved=len(self.model_metrics.requests_recieved),
|
||||
additional_disk_usage=self.system_metrics.additional_disk_usage,
|
||||
working_request_idxs=self.model_metrics.working_request_idxs,
|
||||
cur_capacity=0,
|
||||
max_capacity=0,
|
||||
url=self.url,
|
||||
)
|
||||
|
||||
def send_data(report_addr: str) -> bool:
|
||||
async def send_data(report_addr: str) -> bool:
|
||||
data = compute_autoscaler_data()
|
||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||
log_data = asdict(data)
|
||||
def obfuscate(secret: str) -> str:
|
||||
if secret is None:
|
||||
return ""
|
||||
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
||||
|
||||
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
||||
log.debug(
|
||||
"\n".join(
|
||||
[
|
||||
"#" * 60,
|
||||
f"sending data to autoscaler",
|
||||
f"{json.dumps((asdict(data)), indent=2)}",
|
||||
f"{json.dumps(log_data, indent=2)}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
res = requests.post(full_path, json=asdict(data), timeout=1)
|
||||
res.raise_for_status()
|
||||
session = await self.http()
|
||||
async with session.post(full_path, json=asdict(data)) as res:
|
||||
res.raise_for_status()
|
||||
return True
|
||||
except requests.Timeout:
|
||||
except asyncio.TimeoutError:
|
||||
log.debug(f"autoscaler status update timed out")
|
||||
except Exception as e:
|
||||
except (ClientResponseError, Exception) as e:
|
||||
log.debug(f"autoscaler status update failed with error: {e}")
|
||||
time.sleep(2)
|
||||
await asyncio.sleep(2)
|
||||
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
|
||||
log.debug(f"failed to send update through {report_addr}")
|
||||
return False
|
||||
@@ -146,11 +272,15 @@ class Metrics:
|
||||
|
||||
self.system_metrics.update_disk_usage()
|
||||
|
||||
sent = False
|
||||
for report_addr in self.report_addr:
|
||||
success = send_data(report_addr)
|
||||
if success is True:
|
||||
if await send_data(report_addr):
|
||||
sent = True
|
||||
break
|
||||
self.update_pending = False
|
||||
self.model_metrics.reset()
|
||||
self.system_metrics.reset()
|
||||
self.last_metric_update = time.time()
|
||||
|
||||
if sent:
|
||||
# clear the one-shot loadtime only if we actually sent *this* value
|
||||
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||
self.update_pending = False
|
||||
self.model_metrics.reset()
|
||||
self.last_metric_update = time.time()
|
||||
|
||||
+6
-6
@@ -292,12 +292,12 @@ def test_load_cmd(
|
||||
args = arg_parser.parse_args()
|
||||
if hasattr(args, "comfy_model"):
|
||||
os.environ["COMFY_MODEL"] = args.comfy_model
|
||||
server_url = dict(
|
||||
prod="https://run.vast.ai",
|
||||
alpha="https://run-alpha.vast.ai",
|
||||
candidate="https://run-candidate.vast.ai",
|
||||
local="http://localhost:8080",
|
||||
)[args.instance]
|
||||
server_url = {
|
||||
"prod": "https://run.vast.ai",
|
||||
"alpha": "https://run-alpha.vast.ai",
|
||||
"candidate": "https://run-candidate.vast.ai",
|
||||
"local": "http://localhost:8080",
|
||||
}.get(args.instance, "http://localhost:8080")
|
||||
run_test(
|
||||
num_requests=args.num_requests,
|
||||
requests_per_second=args.requests_per_second,
|
||||
|
||||
@@ -8,3 +8,4 @@ Requests~=2.32
|
||||
transformers~=4.52
|
||||
utils==1.0.*
|
||||
hf_transfer>=0.1.9
|
||||
vastai-sdk>=0.2.0
|
||||
+1
-5
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
|
||||
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
||||
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
||||
|
||||
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
|
||||
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
|
||||
USE_SSL="${USE_SSL:-true}"
|
||||
WORKER_PORT="${WORKER_PORT:-3000}"
|
||||
mkdir -p "$WORKSPACE_DIR"
|
||||
@@ -124,9 +124,5 @@ cd "$SERVER_DIR"
|
||||
|
||||
echo "launching PyWorker server"
|
||||
|
||||
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
||||
|
||||
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
||||
echo "launching PyWorker server done"
|
||||
|
||||
+43
-5
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
@@ -16,6 +17,38 @@ class Endpoint:
|
||||
Utility class for handling endpoint operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_endpoint_info(
|
||||
endpoint_name: str, account_api_key: str, instance: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||
url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}"
|
||||
# Retry a few times to smooth over transient propagation/network delays
|
||||
for attempt in range(4):
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=8)
|
||||
if response.status_code != 200:
|
||||
# brief backoff and retry
|
||||
time.sleep(0.3 * (attempt + 1))
|
||||
continue
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception:
|
||||
# JSON parse failed; backoff and retry
|
||||
time.sleep(0.3 * (attempt + 1))
|
||||
continue
|
||||
result = data.get("results", []) if isinstance(data, dict) else []
|
||||
endpoint = next(
|
||||
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||
None,
|
||||
)
|
||||
if endpoint and endpoint.get("id") and endpoint.get("api_key"):
|
||||
return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")}
|
||||
except Exception:
|
||||
# network or other transient error; retry
|
||||
time.sleep(0.3 * (attempt + 1))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_autoscaler_server_url(instance: str) -> str:
|
||||
endpoints = {
|
||||
@@ -23,7 +56,10 @@ class Endpoint:
|
||||
"candidate": "run-candidate",
|
||||
"prod": "run",
|
||||
}
|
||||
return f"https://{endpoints[instance]}.vast.ai/"
|
||||
host = endpoints.get(instance)
|
||||
if host:
|
||||
return f"https://{host}.vast.ai/"
|
||||
return "http://localhost:8080"
|
||||
|
||||
@staticmethod
|
||||
def get_server_url(instance: str) -> str:
|
||||
@@ -32,7 +68,8 @@ class Endpoint:
|
||||
"candidate": "candidate",
|
||||
"prod": "console",
|
||||
}
|
||||
return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
|
||||
host = endpoints.get(instance, "alpha")
|
||||
return f"https://{host}.vast.ai/api/v0/endptjobs/"
|
||||
|
||||
@staticmethod
|
||||
def get_endpoint_api_key(
|
||||
@@ -55,6 +92,7 @@ class Endpoint:
|
||||
response = requests.get(
|
||||
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
||||
headers=headers,
|
||||
timeout=8,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
@@ -64,14 +102,14 @@ class Endpoint:
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except requests.exceptions.JSONDecodeError as e:
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to parse JSON response: {e}")
|
||||
return None
|
||||
|
||||
result = data.get("results", [])
|
||||
|
||||
endpoint: Optional[Dict[str, Any]] = next(
|
||||
(item for item in result if item["endpoint_name"] == endpoint_name),
|
||||
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||
None,
|
||||
)
|
||||
if not endpoint:
|
||||
|
||||
@@ -12,9 +12,21 @@ A docker image is provided but you may use any if the above requirements are met
|
||||
|
||||
## Benchmarking
|
||||
|
||||
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
|
||||
### Custom Benchmark Workflows
|
||||
|
||||
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
|
||||
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
|
||||
|
||||
**Ways to provide the benchmark file:**
|
||||
- Fork this repository and add your `benchmark.json` file
|
||||
- Write the file during worker provisioning (onstart script or setup phase)
|
||||
|
||||
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
|
||||
|
||||
### Default Benchmark (Fallback)
|
||||
|
||||
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
|
||||
|
||||
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
|
||||
|
||||
| Environment Variable | Default Value | Description |
|
||||
| -------------------- | ------------- | ----------- |
|
||||
@@ -24,7 +36,7 @@ The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image wo
|
||||
|
||||
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
||||
|
||||
### Calibrating Benchmark Duration
|
||||
#### Calibrating Fallback Benchmark Duration
|
||||
|
||||
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
||||
|
||||
|
||||
+25
-145
@@ -1,155 +1,35 @@
|
||||
import logging
|
||||
from .data_types import count_workload
|
||||
import uuid
|
||||
import random
|
||||
from urllib.parse import urljoin
|
||||
import json
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
import requests
|
||||
from vastai import Serverless
|
||||
|
||||
from lib.test_utils import print_truncate_res
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from .data_types import count_workload
|
||||
async def main():
|
||||
async with Serverless() as client:
|
||||
endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def call_text2image_workflow(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
"""Simple Text2Image using the new modifier-based approach"""
|
||||
|
||||
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
|
||||
"""Helper function for making requests with consistent error handling"""
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
verify=verify
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
log.error(f"HTTP error occurred during {context}: {http_err}")
|
||||
log.error(f"Status Code: {response.status_code}")
|
||||
log.error("Response content:", response.text)
|
||||
return None
|
||||
except requests.exceptions.Timeout:
|
||||
log.error(f"Timeout occurred during {context}: {url}")
|
||||
return None
|
||||
except requests.exceptions.ConnectionError:
|
||||
log.error(f"Connection error occurred during {context}: {url}")
|
||||
return None
|
||||
except json.JSONDecodeError as json_err:
|
||||
log.error(f"Failed to decode JSON response during {context}: {json_err}")
|
||||
if 'response' in locals():
|
||||
print("Response content:", response.text)
|
||||
return None
|
||||
except Exception as err:
|
||||
log.error(f"An unexpected error occurred during {context}: {err}")
|
||||
if 'response' in locals():
|
||||
log.error("Response content (if available):", response.text)
|
||||
return None
|
||||
|
||||
WORKER_ENDPOINT = "/generate/sync"
|
||||
|
||||
# This worker has concurrency = 1. All workloads have cost value 1.0
|
||||
COST = count_workload()
|
||||
|
||||
# Route to get worker URL
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
}
|
||||
|
||||
# First request - get routing information
|
||||
route_response = make_request(
|
||||
url=urljoin(server_url, "/route/"),
|
||||
payload=route_payload,
|
||||
timeout=4,
|
||||
context="route request"
|
||||
)
|
||||
|
||||
if route_response is None:
|
||||
return None
|
||||
|
||||
if "url" not in route_response or not route_response["url"]:
|
||||
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
|
||||
return None
|
||||
|
||||
if "status" in route_response:
|
||||
print(f"Autoscaler status: {route_response['status']}")
|
||||
return None
|
||||
|
||||
# Extract data from route response
|
||||
url = route_response["url"]
|
||||
auth_data = dict(
|
||||
signature=route_response["signature"],
|
||||
cost=route_response["cost"],
|
||||
endpoint=route_response["endpoint"],
|
||||
reqnum=route_response["reqnum"],
|
||||
url=route_response["url"],
|
||||
)
|
||||
|
||||
# Build the payload for the worker request
|
||||
worker_payload = {
|
||||
"input": {
|
||||
"request_id": str(uuid.uuid4()),
|
||||
"modifier": "Text2Image",
|
||||
"modifications": {
|
||||
"prompt": "a beautiful landscape with mountains and lakes",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"steps": 20,
|
||||
"seed": random.randint(0, 2**32 - 1)
|
||||
},
|
||||
"workflow_json": {} # Empty since using modifier approach
|
||||
payload = {
|
||||
"input": {
|
||||
"request_id": str(uuid.uuid4()),
|
||||
"modifier": "Text2Image",
|
||||
"modifications": {
|
||||
"prompt": "a beautiful landscape with mountains and lakes",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"steps": 20,
|
||||
"seed": random.randint(0, 2**32 - 1)
|
||||
},
|
||||
"workflow_json": {} # Empty since using modifier approach
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req_data = dict(payload=worker_payload, auth_data=auth_data)
|
||||
worker_url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {worker_url}")
|
||||
|
||||
# Second request - call the worker endpoint
|
||||
worker_response = make_request(
|
||||
url=worker_url,
|
||||
payload=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
context="worker request"
|
||||
)
|
||||
|
||||
return worker_response
|
||||
response = await endpoint.request("/generate/sync", payload, cost=count_workload())
|
||||
|
||||
# Get the file from the path on the local machine using SCP or SFTP
|
||||
# or configure S3 to upload to cloud storage.
|
||||
print(response["response"]["output"][0]["local_path"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lib.test_utils import test_args
|
||||
|
||||
args = test_args.parse_args()
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
|
||||
if endpoint_api_key:
|
||||
result = call_text2image_workflow(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
if result is None:
|
||||
log.error("Text2Image workflow failed")
|
||||
else:
|
||||
print(result)
|
||||
else:
|
||||
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
|
||||
asyncio.run(main())
|
||||
@@ -5,12 +5,13 @@ import dataclasses
|
||||
from typing import Dict, Any
|
||||
from functools import cache
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
import json
|
||||
import logging
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
|
||||
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
|
||||
test_prompts = f.readlines()
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
def count_workload() -> float:
|
||||
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
||||
@@ -24,9 +25,32 @@ class ComfyWorkflowData(ApiPayload):
|
||||
@classmethod
|
||||
def for_test(cls):
|
||||
"""
|
||||
Use the variables available to simulate workflows of the required running time
|
||||
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
|
||||
Otherwise, use the variables available to simulate workflows of the required running time
|
||||
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
||||
"""
|
||||
# Try to load benchmark.json
|
||||
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
|
||||
|
||||
if benchmark_file.exists():
|
||||
try:
|
||||
with open(benchmark_file, "r") as f:
|
||||
benchmark_workflow = json.load(f)
|
||||
return cls(
|
||||
input={
|
||||
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||
"workflow_json": benchmark_workflow
|
||||
}
|
||||
)
|
||||
except (json.JSONDecodeError, IOError):
|
||||
# JSON is malformed or file can't be read, fall through to default
|
||||
log.error(f"Failed to benchmark using {benchmark_file}")
|
||||
|
||||
# Fallback: read prompts and construct payload
|
||||
log.info("Using fallback method for benchmarking")
|
||||
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
|
||||
test_prompts = f.readlines()
|
||||
|
||||
test_prompt = random.choice(test_prompts).rstrip()
|
||||
return cls(
|
||||
input={
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
{
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": "__RANDOM_INT__",
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "text, watermark",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||
"Value not in list: ", # This error is emitted when the model file is not there at all
|
||||
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -7,20 +7,13 @@ from lib.test_utils import print_truncate_res
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
|
||||
"""
|
||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
from vastai import Serverless
|
||||
|
||||
|
||||
def call_default_workflow(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
ENDPOINT_NAME = "my-comfyui-endpoint"
|
||||
COST = 100 # Use a constant cost for image generation
|
||||
|
||||
def call_default_workflow(client: Serverless) -> None:
|
||||
WORKER_ENDPOINT = "/prompt"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
@@ -82,6 +75,7 @@ def call_custom_workflow_for_sd3(
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
request_idx=message["request_idx"],
|
||||
)
|
||||
workflow = {
|
||||
"3": {
|
||||
|
||||
+357
-427
@@ -1,14 +1,15 @@
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from urllib.parse import urljoin
|
||||
from typing import Dict, Any, Optional, Iterator, Union, List
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
||||
import argparse
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from vastai import Serverless
|
||||
import asyncio
|
||||
|
||||
# ---------------------- Logging ----------------------
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
@@ -16,135 +17,20 @@ logging.basicConfig(
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
# ---------------------- Prompts ----------------------
|
||||
COMPLETIONS_PROMPT = "the capital of USA is"
|
||||
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
|
||||
|
||||
|
||||
class APIClient:
|
||||
"""Lightweight client focused solely on API communication"""
|
||||
|
||||
# Remove the generic WORKER_ENDPOINT since we're now going direct
|
||||
DEFAULT_COST = 100
|
||||
DEFAULT_TIMEOUT = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_group_name: str,
|
||||
api_key: str,
|
||||
server_url: str,
|
||||
endpoint_api_key: str,
|
||||
):
|
||||
self.endpoint_group_name = endpoint_group_name
|
||||
self.api_key = api_key
|
||||
self.server_url = server_url
|
||||
self.endpoint_api_key = endpoint_api_key
|
||||
|
||||
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
||||
"""Get worker URL and auth data from routing service"""
|
||||
if not self.endpoint_api_key:
|
||||
raise ValueError("No valid endpoint API key available")
|
||||
|
||||
route_payload = {
|
||||
"endpoint": self.endpoint_group_name,
|
||||
"api_key": self.endpoint_api_key,
|
||||
"cost": cost,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
urljoin(self.server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=self.DEFAULT_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create auth data from routing response"""
|
||||
return {
|
||||
"signature": message["signature"],
|
||||
"cost": message["cost"],
|
||||
"endpoint": message["endpoint"],
|
||||
"reqnum": message["reqnum"],
|
||||
"url": message["url"],
|
||||
}
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
payload: Dict[str, Any],
|
||||
endpoint: str,
|
||||
method: str = "POST",
|
||||
stream: bool = False,
|
||||
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
"""Make request directly to the specific worker endpoint"""
|
||||
# Get worker URL and auth data
|
||||
cost = payload.get("max_tokens", self.DEFAULT_COST)
|
||||
message = self._get_worker_url(cost=cost)
|
||||
worker_url = message["url"]
|
||||
auth_data = self._create_auth_data(message)
|
||||
|
||||
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
|
||||
|
||||
url = urljoin(worker_url, endpoint)
|
||||
log.debug(f"Making direct request to: {url}")
|
||||
log.debug(f"Payload: {req_data}")
|
||||
|
||||
# Make the request using the specified method
|
||||
if method.upper() == "POST":
|
||||
response = requests.post(
|
||||
url, json=req_data, stream=stream, verify=get_cert_file_path()
|
||||
)
|
||||
elif method.upper() == "GET":
|
||||
response = requests.get(
|
||||
url, params=req_data, stream=stream, verify=get_cert_file_path()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
if stream:
|
||||
return self._handle_streaming_response(response)
|
||||
else:
|
||||
return response.json()
|
||||
|
||||
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
|
||||
"""Handle streaming response and yield tokens"""
|
||||
try:
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line:
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
yield data # Yield the full chunk
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
log.error(f"Error handling streaming response: {e}")
|
||||
raise
|
||||
|
||||
def call_completions(
|
||||
self, config: CompletionConfig
|
||||
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
payload = config.to_dict()
|
||||
|
||||
return self._make_request(
|
||||
payload=payload, endpoint="/v1/completions", stream=config.stream
|
||||
)
|
||||
|
||||
def call_chat_completions(
|
||||
self, config: ChatCompletionConfig
|
||||
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
payload = config.to_dict()
|
||||
|
||||
return self._make_request(
|
||||
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
|
||||
)
|
||||
TOOLS_PROMPT = (
|
||||
"Can you list the files in the current working directory and tell me what you see? "
|
||||
"What do you think this directory might be for?"
|
||||
)
|
||||
|
||||
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
|
||||
MAX_TOKENS = 1024
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
|
||||
# ---------------------- Tooling ----------------------
|
||||
class ToolManager:
|
||||
"""Handles tool definitions and execution"""
|
||||
|
||||
@@ -164,7 +50,7 @@ class ToolManager:
|
||||
|
||||
@staticmethod
|
||||
def get_ls_tool_definition() -> List[Dict[str, Any]]:
|
||||
"""Get the ls tool definition"""
|
||||
"""OpenAI-compatible tool schema"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
@@ -178,98 +64,217 @@ class ToolManager:
|
||||
|
||||
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
||||
"""Execute a tool call and return the result"""
|
||||
function_name = tool_call["function"]["name"]
|
||||
|
||||
function_name = (tool_call.get("function") or {}).get("name")
|
||||
if function_name == "list_files":
|
||||
return self.list_files()
|
||||
else:
|
||||
raise ValueError(f"Unknown tool function: {function_name}")
|
||||
raise ValueError(f"Unknown tool function: {function_name}")
|
||||
|
||||
|
||||
# ----- Helpers to handle streamed tool_calls assembly -----
|
||||
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
|
||||
"""
|
||||
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
|
||||
We merge into a per-index state dict until the assistant message finishes.
|
||||
"""
|
||||
idx = tc_delta.get("index")
|
||||
if idx is None:
|
||||
return
|
||||
|
||||
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
|
||||
|
||||
if tc_delta.get("id"):
|
||||
entry["id"] = tc_delta["id"]
|
||||
|
||||
fn_delta = tc_delta.get("function") or {}
|
||||
if "name" in fn_delta and fn_delta["name"]:
|
||||
entry["function"]["name"] = fn_delta["name"]
|
||||
if "arguments" in fn_delta and fn_delta["arguments"]:
|
||||
entry["function"]["arguments"] += fn_delta["arguments"]
|
||||
|
||||
|
||||
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
return [state[i] for i in sorted(state.keys())]
|
||||
|
||||
|
||||
# ---- OpenAI-compatible calls (non-streaming) ----
|
||||
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
}
|
||||
}
|
||||
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
# ---- Streaming variants ----
|
||||
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
|
||||
# ---------------------- Demo Runner ----------------------
|
||||
class APIDemo:
|
||||
"""Demo and testing functionality for the API client"""
|
||||
|
||||
def __init__(
|
||||
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
|
||||
):
|
||||
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.tool_manager = tool_manager or ToolManager()
|
||||
|
||||
def handle_streaming_response(
|
||||
self, response_stream, show_reasoning: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Handle streaming chat response and display all output.
|
||||
"""
|
||||
|
||||
# ----- Streaming handler -----
|
||||
async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
|
||||
full_response = ""
|
||||
reasoning_content = ""
|
||||
reasoning_started = False
|
||||
content_started = False
|
||||
printed_reasoning = False
|
||||
printed_answer = False
|
||||
|
||||
for chunk in response_stream:
|
||||
# Normalize the chunk
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.strip()
|
||||
if chunk.startswith("data: "):
|
||||
chunk = chunk[6:].strip()
|
||||
if chunk in ["[DONE]", ""]:
|
||||
continue
|
||||
try:
|
||||
parsed_chunk = json.loads(chunk)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
elif isinstance(chunk, dict):
|
||||
parsed_chunk = chunk
|
||||
else:
|
||||
continue
|
||||
async for chunk in stream:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Parse delta from the chunk
|
||||
choices = parsed_chunk.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
reasoning_token = delta.get("reasoning_content", "")
|
||||
content_token = delta.get("content", "")
|
||||
|
||||
# Print reasoning token if applicable
|
||||
if show_reasoning and reasoning_token:
|
||||
if not reasoning_started:
|
||||
# reasoning tokens
|
||||
rc = delta.get("reasoning_content")
|
||||
if rc and show_reasoning:
|
||||
if not printed_reasoning:
|
||||
print("\n🧠 Reasoning: ", end="", flush=True)
|
||||
reasoning_started = True
|
||||
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
|
||||
reasoning_content += reasoning_token
|
||||
printed_reasoning = True
|
||||
print(rc, end="", flush=True)
|
||||
reasoning_content += rc
|
||||
|
||||
# Print content token
|
||||
if content_token:
|
||||
if not content_started:
|
||||
if show_reasoning and reasoning_started:
|
||||
print(f"\n💬 Response: ", end="", flush=True)
|
||||
# content tokens
|
||||
content_part = delta.get("content")
|
||||
if content_part:
|
||||
if not printed_answer:
|
||||
if show_reasoning and printed_reasoning:
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
else:
|
||||
print("Assistant: ", end="", flush=True)
|
||||
content_started = True
|
||||
print(content_token, end="", flush=True)
|
||||
full_response += content_token
|
||||
|
||||
print() # Ensure newline after response
|
||||
printed_answer = True
|
||||
print(content_part, end="", flush=True)
|
||||
full_response += content_part
|
||||
|
||||
print() # newline
|
||||
if show_reasoning:
|
||||
if reasoning_started or content_started:
|
||||
if printed_reasoning or printed_answer:
|
||||
print("\nStreaming completed.")
|
||||
if reasoning_started:
|
||||
if printed_reasoning:
|
||||
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
||||
if content_started:
|
||||
if printed_answer:
|
||||
print(f"Response tokens: {len(full_response.split())}")
|
||||
|
||||
return full_response
|
||||
|
||||
def test_tool_support(self) -> bool:
|
||||
"""Test if the endpoint supports function calling"""
|
||||
log.debug("Testing endpoint tool calling support...")
|
||||
async def demo_completions(self) -> None:
|
||||
print("=" * 60)
|
||||
print("COMPLETIONS DEMO")
|
||||
print("=" * 60)
|
||||
|
||||
# Try a simple request with minimal tools to test support
|
||||
response = await call_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
prompt=COMPLETIONS_PROMPT,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
print("\nResponse:")
|
||||
print(json.dumps(response, indent=2))
|
||||
|
||||
async def demo_chat(self, use_streaming: bool = True) -> None:
|
||||
print("=" * 60)
|
||||
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
|
||||
print("=" * 60)
|
||||
|
||||
messages = [{"role": "user", "content": CHAT_PROMPT}]
|
||||
|
||||
if use_streaming:
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
try:
|
||||
await self.handle_streaming_response(stream, show_reasoning=True)
|
||||
except Exception as e:
|
||||
log.error("\nError during streaming: %s", e, exc_info=True)
|
||||
else:
|
||||
response = await call_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
choice = (response.get("choices") or [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
|
||||
if reasoning:
|
||||
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
||||
print(f"\n💬 Assistant: {content}")
|
||||
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
async def test_tool_support(self) -> bool:
|
||||
"""Probe that tool schema is accepted (no actual call)"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
minimal_tool = [
|
||||
{
|
||||
@@ -277,170 +282,147 @@ class APIDemo:
|
||||
"function": {"name": "test_function", "description": "Test function"},
|
||||
}
|
||||
]
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
tools=minimal_tool,
|
||||
tool_choice="none", # Don't actually call the tool
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.client.call_chat_completions(config)
|
||||
_ = await call_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=minimal_tool,
|
||||
tool_choice="none",
|
||||
max_tokens=10
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error: Endpoint does not support tool calling: {e}")
|
||||
log.error("Endpoint does not support tool calling: %s", e)
|
||||
return False
|
||||
|
||||
def demo_completions(self) -> None:
|
||||
"""Demo: test basic completions endpoint"""
|
||||
print("=" * 60)
|
||||
print("COMPLETIONS DEMO")
|
||||
print("=" * 60)
|
||||
|
||||
config = CompletionConfig(
|
||||
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
|
||||
)
|
||||
response = self.client.call_completions(config)
|
||||
|
||||
if isinstance(response, dict):
|
||||
print("\nResponse:")
|
||||
print(json.dumps(response, indent=2))
|
||||
else:
|
||||
log.error("Unexpected response format")
|
||||
|
||||
def demo_chat(self, use_streaming: bool = True) -> None:
|
||||
"""
|
||||
Demo: test chat completions endpoint with optional streaming
|
||||
"""
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": CHAT_PROMPT}],
|
||||
stream=use_streaming,
|
||||
)
|
||||
|
||||
log.info(f"Testing chat completions with model '{self.model}'...")
|
||||
response = self.client.call_chat_completions(config)
|
||||
|
||||
if use_streaming:
|
||||
try:
|
||||
self.handle_streaming_response(response, show_reasoning=True)
|
||||
except Exception as e:
|
||||
log.error(f"\nError during streaming: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
else:
|
||||
if isinstance(response, dict):
|
||||
choice = response.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
reasoning = message.get("reasoning_content", "") or message.get(
|
||||
"reasoning", ""
|
||||
)
|
||||
|
||||
if reasoning:
|
||||
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
||||
|
||||
print(f"\n💬 Assistant: {content}")
|
||||
print(f"\nFull Response:")
|
||||
print(json.dumps(response, indent=2))
|
||||
else:
|
||||
log.error("Unexpected response format")
|
||||
|
||||
def demo_ls_tool(self) -> None:
|
||||
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
|
||||
async def demo_ls_tool(self) -> None:
|
||||
"""Ask to list files using function calling, then provide final analysis"""
|
||||
print("=" * 60)
|
||||
print("TOOL USE DEMO: List Directory Contents")
|
||||
print("=" * 60)
|
||||
|
||||
# Test if tools are supported first
|
||||
if not self.test_tool_support():
|
||||
if not await self.test_tool_support():
|
||||
return
|
||||
|
||||
# Request with tool available
|
||||
messages = [{"role": "user", "content": TOOLS_PROMPT}]
|
||||
messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
# First pass: let the model decide tools, stream tool_calls and partial content
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=self.tool_manager.get_ls_tool_definition(),
|
||||
tool_choice="auto",
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
log.info(f"Making initial request with tool using model '{self.model}'...")
|
||||
response = self.client.call_chat_completions(config)
|
||||
assistant_content_buf: List[str] = []
|
||||
tool_calls_state: Dict[int, Dict[str, Any]] = {}
|
||||
printed_reasoning = False
|
||||
printed_answer = False
|
||||
|
||||
if not isinstance(response, dict):
|
||||
raise ValueError("Expected dict response for tool use")
|
||||
async for chunk in stream:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
choice = response.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
rc = delta.get("reasoning_content")
|
||||
if rc:
|
||||
if not printed_reasoning:
|
||||
printed_reasoning = True
|
||||
print("🧠 Reasoning: ", end="", flush=True)
|
||||
print(rc, end="", flush=True)
|
||||
|
||||
print(f"Assistant response: {message.get('content', 'No content')}")
|
||||
content_part = delta.get("content")
|
||||
if content_part:
|
||||
assistant_content_buf.append(content_part)
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
print(content_part, end="", flush=True)
|
||||
|
||||
# Check for tool calls
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not tool_calls:
|
||||
raise ValueError(
|
||||
"No tool calls made - model may not support function calling"
|
||||
)
|
||||
if "tool_calls" in delta and delta["tool_calls"]:
|
||||
for tc_delta in delta["tool_calls"]:
|
||||
_merge_tool_call_delta(tool_calls_state, tc_delta)
|
||||
|
||||
print(f"Tool calls detected: {len(tool_calls)}")
|
||||
# If no tool calls, we’re done.
|
||||
if not tool_calls_state:
|
||||
print("\n(No tool calls were made.)")
|
||||
return
|
||||
|
||||
# Execute the tool call
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call["function"]["name"]
|
||||
print(f"Executing tool: {function_name}")
|
||||
# Build assistant message with tool_calls
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": "".join(assistant_content_buf) if assistant_content_buf else None,
|
||||
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
|
||||
}
|
||||
messages.append(assistant_message)
|
||||
|
||||
tool_result = self.tool_manager.execute_tool_call(tool_call)
|
||||
print(f"Tool result:\n{tool_result}")
|
||||
# Execute tools and feed results back
|
||||
for tc in assistant_message["tool_calls"]:
|
||||
tool_name = (tc.get("function") or {}).get("name")
|
||||
call_id = tc.get("id")
|
||||
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
|
||||
|
||||
# Add tool result and continue conversation
|
||||
messages.append(message) # Add assistant's message with tool call
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
try:
|
||||
args = json.loads(raw_args) if raw_args.strip() else {}
|
||||
except Exception as e:
|
||||
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
|
||||
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
|
||||
continue
|
||||
|
||||
# Get final response
|
||||
final_config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=self.tool_manager.get_ls_tool_definition(),
|
||||
)
|
||||
try:
|
||||
if tool_name == "list_files":
|
||||
tool_result = self.tool_manager.list_files()
|
||||
else:
|
||||
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
|
||||
except Exception as e:
|
||||
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
|
||||
|
||||
print("Getting final response...")
|
||||
final_response = self.client.call_chat_completions(final_config)
|
||||
print("\n[Tool executed]", tool_name)
|
||||
print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
|
||||
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
|
||||
|
||||
if isinstance(final_response, dict):
|
||||
final_choice = final_response.get("choices", [{}])[0]
|
||||
final_message = final_choice.get("message", {})
|
||||
final_content = final_message.get("content", "")
|
||||
# Second pass: get final streamed answer after tool results
|
||||
stream2 = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("FINAL LLM ANALYSIS:")
|
||||
print("=" * 60)
|
||||
print(final_content)
|
||||
print("=" * 60)
|
||||
final_buf = []
|
||||
printed_reasoning2 = False
|
||||
printed_answer2 = False
|
||||
|
||||
def interactive_chat(self) -> None:
|
||||
async for chunk in stream2:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
rc2 = delta.get("reasoning_content")
|
||||
if rc2:
|
||||
if not printed_reasoning2:
|
||||
printed_reasoning2 = True
|
||||
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
|
||||
print(rc2, end="", flush=True)
|
||||
|
||||
c2 = delta.get("content")
|
||||
if c2:
|
||||
final_buf.append(c2)
|
||||
if not printed_answer2:
|
||||
printed_answer2 = True
|
||||
print("\n💬 Response (final): ", end="", flush=True)
|
||||
print(c2, end="", flush=True)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("FINAL LLM ANALYSIS:")
|
||||
print("=" * 60)
|
||||
print("".join(final_buf))
|
||||
print("=" * 60)
|
||||
|
||||
async def interactive_chat(self) -> None:
|
||||
"""Interactive chat session with streaming"""
|
||||
print("=" * 60)
|
||||
print("INTERACTIVE STREAMING CHAT")
|
||||
@@ -449,7 +431,7 @@ class APIDemo:
|
||||
print("Type 'quit' to exit, 'clear' to clear history")
|
||||
print()
|
||||
|
||||
messages = []
|
||||
messages: List[Dict[str, Any]] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -467,16 +449,15 @@ class APIDemo:
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model, messages=messages, stream=True, temperature=0.7
|
||||
)
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = self.client.call_chat_completions(config)
|
||||
assistant_content = self.handle_streaming_response(
|
||||
response, show_reasoning=True
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=0.7
|
||||
)
|
||||
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
|
||||
|
||||
# Add assistant response to conversation history
|
||||
messages.append({"role": "assistant", "content": assistant_content})
|
||||
@@ -485,115 +466,64 @@ class APIDemo:
|
||||
print("\n👋 Chat interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
log.error(f"\nError: {e}")
|
||||
log.error("\nError: %s", e)
|
||||
continue
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function with CLI switches for different tests"""
|
||||
from lib.test_utils import test_args
|
||||
# ---------------------- CLI ----------------------
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
|
||||
p.add_argument("--model", required=True, help="Model to use for requests (required)")
|
||||
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
|
||||
|
||||
# Add mandatory model argument
|
||||
test_args.add_argument(
|
||||
"--model", required=True, help="Model to use for requests (required)"
|
||||
)
|
||||
modes = p.add_mutually_exclusive_group(required=False)
|
||||
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
|
||||
modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
|
||||
modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
|
||||
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
|
||||
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
|
||||
return p
|
||||
|
||||
# Add test mode arguments
|
||||
test_args.add_argument(
|
||||
"--completion", action="store_true", help="Test completions endpoint"
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--chat",
|
||||
action="store_true",
|
||||
help="Test chat completions endpoint (non-streaming)",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--chat-stream",
|
||||
action="store_true",
|
||||
help="Test chat completions endpoint with streaming",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--tools",
|
||||
action="store_true",
|
||||
help="Test function calling with ls tool (non-streaming)",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="Start interactive streaming chat session",
|
||||
)
|
||||
|
||||
args = test_args.parse_args()
|
||||
async def main_async():
|
||||
args = build_arg_parser().parse_args()
|
||||
|
||||
# Check that only one test mode is selected
|
||||
test_modes = [
|
||||
args.completion,
|
||||
args.chat,
|
||||
args.chat_stream,
|
||||
args.tools,
|
||||
args.interactive,
|
||||
]
|
||||
selected_count = sum(test_modes)
|
||||
|
||||
if selected_count == 0:
|
||||
selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
|
||||
if selected == 0:
|
||||
print("Please specify exactly one test mode:")
|
||||
print(" --completion : Test completions endpoint")
|
||||
print(" --chat : Test chat completions endpoint (non-streaming)")
|
||||
print(" --chat-stream : Test chat completions endpoint with streaming")
|
||||
print(" --tools : Test function calling with ls tool (non-streaming)")
|
||||
print(" --tools : Test function calling with ls tool")
|
||||
print(" --interactive : Start interactive streaming chat session")
|
||||
print(
|
||||
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
|
||||
)
|
||||
print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
|
||||
sys.exit(1)
|
||||
elif selected_count > 1:
|
||||
elif selected > 1:
|
||||
print("Please specify exactly one test mode")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using model: {args.model}")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.model, ToolManager())
|
||||
|
||||
if not endpoint_api_key:
|
||||
log.error(
|
||||
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Create the core API client
|
||||
client = APIClient(
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
api_key=args.api_key,
|
||||
server_url=Endpoint.get_autoscaler_server_url(args.instance),
|
||||
endpoint_api_key=endpoint_api_key,
|
||||
)
|
||||
|
||||
# Create tool manager and demo (passing the model parameter)
|
||||
tool_manager = ToolManager()
|
||||
demo = APIDemo(client, args.model, tool_manager)
|
||||
|
||||
print(f"Using model: {args.model}")
|
||||
print("=" * 60)
|
||||
|
||||
# Run the selected test
|
||||
if args.completion:
|
||||
demo.demo_completions()
|
||||
elif args.chat:
|
||||
demo.demo_chat(use_streaming=False)
|
||||
elif args.chat_stream:
|
||||
demo.demo_chat(use_streaming=True)
|
||||
elif args.tools:
|
||||
demo.demo_ls_tool()
|
||||
elif args.interactive:
|
||||
demo.interactive_chat()
|
||||
if args.completion:
|
||||
await demo.demo_completions()
|
||||
elif args.chat:
|
||||
await demo.demo_chat(use_streaming=False)
|
||||
elif args.chat_stream:
|
||||
await demo.demo_chat(use_streaming=True)
|
||||
elif args.tools:
|
||||
await demo.demo_ls_tool()
|
||||
elif args.interactive:
|
||||
await demo.interactive_chat()
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error during test: {e}", exc_info=True)
|
||||
log.error("Error during test: %s", e, exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
asyncio.run(main_async())
|
||||
|
||||
@@ -119,14 +119,25 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
|
||||
class CompletionsData(GenericData):
|
||||
@classmethod
|
||||
def for_test(cls) -> "CompletionsData":
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||
|
||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||
woodlands, shrublands, and mountainous areas.
|
||||
|
||||
Please answer the following question based on the above context."""
|
||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
test_input = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"prompt": f"{system_prompt}\n\n{unique_question}",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
@@ -153,7 +164,18 @@ class ChatCompletionsData(GenericData):
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "ChatCompletionsData":
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||
|
||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||
woodlands, shrublands, and mountainous areas.
|
||||
|
||||
Please answer the following question based on the above context."""
|
||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
@@ -161,7 +183,10 @@ class ChatCompletionsData(GenericData):
|
||||
# Chat completions use messages format instead of prompt
|
||||
test_input = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt}, # Shared prefix
|
||||
{"role": "user", "content": unique_question} # Unique per request
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
|
||||
+414
-8
@@ -1,8 +1,395 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from lib.test_utils import test_args
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from lib.data_types import AuthData
|
||||
from .data_types.server import CompletionsData
|
||||
import os
|
||||
|
||||
WORKER_ENDPOINT = "/v1/completions"
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import requests
|
||||
from dataclasses import dataclass
|
||||
from collections import Counter
|
||||
from urllib.parse import urljoin, urlparse
|
||||
import re
|
||||
|
||||
# Headless plotting
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import logging
|
||||
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
|
||||
from requests.adapters import HTTPAdapter
|
||||
|
||||
def get_incremented_path(path: str) -> str:
|
||||
base, ext = os.path.splitext(path)
|
||||
if not os.path.exists(path):
|
||||
return path
|
||||
i = 1
|
||||
while os.path.exists(f"{base}-{i}{ext}"):
|
||||
i += 1
|
||||
return f"{base}-{i}{ext}"
|
||||
|
||||
WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT)
|
||||
|
||||
@dataclass
|
||||
class ReqResult:
|
||||
worker_url: str
|
||||
route_ms: float
|
||||
worker_ms: float
|
||||
total_ms: float
|
||||
ok: bool
|
||||
error: str = ""
|
||||
status_code: int = 0
|
||||
t_start: float = 0.0
|
||||
t_end: float = 0.0
|
||||
workload: float = 0.0
|
||||
|
||||
def do_one(endpoint_name: str,
|
||||
endpoint_id: int,
|
||||
endpoint_api_key: str,
|
||||
server_url: str,
|
||||
worker_endpoint: str,
|
||||
payload,
|
||||
results_list,
|
||||
t0,
|
||||
status_samples,
|
||||
route_session,
|
||||
worker_session):
|
||||
try:
|
||||
workload = payload.count_workload()
|
||||
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
|
||||
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
|
||||
start = time.time()
|
||||
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
|
||||
t_after_route = time.time()
|
||||
if r0.status_code != 200:
|
||||
results_list.append(ReqResult(worker_url="",
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=0.0,
|
||||
total_ms=(t_after_route - start) * 1000.0,
|
||||
ok=False,
|
||||
error=f"route error {r0.reason} {r0.text}",
|
||||
status_code=r0.status_code,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_route - t0,
|
||||
workload=workload))
|
||||
return
|
||||
msg = r0.json()
|
||||
|
||||
# 1) Check if we got a worker back from route
|
||||
worker_url = msg.get("url", "")
|
||||
if not worker_url:
|
||||
status = msg.get("status", "")
|
||||
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
|
||||
if m:
|
||||
tot, loading, standby, err = map(int, m.groups())
|
||||
idle = max(tot - loading - standby - err, 0)
|
||||
status_samples.append((time.time() - t0, idle))
|
||||
|
||||
# 2) If we got a worker, send the request
|
||||
if worker_url:
|
||||
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
|
||||
t_before_worker = time.time()
|
||||
r1 = worker_session.post(
|
||||
urljoin(worker_url, worker_endpoint),
|
||||
json=req,
|
||||
verify=get_cert_file_path(),
|
||||
timeout=(4, 120),
|
||||
)
|
||||
t_after_worker = time.time()
|
||||
if r1.status_code != 200:
|
||||
results_list.append(ReqResult(worker_url=worker_url,
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||
total_ms=(t_after_worker - start) * 1000.0,
|
||||
ok=False,
|
||||
error=f"worker inference error {r1.reason} {r1.text}",
|
||||
status_code=r1.status_code,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_worker - t0,
|
||||
workload=workload))
|
||||
return
|
||||
# Success case
|
||||
results_list.append(ReqResult(worker_url=worker_url,
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||
total_ms=(t_after_worker - start) * 1000.0,
|
||||
ok=True,
|
||||
error="",
|
||||
status_code=200,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_worker - t0,
|
||||
workload=workload))
|
||||
|
||||
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
|
||||
if worker_url:
|
||||
try:
|
||||
r_status = route_session.post(
|
||||
urljoin(server_url, "/get_endpoint_workers/"),
|
||||
json={"id": endpoint_id},
|
||||
headers={"Authorization": f"Bearer {endpoint_api_key}"},
|
||||
timeout=3,
|
||||
)
|
||||
if r_status.status_code == 200:
|
||||
workers = r_status.json()
|
||||
idle = 0
|
||||
for w in workers:
|
||||
st = str(w.get("status", "")).lower()
|
||||
if (st in ("idle")):
|
||||
idle += 1
|
||||
status_samples.append((time.time() - t0, idle))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
t = time.time()
|
||||
results_list.append(ReqResult(worker_url="",
|
||||
route_ms=0.0,
|
||||
worker_ms=0.0,
|
||||
total_ms=0.0,
|
||||
ok=False,
|
||||
error=f"unknown error {e}",
|
||||
status_code=0,
|
||||
t_start=t - t0,
|
||||
t_end=t - t0,
|
||||
workload=0.0))
|
||||
|
||||
def run_load_with_metrics(num_requests: int,
|
||||
requests_per_second: float,
|
||||
endpoint_group_name: str,
|
||||
account_api_key: str,
|
||||
server_url: str,
|
||||
worker_endpoint: str,
|
||||
instance: str,
|
||||
out_path: str):
|
||||
|
||||
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
|
||||
account_api_key=account_api_key,
|
||||
instance=instance)
|
||||
if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"):
|
||||
print(f"Endpoint {endpoint_group_name} not found for API key")
|
||||
return
|
||||
endpoint_id = int(ep_info["id"])
|
||||
endpoint_api_key = ep_info["api_key"]
|
||||
|
||||
t0 = time.time()
|
||||
results = []
|
||||
status_samples = []
|
||||
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
|
||||
submit_queue_factor = 2 # cap queued tasks to reduce memory
|
||||
|
||||
# Shared HTTP sessions with connection pooling (persistent connections)
|
||||
def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session:
|
||||
sess = requests.Session()
|
||||
adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0)
|
||||
sess.mount("https://", adapter)
|
||||
sess.mount("http://", adapter)
|
||||
return sess
|
||||
|
||||
# Router: mostly single host, small connection pool is sufficient
|
||||
route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
|
||||
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
|
||||
worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
|
||||
|
||||
# Fire requests using a thread pool, scheduling at requested RPS
|
||||
inflight = set()
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
for i in range(num_requests):
|
||||
# Pace submissions to RPS
|
||||
target_time = t0 + i / max(requests_per_second, 1e-9)
|
||||
sleep_s = target_time - time.time()
|
||||
if sleep_s > 0:
|
||||
time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive
|
||||
|
||||
payload = CompletionsData.for_test()
|
||||
fut = executor.submit(
|
||||
do_one,
|
||||
endpoint_group_name,
|
||||
endpoint_id,
|
||||
endpoint_api_key,
|
||||
server_url,
|
||||
worker_endpoint,
|
||||
payload,
|
||||
results,
|
||||
t0,
|
||||
status_samples,
|
||||
route_session,
|
||||
worker_session,
|
||||
)
|
||||
inflight.add(fut)
|
||||
# Prevent unbounded queue growth
|
||||
if len(inflight) >= max_concurrency * submit_queue_factor:
|
||||
done, not_done = wait(inflight, return_when=FIRST_COMPLETED)
|
||||
inflight = not_done
|
||||
# Wait for all outstanding tasks
|
||||
if inflight:
|
||||
wait(inflight)
|
||||
# Close sessions
|
||||
try:
|
||||
route_session.close()
|
||||
finally:
|
||||
worker_session.close()
|
||||
|
||||
# Aggregate results
|
||||
oks = [r for r in results if r.ok]
|
||||
errs = [r for r in results if not r.ok]
|
||||
total_reqs = len(results)
|
||||
succ = len(oks)
|
||||
|
||||
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
|
||||
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
|
||||
route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
|
||||
|
||||
avg_total = float(np.mean(total_ms)) if succ else 0.0
|
||||
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
|
||||
avg_route = float(np.mean(route_ms)) if succ else 0.0
|
||||
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
|
||||
|
||||
# Distribution over workers (by host:port)
|
||||
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
|
||||
dist = Counter(hosts)
|
||||
|
||||
# Idle over time (mode per second)
|
||||
idle_ts, idle_vals = [], []
|
||||
if status_samples:
|
||||
buckets = {}
|
||||
for ts, idle in status_samples:
|
||||
k = int(ts)
|
||||
buckets.setdefault(k, []).append(idle)
|
||||
keys = sorted(buckets.keys())
|
||||
idle_ts = keys
|
||||
# Use the most frequent sampled value per second (mode) to keep integer counts
|
||||
idle_vals = []
|
||||
for k in keys:
|
||||
vals_k = [int(v) for v in buckets[k]]
|
||||
if vals_k:
|
||||
cnt = Counter(vals_k)
|
||||
idle_vals.append(cnt.most_common(1)[0][0])
|
||||
else:
|
||||
idle_vals.append(0)
|
||||
|
||||
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
|
||||
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
|
||||
print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
|
||||
if errs:
|
||||
print("Sample errors:")
|
||||
for e in errs[:5]:
|
||||
print(f" {e.status_code} {e.error}")
|
||||
|
||||
# Plot: 2x3 grid
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||||
fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}")
|
||||
|
||||
# Dist per worker
|
||||
ax0 = axes[0, 0]
|
||||
if dist:
|
||||
items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True)
|
||||
labels, counts = zip(*items)
|
||||
ax0.bar(range(len(labels)), counts)
|
||||
ax0.set_xticks(range(len(labels)))
|
||||
ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||
ax0.set_title("Request distribution over workers")
|
||||
ax0.set_ylabel("count")
|
||||
|
||||
# Latency histogram (total)
|
||||
ax1 = axes[0, 1]
|
||||
if succ:
|
||||
ax1.hist(total_ms, bins=30)
|
||||
ax1.set_title("Total latency (ms)")
|
||||
ax1.set_xlabel("ms")
|
||||
ax1.set_ylabel("freq")
|
||||
|
||||
# Eligible workers over time
|
||||
ax_idle = axes[0, 2]
|
||||
if idle_ts:
|
||||
ax_idle.plot(idle_ts, idle_vals, "-o", ms=3)
|
||||
ax_idle.set_title("Eligible workers over time")
|
||||
ax_idle.set_xlabel("time (s)")
|
||||
ax_idle.set_ylabel("eligible count")
|
||||
|
||||
# Throughput over time (completions/sec)
|
||||
ax_idle = axes[1, 0]
|
||||
ax_idle.clear()
|
||||
if succ:
|
||||
per_sec = {}
|
||||
for r in oks:
|
||||
s = int(r.t_end)
|
||||
per_sec[s] = per_sec.get(s, 0) + 1
|
||||
ts = sorted(per_sec.keys())
|
||||
vals = [per_sec[t] for t in ts]
|
||||
ax_idle.plot(ts, vals, "-o", ms=3)
|
||||
ax_idle.set_title("Completions per second")
|
||||
ax_idle.set_xlabel("time (s)")
|
||||
ax_idle.set_ylabel("completions / sec")
|
||||
|
||||
# Summary text
|
||||
ax3 = axes[1, 1]
|
||||
ax3.axis("off")
|
||||
text = (
|
||||
f"Total requests: {total_reqs}\n"
|
||||
f"Success: {succ} Errors: {len(errs)}\n"
|
||||
f"Avg total latency: {avg_total:.1f} ms\n"
|
||||
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
|
||||
f"Avg route latency: {avg_route:.1f} ms\n"
|
||||
f"Avg worker latency: {avg_worker:.1f} ms\n"
|
||||
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
|
||||
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
|
||||
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
|
||||
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
|
||||
)
|
||||
ax3.set_title("Summary")
|
||||
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
|
||||
|
||||
# Error count over time
|
||||
ax_errors = axes[1, 2]
|
||||
all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
|
||||
if all_end_times:
|
||||
min_second = min(all_end_times)
|
||||
max_second = max(all_end_times)
|
||||
# Count errors per second
|
||||
errors_per_second = {}
|
||||
for result in errs:
|
||||
second = int(result.t_end)
|
||||
errors_per_second[second] = errors_per_second.get(second, 0) + 1
|
||||
# Create complete timeline including zeros
|
||||
time_seconds = list(range(min_second, max_second + 1))
|
||||
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
|
||||
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
|
||||
ax_errors.set_title("Errors per second")
|
||||
ax_errors.set_xlabel("time (s)")
|
||||
ax_errors.set_ylabel("errors / sec")
|
||||
|
||||
# Ensure unique output path and create directory if needed
|
||||
final_out_path = get_incremented_path(out_path)
|
||||
out_dir = os.path.dirname(final_out_path)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
plt.savefig(final_out_path, dpi=120)
|
||||
print(f"Saved report to: {final_out_path}")
|
||||
|
||||
# Per-worker latency boxplot (top 12 by volume)
|
||||
groups = {}
|
||||
for r in oks:
|
||||
host = urlparse(r.worker_url).netloc
|
||||
groups.setdefault(host, []).append(r.total_ms)
|
||||
items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12]
|
||||
if items:
|
||||
labels, data = zip(*items)
|
||||
fig2, axb = plt.subplots(1, 1, figsize=(12, 5))
|
||||
axb.boxplot(data, showfliers=False)
|
||||
axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||
axb.set_title("Per-worker latency (ms)")
|
||||
axb.set_ylabel("ms")
|
||||
plt.tight_layout()
|
||||
extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png")
|
||||
plt.savefig(extra_out, dpi=120)
|
||||
fig2.tight_layout()
|
||||
fig2.savefig(extra_out, dpi=120)
|
||||
print(f"Saved worker latency plot to: {extra_out}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if MODEL_NAME environment variable is set
|
||||
@@ -16,13 +403,32 @@ if __name__ == "__main__":
|
||||
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
||||
)
|
||||
|
||||
# Parse known args to get model early, before test_load_cmd adds its args
|
||||
# Parse known args to get model early, before adding load args
|
||||
known_args, _ = test_args.parse_known_args()
|
||||
|
||||
# Set environment variable if model was provided
|
||||
if hasattr(known_args, "model") and known_args.model:
|
||||
os.environ["MODEL_NAME"] = known_args.model
|
||||
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
||||
|
||||
# Now call test_load_cmd normally - it will add its own args and re-parse
|
||||
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||
# Load test args
|
||||
test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
|
||||
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
|
||||
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
|
||||
args = test_args.parse_args()
|
||||
|
||||
server_url = {
|
||||
"prod": "https://run.vast.ai",
|
||||
"alpha": "https://run-alpha.vast.ai",
|
||||
"candidate": "https://run-candidate.vast.ai",
|
||||
"local": "http://localhost:8080"
|
||||
}.get(args.instance, "http://localhost:8080")
|
||||
|
||||
run_load_with_metrics(
|
||||
num_requests=args.num_requests,
|
||||
requests_per_second=args.requests_per_second,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
server_url=server_url,
|
||||
worker_endpoint=WORKER_ENDPOINT,
|
||||
instance=args.instance,
|
||||
out_path=args.out_path,
|
||||
)
|
||||
+49
-113
@@ -1,125 +1,61 @@
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from vastai import Serverless
|
||||
import asyncio
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
|
||||
MAX_TOKENS = 1024
|
||||
PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
|
||||
async def call_generate(client: Serverless) -> None:
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
payload = {
|
||||
"inputs": PROMPT,
|
||||
"parameters": {
|
||||
"max_new_tokens": MAX_TOKENS,
|
||||
"temperature": 0.7,
|
||||
"return_full_text": False
|
||||
}
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=url,
|
||||
)
|
||||
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
|
||||
|
||||
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
print(res)
|
||||
print(resp["response"]["generated_text"])
|
||||
|
||||
|
||||
def call_generate_stream(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
WORKER_ENDPOINT = "/generate_stream"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
async def call_generate_stream(client: Serverless) -> None:
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"inputs": PROMPT,
|
||||
"parameters": {
|
||||
"max_new_tokens": MAX_TOKENS,
|
||||
"temperature": 0.7,
|
||||
"do_sample": True,
|
||||
"return_full_text": False,
|
||||
}
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
print(f"url: {url}")
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
)
|
||||
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
response = requests.post(url, json=req_data, stream=True)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
for line in response.iter_lines():
|
||||
payload = line.decode().lstrip("data:").rstrip()
|
||||
if payload:
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
print(data["token"]["text"], end="")
|
||||
sys.stdout.flush()
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
log.warning(f"Failed to parse streaming response: {e}")
|
||||
continue
|
||||
print()
|
||||
|
||||
resp = await endpoint.request(
|
||||
"/generate_stream",
|
||||
payload,
|
||||
cost=MAX_TOKENS,
|
||||
stream=True,
|
||||
)
|
||||
stream = resp["response"]
|
||||
|
||||
printed_answer = False
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("Answer:\n", end="", flush=True)
|
||||
print(tok, end="", flush=True)
|
||||
|
||||
async def main():
|
||||
async with Serverless() as client:
|
||||
await call_generate(client)
|
||||
await call_generate_stream(client)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lib.test_utils import test_args
|
||||
|
||||
args = test_args.parse_args()
|
||||
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
if endpoint_api_key:
|
||||
try:
|
||||
call_generate(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
call_generate_stream(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error during API call: {e}")
|
||||
else:
|
||||
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
|
||||
asyncio.run(main())
|
||||
|
||||
Reference in New Issue
Block a user