Compare commits
126 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 68d8ce4bfd | |||
| 138fc3ac47 | |||
| 222ac2a0dd | |||
| 40aed9b5f8 | |||
| d4d36bf86e | |||
| e839cfc6e8 | |||
| f04138e13b | |||
| de3aa87c8f | |||
| 6b5b1341a7 | |||
| 8be92c03de | |||
| adedb8ba90 | |||
| 2f543c01ad | |||
| 0bcd2219ea | |||
| 0339b471c5 | |||
| e143162438 | |||
| 7986e51e9e | |||
| 9c6ab78503 | |||
| 45e0c7d9ca | |||
| 7a792fd176 | |||
| e0449cb3c7 | |||
| a4339bd3f1 | |||
| 2b26e5e20c | |||
| d3727d4fd7 | |||
| a47c9d1ed0 | |||
| 0b14562a63 | |||
| de9b50abb9 | |||
| c510801723 | |||
| a12523b1d2 | |||
| eedf81c0a3 | |||
| 3adec1826d | |||
| b55bfa9611 | |||
| 7db54f3bd7 | |||
| d63a060202 | |||
| c6521cb6d4 | |||
| b7fe4ebb91 | |||
| 8ae7b74605 | |||
| 106067d716 | |||
| f5134d4bf5 | |||
| 47e5460532 | |||
| ec2ac0a21a | |||
| 2cde573c56 | |||
| b2e4a5db0c | |||
| 7437028cb2 | |||
| 02c8307af7 | |||
| 7c0f316eeb | |||
| b4025a744f | |||
| d190308329 | |||
| 9f5a432513 | |||
| e09f1fa953 | |||
| ba6f1c2e4b | |||
| 944f83fc03 | |||
| 298590fb88 | |||
| 814c3acd4c | |||
| 22bca74087 | |||
| 9c795e2a01 | |||
| 830b532781 | |||
| d6a6e34c6b | |||
| ac1e109c48 | |||
| d6eb498ee4 | |||
| f56bbc0ebe | |||
| bcecd6df40 | |||
| 4d9bf2048c | |||
| 7788bc4a62 | |||
| 37ad3f8d46 | |||
| 70d51bafe1 | |||
| 63909736bb | |||
| f4f7080df1 | |||
| d51a338e8f | |||
| 92a04bd7af | |||
| 0f13506938 | |||
| 01e752d31f | |||
| 5edfa968ca | |||
| 5b5ef7227a | |||
| 16990ff8ff | |||
| 9748176366 | |||
| b39193ae70 | |||
| 9a6ca5d412 | |||
| e9ba1b03e4 | |||
| c98d661513 | |||
| f6fd1c6ac1 | |||
| 055e346c8c | |||
| 1cedb28acf | |||
| ec25dda3ad | |||
| 0397af719d | |||
| 4fdc314fd9 | |||
| 3786cf978d | |||
| a86d4bcf9c | |||
| e9b6a14a5e | |||
| cadac033e1 | |||
| 639d82f5b4 | |||
| 25db78e39d | |||
| 4e2f2311d0 | |||
| 38782d89bc | |||
| 0185216ccb | |||
| b20d9e714c | |||
| b1eb65d75d | |||
| 1d09d7fe96 | |||
| 1b37054dec | |||
| 1a1e4174b8 | |||
| b8377c4081 | |||
| 1e4fa87437 | |||
| 4c5fa03c7b | |||
| a8fe74f771 | |||
| b482de8394 | |||
| 703435d10e | |||
| 947fc5eea4 | |||
| 7c1a544b19 | |||
| 16b414676e | |||
| ba74ac8136 | |||
| 92ff412679 | |||
| fc75a64684 | |||
| b00bef547c | |||
| 3f4acb29fa | |||
| 58b078f908 | |||
| f9fdf04884 | |||
| 636f17d27f | |||
| 08c88f7527 | |||
| 8797b504af | |||
| cd946b0a9f | |||
| c595b42410 | |||
| 0bf3247a34 | |||
| 52ac4c0c1a | |||
| 8804e17201 | |||
| 4016cf9a53 | |||
| e0be45f39a | |||
| be2aafdb1f |
@@ -3,3 +3,4 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
bin/
|
bin/
|
||||||
lib64
|
lib64
|
||||||
|
.venv
|
||||||
@@ -39,11 +39,12 @@ reporting these metrics to the autoscaler.
|
|||||||
|
|
||||||
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
|
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
|
||||||
|
|
||||||
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00)
|
* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d)
|
||||||
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447)
|
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d)
|
||||||
|
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938)
|
||||||
|
|
||||||
Currently available workers:
|
Currently available workers:
|
||||||
* `hello_world`: A simple example worker for a basic LLM server.
|
* `openai`: A simple example worker for a basic vLLM server.
|
||||||
* `comfyui`: A worker for the ComfyUI image generation backend.
|
* `comfyui`: A worker for the ComfyUI image generation backend.
|
||||||
* `tgi`: A worker for the Text Generation Inference backend.
|
* `tgi`: A worker for the Text Generation Inference backend.
|
||||||
|
|
||||||
|
|||||||
+144
-75
@@ -5,13 +5,14 @@ import base64
|
|||||||
import subprocess
|
import subprocess
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from asyncio import sleep, gather, Semaphore
|
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
|
|
||||||
from anyio import open_file
|
from anyio import open_file
|
||||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
|
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from Crypto.Signature import pkcs1_15
|
from Crypto.Signature import pkcs1_15
|
||||||
@@ -25,8 +26,12 @@ from lib.data_types import (
|
|||||||
LogAction,
|
LogAction,
|
||||||
ApiPayload_T,
|
ApiPayload_T,
|
||||||
JsonDataException,
|
JsonDataException,
|
||||||
|
RequestMetrics,
|
||||||
|
BenchmarkResult
|
||||||
)
|
)
|
||||||
|
|
||||||
|
VERSION = "0.2.1"
|
||||||
|
|
||||||
MSG_HISTORY_LEN = 100
|
MSG_HISTORY_LEN = 100
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
@@ -53,15 +58,25 @@ class Backend:
|
|||||||
EndpointHandler # this endpoint handler will be used for benchmarking
|
EndpointHandler # this endpoint handler will be used for benchmarking
|
||||||
)
|
)
|
||||||
log_actions: List[Tuple[LogAction, str]]
|
log_actions: List[Tuple[LogAction, str]]
|
||||||
|
max_wait_time: float = 10.0
|
||||||
reqnum = -1
|
reqnum = -1
|
||||||
|
version = VERSION
|
||||||
msg_history = []
|
msg_history = []
|
||||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||||
unsecured: bool = dataclasses.field(
|
unsecured: bool = dataclasses.field(
|
||||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||||
)
|
)
|
||||||
|
report_addr: str = dataclasses.field(
|
||||||
|
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
||||||
|
)
|
||||||
|
mtoken: str = dataclasses.field(
|
||||||
|
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.metrics = Metrics()
|
self.metrics = Metrics()
|
||||||
|
self.metrics._set_version(self.version)
|
||||||
|
self.metrics._set_mtoken(self.mtoken)
|
||||||
self._total_pubkey_fetch_errors = 0
|
self._total_pubkey_fetch_errors = 0
|
||||||
self._pubkey = self._fetch_pubkey()
|
self._pubkey = self._fetch_pubkey()
|
||||||
self.__start_healthcheck: bool = False
|
self.__start_healthcheck: bool = False
|
||||||
@@ -75,7 +90,13 @@ class Backend:
|
|||||||
@cached_property
|
@cached_property
|
||||||
def session(self):
|
def session(self):
|
||||||
log.debug(f"starting session with {self.model_server_url}")
|
log.debug(f"starting session with {self.model_server_url}")
|
||||||
return ClientSession(self.model_server_url)
|
connector = TCPConnector(
|
||||||
|
force_close=True, # Required for long running jobs
|
||||||
|
enable_cleanup_closed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout = ClientTimeout(total=None)
|
||||||
|
return ClientSession(self.model_server_url, timeout=timeout, connector=connector)
|
||||||
|
|
||||||
def create_handler(
|
def create_handler(
|
||||||
self,
|
self,
|
||||||
@@ -90,23 +111,19 @@ class Backend:
|
|||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
def _fetch_pubkey(self):
|
def _fetch_pubkey(self):
|
||||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
report_addr = self.report_addr.rstrip("/")
|
||||||
|
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
||||||
|
try:
|
||||||
result = subprocess.check_output(command, universal_newlines=True)
|
result = subprocess.check_output(command, universal_newlines=True)
|
||||||
log.debug("public key:")
|
log.debug("public key:")
|
||||||
log.debug(result)
|
log.debug(result)
|
||||||
key = None
|
|
||||||
for _ in range(5):
|
|
||||||
try:
|
|
||||||
key = RSA.import_key(result)
|
key = RSA.import_key(result)
|
||||||
break
|
if key is not None:
|
||||||
except ValueError as e:
|
|
||||||
log.debug(f"Error downloading key: {e}")
|
|
||||||
time.sleep(15)
|
|
||||||
if key is None:
|
|
||||||
self._total_pubkey_fetch_errors += 1
|
|
||||||
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
|
||||||
self.backend_errored("Failed to get autoscaler pubkey")
|
|
||||||
return key
|
return key
|
||||||
|
except (ValueError , subprocess.CalledProcessError) as e:
|
||||||
|
log.debug(f"Error downloading key: {e}")
|
||||||
|
self.backend_errored("Failed to get autoscaler pubkey")
|
||||||
|
|
||||||
|
|
||||||
async def __handle_request(
|
async def __handle_request(
|
||||||
self,
|
self,
|
||||||
@@ -122,75 +139,109 @@ class Backend:
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
|
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
|
||||||
|
|
||||||
|
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||||
|
await request.wait_for_disconnection()
|
||||||
|
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
|
||||||
|
self.metrics._request_canceled(request_metrics)
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||||
log.debug(f"got request, {auth_data.reqnum}")
|
|
||||||
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
|
|
||||||
if self.allow_parallel_requests is False:
|
|
||||||
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
|
|
||||||
await self.sem.acquire()
|
|
||||||
log.debug(
|
|
||||||
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
|
||||||
response = await self.__call_api(handler=handler, payload=payload)
|
response = await self.__call_api(handler=handler, payload=payload)
|
||||||
status_code = response.status
|
status_code = response.status
|
||||||
log.debug(
|
log.debug(
|
||||||
" ".join(
|
" ".join(
|
||||||
[
|
[
|
||||||
f"request with reqnum:{auth_data.reqnum}",
|
f"request with reqnum:{request_metrics.reqnum}",
|
||||||
f"returned status code: {status_code},",
|
f"returned status code: {status_code},",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
res = await handler.generate_client_response(request, response)
|
res = await handler.generate_client_response(request, response)
|
||||||
self.metrics._request_end(
|
self.metrics._request_success(request_metrics)
|
||||||
workload=workload,
|
|
||||||
req_response_time=time.time() - start_time,
|
|
||||||
reqnum=auth_data.reqnum,
|
|
||||||
)
|
|
||||||
return res
|
return res
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
log.debug(f"[backend] Request error: {e}")
|
log.debug(f"[backend] Request error: {e}")
|
||||||
self.metrics._request_errored(
|
self.metrics._request_errored(request_metrics)
|
||||||
workload=workload, reqnum=auth_data.reqnum
|
|
||||||
)
|
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
finally:
|
|
||||||
self.sem.release()
|
|
||||||
|
|
||||||
###########
|
###########
|
||||||
|
|
||||||
if self.__check_signature(auth_data) is False:
|
if self.__check_signature(auth_data) is False:
|
||||||
|
self.metrics._request_reject(request_metrics)
|
||||||
return web.Response(status=401)
|
return web.Response(status=401)
|
||||||
|
|
||||||
|
if self.metrics.model_metrics.wait_time > self.max_wait_time:
|
||||||
|
self.metrics._request_reject(request_metrics)
|
||||||
|
return web.Response(status=429)
|
||||||
|
|
||||||
|
acquired = False
|
||||||
try:
|
try:
|
||||||
return await make_request()
|
self.metrics._request_start(request_metrics)
|
||||||
|
if self.allow_parallel_requests is False:
|
||||||
|
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
|
||||||
|
await self.sem.acquire()
|
||||||
|
acquired = True
|
||||||
|
log.debug(
|
||||||
|
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
|
||||||
|
done, pending = await wait(
|
||||||
|
[
|
||||||
|
create_task(make_request()),
|
||||||
|
create_task(cancel_api_call_if_disconnected()),
|
||||||
|
],
|
||||||
|
return_when=FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
for t in pending:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
|
||||||
|
done_task = done.pop()
|
||||||
|
try:
|
||||||
|
return done_task.result()
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Request task raised exception: {e}")
|
||||||
|
return web.Response(status=500)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Client is gone. Do not write a response; just unwind.
|
||||||
|
return web.Response(status=499)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Exception in main handler loop {e}")
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
finally:
|
finally:
|
||||||
if request.task.cancelled():
|
# Always release the semaphore if it was acquired
|
||||||
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
if acquired:
|
||||||
self.metrics._request_canceled(
|
self.sem.release()
|
||||||
workload=workload, reqnum=auth_data.reqnum
|
self.metrics._request_end(request_metrics)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def healthcheck_session(self):
|
||||||
|
"""Dedicated session for healthchecks to avoid conflicts with API session"""
|
||||||
|
log.debug("creating dedicated healthcheck session")
|
||||||
|
connector = TCPConnector(
|
||||||
|
force_close=True, # Keep this for isolation
|
||||||
|
enable_cleanup_closed=True,
|
||||||
)
|
)
|
||||||
|
timeout = ClientTimeout(total=10) # Reasonable timeout for healthchecks
|
||||||
|
return ClientSession(timeout=timeout, connector=connector)
|
||||||
|
|
||||||
async def __healthcheck(self):
|
async def __healthcheck(self):
|
||||||
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
||||||
if health_check_url is None:
|
if health_check_url is None:
|
||||||
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||||
return
|
return
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
await sleep(10)
|
await sleep(10)
|
||||||
if self.__start_healthcheck is False:
|
if self.__start_healthcheck is False:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||||
async with self.session.get(health_check_url) as response:
|
async with self.healthcheck_session.get(health_check_url) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
log.debug("Healthcheck successful")
|
log.debug("Healthcheck successful")
|
||||||
elif response.status == 503:
|
elif response.status == 503:
|
||||||
@@ -199,7 +250,6 @@ class Backend:
|
|||||||
f"Healthcheck failed with status: {response.status}"
|
f"Healthcheck failed with status: {response.status}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# endpoint not ready yet so bail
|
|
||||||
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Healthcheck failed with exception: {e}")
|
log.debug(f"Healthcheck failed with exception: {e}")
|
||||||
@@ -207,7 +257,7 @@ class Backend:
|
|||||||
|
|
||||||
async def _start_tracking(self) -> None:
|
async def _start_tracking(self) -> None:
|
||||||
await gather(
|
await gather(
|
||||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
|
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
||||||
)
|
)
|
||||||
|
|
||||||
def backend_errored(self, msg: str) -> None:
|
def backend_errored(self, msg: str) -> None:
|
||||||
@@ -239,7 +289,7 @@ class Backend:
|
|||||||
message = {
|
message = {
|
||||||
key: value
|
key: value
|
||||||
for (key, value) in (dataclasses.asdict(auth_data).items())
|
for (key, value) in (dataclasses.asdict(auth_data).items())
|
||||||
if key != "signature"
|
if key != "signature" and key != "__request_id"
|
||||||
}
|
}
|
||||||
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
|
||||||
log.debug(
|
log.debug(
|
||||||
@@ -249,7 +299,7 @@ class Backend:
|
|||||||
elif message in self.msg_history:
|
elif message in self.msg_history:
|
||||||
log.debug(f"message: {message} already in message history")
|
log.debug(f"message: {message} already in message history")
|
||||||
return False
|
return False
|
||||||
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
|
elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
|
||||||
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
self.reqnum = max(auth_data.reqnum, self.reqnum)
|
||||||
self.msg_history.append(message)
|
self.msg_history.append(message)
|
||||||
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
|
||||||
@@ -268,48 +318,67 @@ class Backend:
|
|||||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||||
log.debug("already ran benchmark")
|
log.debug("already ran benchmark")
|
||||||
# trigger model load
|
# trigger model load
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
_ = await self.__call_api(
|
# _ = await self.__call_api(
|
||||||
handler=self.benchmark_handler, payload=payload
|
# handler=self.benchmark_handler, payload=payload
|
||||||
)
|
# )
|
||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
max_throughput = 0
|
|
||||||
last_throughput = 0
|
log.debug("Initial run to trigger model loading...")
|
||||||
sum_throughput = 0
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
for run in range(self.benchmark_handler.benchmark_runs + 1):
|
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||||
start = time.time()
|
|
||||||
|
max_throughput = 0
|
||||||
|
sum_throughput = 0
|
||||||
|
concurrent_requests = 10 if self.allow_parallel_requests else 1
|
||||||
|
|
||||||
|
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
||||||
|
start = time.time()
|
||||||
|
benchmark_requests = []
|
||||||
|
|
||||||
|
for i in range(concurrent_requests):
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
res = await self.__call_api(
|
|
||||||
handler=self.benchmark_handler, payload=payload
|
|
||||||
)
|
|
||||||
data = await res.json()
|
|
||||||
time_elapsed = time.time() - start
|
|
||||||
# first run triggers one-time loading of the model which is very slow, so we skip counting it
|
|
||||||
if run == 0:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
last_throughput = workload / time_elapsed
|
task = self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||||
sum_throughput += last_throughput
|
benchmark_requests.append(
|
||||||
max_throughput = max(max_throughput, last_throughput)
|
BenchmarkResult(request_idx=i, workload=workload, task=task)
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = await gather(*[br.task for br in benchmark_requests])
|
||||||
|
for br, response in zip(benchmark_requests, responses):
|
||||||
|
br.response = response
|
||||||
|
|
||||||
|
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
||||||
|
time_elapsed = time.time() - start
|
||||||
|
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
||||||
|
if successful_responses == 0:
|
||||||
|
self.backend_errored("No successful responses from benchmark")
|
||||||
|
log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses")
|
||||||
|
|
||||||
|
throughput = total_workload / time_elapsed
|
||||||
|
sum_throughput += throughput
|
||||||
|
max_throughput = max(max_throughput, throughput)
|
||||||
|
|
||||||
|
# Log results for debugging
|
||||||
log.debug(
|
log.debug(
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
|
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
||||||
"",
|
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
||||||
f"response: {data}",
|
f"Throughput: {throughput} workload/s",
|
||||||
|
f"Successful responses: {successful_responses}/{concurrent_requests}",
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
||||||
log.debug(
|
log.debug(
|
||||||
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||||
)
|
)
|
||||||
# save max_throughput so we don't have to run benchmark again on restart of cold instances
|
|
||||||
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||||
f.write(str(max_throughput))
|
f.write(str(max_throughput))
|
||||||
return max_throughput
|
return max_throughput
|
||||||
@@ -327,7 +396,7 @@ class Backend:
|
|||||||
)
|
)
|
||||||
# some backends need a few seconds after logging successful startup before
|
# some backends need a few seconds after logging successful startup before
|
||||||
# they can begin accepting requests
|
# they can begin accepting requests
|
||||||
await sleep(5)
|
# await sleep(5)
|
||||||
try:
|
try:
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
self.__start_healthcheck = True
|
self.__start_healthcheck = True
|
||||||
@@ -348,13 +417,13 @@ class Backend:
|
|||||||
|
|
||||||
async def tail_log():
|
async def tail_log():
|
||||||
log.debug(f"tailing file: {self.model_log_file}")
|
log.debug(f"tailing file: {self.model_log_file}")
|
||||||
async with await open_file(self.model_log_file) as f:
|
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
|
||||||
while True:
|
while True:
|
||||||
line = await f.readline()
|
line = await f.readline()
|
||||||
if line:
|
if line:
|
||||||
await handle_log_line(line.rstrip())
|
await handle_log_line(line.rstrip())
|
||||||
else:
|
else:
|
||||||
time.sleep(LOG_POLL_INTERVAL)
|
await asyncio.sleep(LOG_POLL_INTERVAL)
|
||||||
|
|
||||||
###########
|
###########
|
||||||
|
|
||||||
|
|||||||
+54
-10
@@ -3,12 +3,11 @@ import logging
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
|
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
|
||||||
from aiohttp import web, ClientResponse
|
from aiohttp import web, ClientResponse
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -66,10 +65,11 @@ class ApiPayload(ABC):
|
|||||||
class AuthData:
|
class AuthData:
|
||||||
"""data used to authenticate requester"""
|
"""data used to authenticate requester"""
|
||||||
|
|
||||||
signature: str
|
|
||||||
cost: str
|
cost: str
|
||||||
endpoint: str
|
endpoint_id: int
|
||||||
reqnum: int
|
reqnum: int
|
||||||
|
request_idx: int
|
||||||
|
signature: str
|
||||||
url: str
|
url: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -190,13 +190,34 @@ class SystemMetrics:
|
|||||||
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
self.additional_disk_usage = disk_usage - self.last_disk_usage
|
||||||
self.last_disk_usage = disk_usage
|
self.last_disk_usage = disk_usage
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, expected: float | None) -> None:
|
||||||
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
# autoscaler excepts model_loading_time to be populated only once, when the instance has
|
||||||
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
# finished benchmarking and is ready to receive requests. This applies to restarted instances
|
||||||
# as well: they should send model_loading_time once when they are done loading
|
# as well: they should send model_loading_time once when they are done loading
|
||||||
|
if self.model_loading_time == expected:
|
||||||
self.model_loading_time = None
|
self.model_loading_time = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RequestMetrics:
|
||||||
|
"""Tracks metrics for an active request."""
|
||||||
|
request_idx: int
|
||||||
|
reqnum: int
|
||||||
|
workload: float
|
||||||
|
status: str
|
||||||
|
success: bool = False
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkResult:
|
||||||
|
request_idx: int
|
||||||
|
workload: float
|
||||||
|
task: Awaitable[ClientResponse]
|
||||||
|
response: Optional[ClientResponse] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_successful(self) -> bool:
|
||||||
|
return self.response is not None and self.response.status == 200
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelMetrics:
|
class ModelMetrics:
|
||||||
"""Model specific metrics"""
|
"""Model specific metrics"""
|
||||||
@@ -206,13 +227,15 @@ class ModelMetrics:
|
|||||||
workload_received: float
|
workload_received: float
|
||||||
workload_cancelled: float
|
workload_cancelled: float
|
||||||
workload_errored: float
|
workload_errored: float
|
||||||
workload_pending: float
|
workload_rejected: float
|
||||||
# these are not
|
# these are not
|
||||||
cur_perf: float
|
workload_pending: float
|
||||||
error_msg: Optional[str]
|
error_msg: Optional[str]
|
||||||
max_throughput: float
|
max_throughput: float
|
||||||
requests_recieved: Set[int] = field(default_factory=set)
|
requests_recieved: Set[int] = field(default_factory=set)
|
||||||
requests_working: Set[int] = field(default_factory=set)
|
requests_working: dict[int, RequestMetrics] = field(default_factory=dict)
|
||||||
|
requests_deleting: list[RequestMetrics] = field(default_factory=list)
|
||||||
|
last_update: float = field(default_factory=time.time)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls):
|
def empty(cls):
|
||||||
@@ -221,7 +244,7 @@ class ModelMetrics:
|
|||||||
workload_served=0.0,
|
workload_served=0.0,
|
||||||
workload_cancelled=0.0,
|
workload_cancelled=0.0,
|
||||||
workload_errored=0.0,
|
workload_errored=0.0,
|
||||||
cur_perf=0.0,
|
workload_rejected=0.0,
|
||||||
workload_received=0.0,
|
workload_received=0.0,
|
||||||
error_msg=None,
|
error_msg=None,
|
||||||
max_throughput=0.0,
|
max_throughput=0.0,
|
||||||
@@ -231,6 +254,20 @@ class ModelMetrics:
|
|||||||
def workload_processing(self) -> float:
|
def workload_processing(self) -> float:
|
||||||
return max(self.workload_received - self.workload_cancelled, 0.0)
|
return max(self.workload_received - self.workload_cancelled, 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wait_time(self) -> float:
|
||||||
|
if (len(self.requests_working) == 0):
|
||||||
|
return 0.0
|
||||||
|
return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cur_load(self) -> float:
|
||||||
|
return sum([request.workload for request in self.requests_working.values()])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def working_request_idxs(self) -> list[int]:
|
||||||
|
return [req.request_idx for req in self.requests_working.values()]
|
||||||
|
|
||||||
def set_errored(self, error_msg):
|
def set_errored(self, error_msg):
|
||||||
self.reset()
|
self.reset()
|
||||||
self.error_msg = error_msg
|
self.error_msg = error_msg
|
||||||
@@ -240,15 +277,21 @@ class ModelMetrics:
|
|||||||
self.workload_received = 0
|
self.workload_received = 0
|
||||||
self.workload_cancelled = 0
|
self.workload_cancelled = 0
|
||||||
self.workload_errored = 0
|
self.workload_errored = 0
|
||||||
|
self.workload_rejected = 0
|
||||||
|
self.last_update = time.time()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AutoScalaerData:
|
class AutoScalerData:
|
||||||
"""Data that is reported to autoscaler"""
|
"""Data that is reported to autoscaler"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
|
mtoken: str
|
||||||
|
version: str
|
||||||
loadtime: float
|
loadtime: float
|
||||||
cur_load: float
|
cur_load: float
|
||||||
|
rej_load: float
|
||||||
|
new_load: float
|
||||||
error_msg: str
|
error_msg: str
|
||||||
max_perf: float
|
max_perf: float
|
||||||
cur_perf: float
|
cur_perf: float
|
||||||
@@ -257,6 +300,7 @@ class AutoScalaerData:
|
|||||||
num_requests_working: int
|
num_requests_working: int
|
||||||
num_requests_recieved: int
|
num_requests_recieved: int
|
||||||
additional_disk_usage: float
|
additional_disk_usage: float
|
||||||
|
working_request_idxs: list[int]
|
||||||
url: str
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+179
-45
@@ -5,13 +5,14 @@ import json
|
|||||||
from asyncio import sleep
|
from asyncio import sleep
|
||||||
from dataclasses import dataclass, asdict, field
|
from dataclasses import dataclass, asdict, field
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
import asyncio
|
||||||
|
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError
|
||||||
|
|
||||||
import requests
|
from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics
|
||||||
|
|
||||||
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
|
|
||||||
from typing import Awaitable, NoReturn, List
|
from typing import Awaitable, NoReturn, List
|
||||||
|
|
||||||
METRICS_UPDATE_INTERVAL = 1
|
METRICS_UPDATE_INTERVAL = 1
|
||||||
|
DELETE_REQUESTS_INTERVAL = 1
|
||||||
|
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
@@ -26,7 +27,10 @@ def get_url() -> str:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Metrics:
|
class Metrics:
|
||||||
|
version: str = "0"
|
||||||
|
mtoken: str = ""
|
||||||
last_metric_update: float = 0.0
|
last_metric_update: float = 0.0
|
||||||
|
last_request_served: float = 0.0
|
||||||
update_pending: bool = False
|
update_pending: bool = False
|
||||||
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
|
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
|
||||||
report_addr: List[str] = field(
|
report_addr: List[str] = field(
|
||||||
@@ -35,44 +39,84 @@ class Metrics:
|
|||||||
url: str = field(default_factory=get_url)
|
url: str = field(default_factory=get_url)
|
||||||
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
|
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
|
||||||
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
|
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
|
||||||
|
_session: ClientSession | None = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
def _request_start(self, workload: float, reqnum: int) -> None:
|
async def http(self) -> ClientSession:
|
||||||
|
if self._session is None:
|
||||||
|
self._session = ClientSession(
|
||||||
|
timeout=ClientTimeout(total=10),
|
||||||
|
connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True)
|
||||||
|
)
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
if self._session is not None:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
def _request_start(self, request: RequestMetrics) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called prior to forwarding a request to a model API.
|
this function is called prior to forwarding a request to a model API.
|
||||||
"""
|
"""
|
||||||
log.debug("request start")
|
log.debug("request start")
|
||||||
self.model_metrics.workload_pending += workload
|
request.status = "Started"
|
||||||
self.model_metrics.workload_received += workload
|
self.model_metrics.workload_pending += request.workload
|
||||||
self.model_metrics.requests_recieved.add(reqnum)
|
self.model_metrics.workload_received += request.workload
|
||||||
self.model_metrics.requests_working.add(reqnum)
|
self.model_metrics.requests_recieved.add(request.reqnum)
|
||||||
|
self.model_metrics.requests_working[request.reqnum] = request
|
||||||
def _request_end(
|
|
||||||
self, workload: float, req_response_time: float, reqnum: int
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
this function is called after a response from model API is received.
|
|
||||||
"""
|
|
||||||
self.model_metrics.workload_served += workload
|
|
||||||
self.model_metrics.workload_pending -= workload
|
|
||||||
self.model_metrics.requests_working.discard(reqnum)
|
|
||||||
self.model_metrics.cur_perf = workload / req_response_time
|
|
||||||
self.update_pending = True
|
self.update_pending = True
|
||||||
|
|
||||||
def _request_errored(self, workload: float, reqnum: int) -> None:
|
def _request_end(self, request: RequestMetrics) -> None:
|
||||||
|
"""
|
||||||
|
this function is called after handling of a request ends, regardless of the outcome
|
||||||
|
"""
|
||||||
|
self.model_metrics.workload_pending -= request.workload
|
||||||
|
self.model_metrics.requests_working.pop(request.reqnum, None)
|
||||||
|
self.model_metrics.requests_deleting.append(request)
|
||||||
|
self.last_request_served = time.time()
|
||||||
|
|
||||||
|
def _request_success(self, request: RequestMetrics) -> None:
|
||||||
|
"""
|
||||||
|
this function is called after a response from model API is received and forwarded.
|
||||||
|
"""
|
||||||
|
self.model_metrics.workload_served += request.workload
|
||||||
|
request.status = "Success"
|
||||||
|
request.success = True
|
||||||
|
self.update_pending = True
|
||||||
|
|
||||||
|
def _request_errored(self, request: RequestMetrics) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called if model API returns an error
|
this function is called if model API returns an error
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_pending -= workload
|
self.model_metrics.workload_errored += request.workload
|
||||||
self.model_metrics.workload_errored += workload
|
request.status = "Error"
|
||||||
self.model_metrics.requests_working.discard(reqnum)
|
request.success = False
|
||||||
|
self.update_pending = True
|
||||||
|
|
||||||
def _request_canceled(self, workload: float, reqnum: int) -> None:
|
def _request_canceled(self, request: RequestMetrics) -> None:
|
||||||
"""
|
"""
|
||||||
this function is called if client drops connection before model API has responded
|
this function is called if client drops connection before model API has responded
|
||||||
"""
|
"""
|
||||||
self.model_metrics.workload_pending -= workload
|
self.model_metrics.workload_cancelled += request.workload
|
||||||
self.model_metrics.workload_cancelled += workload
|
request.success = True
|
||||||
self.model_metrics.requests_working.discard(reqnum)
|
request.status = "Cancelled"
|
||||||
|
|
||||||
|
def _request_reject(self, request: RequestMetrics):
|
||||||
|
"""
|
||||||
|
this function is called if the current wait time for the model is above max_wait_time
|
||||||
|
"""
|
||||||
|
self.model_metrics.requests_recieved.add(request.reqnum)
|
||||||
|
self.model_metrics.requests_deleting.append(request)
|
||||||
|
self.model_metrics.workload_rejected += request.workload
|
||||||
|
request.success = False
|
||||||
|
request.status = "Rejected"
|
||||||
|
self.update_pending = True
|
||||||
|
|
||||||
|
async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]:
|
||||||
|
while True:
|
||||||
|
await sleep(DELETE_REQUESTS_INTERVAL)
|
||||||
|
if len(self.model_metrics.requests_deleting) > 0:
|
||||||
|
await self.__send_delete_requests_and_reset()
|
||||||
|
|
||||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||||
while True:
|
while True:
|
||||||
@@ -80,10 +124,10 @@ class Metrics:
|
|||||||
elapsed = time.time() - self.last_metric_update
|
elapsed = time.time() - self.last_metric_update
|
||||||
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
||||||
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
||||||
self.__send_metrics_and_reset(elapsed)
|
await self.__send_metrics_and_reset()
|
||||||
elif self.update_pending or elapsed > 10:
|
elif self.update_pending or elapsed > 10:
|
||||||
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
||||||
self.__send_metrics_and_reset(elapsed)
|
await self.__send_metrics_and_reset()
|
||||||
|
|
||||||
def _model_loaded(self, max_throughput: float) -> None:
|
def _model_loaded(self, max_throughput: float) -> None:
|
||||||
self.system_metrics.model_loading_time = (
|
self.system_metrics.model_loading_time = (
|
||||||
@@ -96,57 +140,147 @@ class Metrics:
|
|||||||
self.model_metrics.set_errored(error_msg)
|
self.model_metrics.set_errored(error_msg)
|
||||||
self.system_metrics.model_is_loaded = True
|
self.system_metrics.model_is_loaded = True
|
||||||
|
|
||||||
|
def _set_version(self, version: str) -> None:
|
||||||
|
self.version = version
|
||||||
|
|
||||||
|
def _set_mtoken(self, mtoken: str) -> None:
|
||||||
|
self.mtoken = mtoken
|
||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
|
|
||||||
def __send_metrics_and_reset(self, elapsed):
|
async def __send_delete_requests_and_reset(self):
|
||||||
|
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||||
|
data = {
|
||||||
|
"worker_id": self.id,
|
||||||
|
"mtoken": self.mtoken,
|
||||||
|
"request_idxs": idxs,
|
||||||
|
"success": success_flag,
|
||||||
|
}
|
||||||
|
log.debug(
|
||||||
|
f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}"
|
||||||
|
)
|
||||||
|
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
||||||
|
for attempt in range(1, 4):
|
||||||
|
try:
|
||||||
|
session = await self.http()
|
||||||
|
async with session.post(full_path, json=data) as res:
|
||||||
|
log.debug(f"delete_requests response: {res.status}")
|
||||||
|
res.raise_for_status()
|
||||||
|
return True
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
log.debug("delete_requests timed out")
|
||||||
|
except (ClientResponseError, Exception) as e:
|
||||||
|
log.debug(f"delete_requests failed with error: {e}")
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
log.debug(f"retrying delete_request, attempt: {attempt}")
|
||||||
|
return False
|
||||||
|
|
||||||
def compute_autoscaler_data() -> AutoScalaerData:
|
# Take a snapshot of what we plan to send this tick.
|
||||||
return AutoScalaerData(
|
# New arrivals after this snapshot will remain in the queue for the next tick.
|
||||||
|
snapshot = list(self.model_metrics.requests_deleting)
|
||||||
|
success_idxs = [r.request_idx for r in snapshot if r.success is True]
|
||||||
|
failed_idxs = [r.request_idx for r in snapshot if r.success is False]
|
||||||
|
|
||||||
|
if not success_idxs and not failed_idxs:
|
||||||
|
return # nothing to do
|
||||||
|
|
||||||
|
for report_addr in self.report_addr:
|
||||||
|
# TODO: Add a Redis subscriber queue for delete_requests
|
||||||
|
if report_addr == "https://cloud.vast.ai/api/v0":
|
||||||
|
# Patch: ignore the Redis API report_addr
|
||||||
|
continue
|
||||||
|
sent_success = True
|
||||||
|
sent_failed = True
|
||||||
|
|
||||||
|
if success_idxs:
|
||||||
|
sent_success = await post(report_addr, success_idxs, True)
|
||||||
|
if failed_idxs:
|
||||||
|
sent_failed = await post(report_addr, failed_idxs, False)
|
||||||
|
|
||||||
|
if sent_success and sent_failed:
|
||||||
|
# Remove only the items we actually sent from the live queue.
|
||||||
|
sent_set = set(success_idxs) | set(failed_idxs)
|
||||||
|
self.model_metrics.requests_deleting[:] = [
|
||||||
|
r for r in self.model_metrics.requests_deleting
|
||||||
|
if r.request_idx not in sent_set
|
||||||
|
]
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
async def __send_metrics_and_reset(self):
|
||||||
|
|
||||||
|
loadtime_snapshot = self.system_metrics.model_loading_time
|
||||||
|
|
||||||
|
def compute_autoscaler_data() -> AutoScalerData:
|
||||||
|
return AutoScalerData(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
loadtime=(self.system_metrics.model_loading_time or 0.0),
|
mtoken=self.mtoken,
|
||||||
cur_load=(self.model_metrics.workload_processing / elapsed),
|
version=self.version,
|
||||||
|
loadtime=(loadtime_snapshot or 0.0),
|
||||||
|
new_load=self.model_metrics.workload_processing,
|
||||||
|
cur_load=self.model_metrics.cur_load,
|
||||||
|
rej_load=self.model_metrics.workload_rejected,
|
||||||
max_perf=self.model_metrics.max_throughput,
|
max_perf=self.model_metrics.max_throughput,
|
||||||
cur_perf=self.model_metrics.cur_perf,
|
cur_perf=self.model_metrics.workload_served,
|
||||||
error_msg=self.model_metrics.error_msg or "",
|
error_msg=self.model_metrics.error_msg or "",
|
||||||
num_requests_working=len(self.model_metrics.requests_working),
|
num_requests_working=len(self.model_metrics.requests_working),
|
||||||
num_requests_recieved=len(self.model_metrics.requests_recieved),
|
num_requests_recieved=len(self.model_metrics.requests_recieved),
|
||||||
additional_disk_usage=self.system_metrics.additional_disk_usage,
|
additional_disk_usage=self.system_metrics.additional_disk_usage,
|
||||||
|
working_request_idxs=self.model_metrics.working_request_idxs,
|
||||||
cur_capacity=0,
|
cur_capacity=0,
|
||||||
max_capacity=0,
|
max_capacity=0,
|
||||||
url=self.url,
|
url=self.url,
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_data(report_addr: str) -> None:
|
async def send_data(report_addr: str) -> bool:
|
||||||
data = compute_autoscaler_data()
|
data = compute_autoscaler_data()
|
||||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
log_data = asdict(data)
|
||||||
|
def obfuscate(secret: str) -> str:
|
||||||
|
if secret is None:
|
||||||
|
return ""
|
||||||
|
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
||||||
|
|
||||||
|
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
||||||
log.debug(
|
log.debug(
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
f"sending data to autoscaler",
|
f"sending data to autoscaler",
|
||||||
f"{json.dumps((asdict(data)), indent=2)}",
|
f"{json.dumps(log_data, indent=2)}",
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||||
for attempt in range(1, 4):
|
for attempt in range(1, 4):
|
||||||
try:
|
try:
|
||||||
requests.post(full_path, json=asdict(data), timeout=1)
|
session = await self.http()
|
||||||
break
|
async with session.post(full_path, json=asdict(data)) as res:
|
||||||
except requests.Timeout:
|
res.raise_for_status()
|
||||||
|
return True
|
||||||
|
except asyncio.TimeoutError:
|
||||||
log.debug(f"autoscaler status update timed out")
|
log.debug(f"autoscaler status update timed out")
|
||||||
except Exception as e:
|
except (ClientResponseError, Exception) as e:
|
||||||
log.debug(f"autoscaler status update failed with error: {e}")
|
log.debug(f"autoscaler status update failed with error: {e}")
|
||||||
time.sleep(2)
|
await asyncio.sleep(2)
|
||||||
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
|
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
|
||||||
|
log.debug(f"failed to send update through {report_addr}")
|
||||||
|
return False
|
||||||
|
|
||||||
###########
|
###########
|
||||||
|
|
||||||
self.system_metrics.update_disk_usage()
|
self.system_metrics.update_disk_usage()
|
||||||
|
|
||||||
|
sent = False
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
send_data(report_addr)
|
if await send_data(report_addr):
|
||||||
|
sent = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if sent:
|
||||||
|
# clear the one-shot loadtime only if we actually sent *this* value
|
||||||
|
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||||
self.update_pending = False
|
self.update_pending = False
|
||||||
self.model_metrics.reset()
|
self.model_metrics.reset()
|
||||||
self.system_metrics.reset()
|
|
||||||
self.last_metric_update = time.time()
|
self.last_metric_update = time.time()
|
||||||
|
|||||||
+22
-2
@@ -3,15 +3,17 @@ import logging
|
|||||||
from typing import List
|
from typing import List
|
||||||
import ssl
|
import ssl
|
||||||
from asyncio import run, gather
|
from asyncio import run, gather
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from lib.backend import Backend
|
from lib.backend import Backend
|
||||||
|
from lib.metrics import Metrics
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||||
|
try:
|
||||||
log.debug("getting certificate...")
|
log.debug("getting certificate...")
|
||||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||||
if use_ssl is True:
|
if use_ssl is True:
|
||||||
@@ -27,7 +29,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
|||||||
log.debug("starting server...")
|
log.debug("starting server...")
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
runner = web.AppRunner(app, handler_cancellation=True)
|
runner = web.AppRunner(app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(
|
site = web.TCPSite(
|
||||||
runner,
|
runner,
|
||||||
@@ -38,3 +40,21 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
|||||||
await gather(site.start(), backend._start_tracking())
|
await gather(site.start(), backend._start_tracking())
|
||||||
|
|
||||||
run(main())
|
run(main())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
err_msg = f"PyWorker failed to launch: {e}"
|
||||||
|
log.error(err_msg)
|
||||||
|
|
||||||
|
async def beacon():
|
||||||
|
metrics = Metrics()
|
||||||
|
metrics._set_version(getattr(backend, "version", "0"))
|
||||||
|
metrics._set_mtoken(getattr(backend, "mtoken", ""))
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
metrics._model_errored(err_msg)
|
||||||
|
await metrics._Metrics__send_metrics_and_reset()
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
finally:
|
||||||
|
await metrics.aclose()
|
||||||
|
|
||||||
|
run(beacon())
|
||||||
|
|||||||
+16
-9
@@ -10,6 +10,7 @@ from collections import Counter
|
|||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from lib.data_types import AuthData, ApiPayload
|
from lib.data_types import AuthData, ApiPayload
|
||||||
@@ -74,6 +75,7 @@ def print_truncate_res(res: str):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ClientState:
|
class ClientState:
|
||||||
endpoint_group_name: str
|
endpoint_group_name: str
|
||||||
|
endpoint_id: int
|
||||||
api_key: str
|
api_key: str
|
||||||
server_url: str
|
server_url: str
|
||||||
worker_endpoint: str
|
worker_endpoint: str
|
||||||
@@ -94,7 +96,7 @@ class ClientState:
|
|||||||
self.status = ClientStatus.Error
|
self.status = ClientStatus.Error
|
||||||
return
|
return
|
||||||
route_payload = {
|
route_payload = {
|
||||||
"endpoint": self.endpoint_group_name,
|
"endpoint_id": self.endpoint_id,
|
||||||
"api_key": self.api_key,
|
"api_key": self.api_key,
|
||||||
"cost": self.payload.count_workload(),
|
"cost": self.payload.count_workload(),
|
||||||
}
|
}
|
||||||
@@ -120,9 +122,11 @@ class ClientState:
|
|||||||
self.url = worker_address
|
self.url = worker_address
|
||||||
url = urljoin(worker_address, self.worker_endpoint)
|
url = urljoin(worker_address, self.worker_endpoint)
|
||||||
self.status = ClientStatus.Generating
|
self.status = ClientStatus.Generating
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
self.infer_error.append(
|
self.infer_error.append(
|
||||||
@@ -241,16 +245,19 @@ def run_test(
|
|||||||
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
|
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
|
||||||
print_thread.daemon = True # makes threads get killed on program exit
|
print_thread.daemon = True # makes threads get killed on program exit
|
||||||
print_thread.start()
|
print_thread.start()
|
||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
endpoint_info = Endpoint.get_endpoint_info(
|
||||||
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
|
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
|
||||||
)
|
)
|
||||||
if not endpoint_api_key:
|
if not endpoint_info:
|
||||||
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
||||||
return
|
return
|
||||||
|
endpoint_id = endpoint_info["id"]
|
||||||
|
endpoint_api_key = endpoint_info["api_key"]
|
||||||
try:
|
try:
|
||||||
for _ in range(num_requests):
|
for _ in range(num_requests):
|
||||||
client = ClientState(
|
client = ClientState(
|
||||||
endpoint_group_name=endpoint_group_name,
|
endpoint_group_name=endpoint_group_name,
|
||||||
|
endpoint_id=endpoint_id,
|
||||||
api_key=endpoint_api_key,
|
api_key=endpoint_api_key,
|
||||||
server_url=server_url,
|
server_url=server_url,
|
||||||
worker_endpoint=worker_endpoint,
|
worker_endpoint=worker_endpoint,
|
||||||
@@ -289,12 +296,12 @@ def test_load_cmd(
|
|||||||
args = arg_parser.parse_args()
|
args = arg_parser.parse_args()
|
||||||
if hasattr(args, "comfy_model"):
|
if hasattr(args, "comfy_model"):
|
||||||
os.environ["COMFY_MODEL"] = args.comfy_model
|
os.environ["COMFY_MODEL"] = args.comfy_model
|
||||||
server_url = dict(
|
server_url = {
|
||||||
prod="https://run.vast.ai",
|
"prod": "https://run.vast.ai",
|
||||||
alpha="https://run-alpha.vast.ai",
|
"alpha": "https://run-alpha.vast.ai",
|
||||||
candidate="https://run-candidate.vast.ai",
|
"candidate": "https://run-candidate.vast.ai",
|
||||||
local="http://localhost:8080",
|
"local": "http://localhost:8080",
|
||||||
)[args.instance]
|
}.get(args.instance, "http://localhost:8080")
|
||||||
run_test(
|
run_test(
|
||||||
num_requests=args.num_requests,
|
num_requests=args.num_requests,
|
||||||
requests_per_second=args.requests_per_second,
|
requests_per_second=args.requests_per_second,
|
||||||
|
|||||||
+3
-2
@@ -1,4 +1,4 @@
|
|||||||
aiohttp~=3.11
|
aiohttp[speedups]==3.10.1
|
||||||
anyio~=4.4
|
anyio~=4.4
|
||||||
lib~=4.0
|
lib~=4.0
|
||||||
nltk~=3.9
|
nltk~=3.9
|
||||||
@@ -6,5 +6,6 @@ psutil~=6.0
|
|||||||
pycryptodome~=3.20
|
pycryptodome~=3.20
|
||||||
Requests~=2.32
|
Requests~=2.32
|
||||||
transformers~=4.52
|
transformers~=4.52
|
||||||
utils~=1.0
|
utils==1.0.*
|
||||||
hf_transfer>=0.1.9
|
hf_transfer>=0.1.9
|
||||||
|
vastai-sdk>=0.2.0
|
||||||
+65
-10
@@ -41,24 +41,45 @@ echo_var DEBUG_LOG
|
|||||||
echo_var PYWORKER_LOG
|
echo_var PYWORKER_LOG
|
||||||
echo_var MODEL_LOG
|
echo_var MODEL_LOG
|
||||||
|
|
||||||
env | grep _ >> /etc/environment;
|
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
||||||
|
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||||
|
if [ -e "$MODEL_LOG" ]; then
|
||||||
|
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
|
||||||
|
cat "$MODEL_LOG" >> "$MODEL_LOG.old"
|
||||||
|
: > "$MODEL_LOG"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Populate /etc/environment with quoted values
|
||||||
|
if ! grep -q "VAST" /etc/environment; then
|
||||||
|
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
|
||||||
|
name=${line%%=*}
|
||||||
|
value=${line#*=}
|
||||||
|
printf '%s="%s"\n' "$name" "$value"
|
||||||
|
done > /etc/environment
|
||||||
|
fi
|
||||||
|
|
||||||
if [ ! -d "$ENV_PATH" ]
|
if [ ! -d "$ENV_PATH" ]
|
||||||
then
|
then
|
||||||
echo "setting up venv"
|
echo "setting up venv"
|
||||||
|
if ! which uv; then
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
source ~/.local/bin/env
|
source ~/.local/bin/env
|
||||||
git clone https://github.com/vast-ai/pyworker "$SERVER_DIR"
|
fi
|
||||||
|
|
||||||
uv venv --managed-python "$WORKSPACE_DIR/worker-env" -p 3.10
|
# Fork testing
|
||||||
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
|
||||||
|
if [[ -n ${PYWORKER_REF:-} ]]; then
|
||||||
|
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
|
||||||
|
fi
|
||||||
|
|
||||||
uv pip install -r vast-pyworker/requirements.txt
|
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
|
||||||
|
source "$ENV_PATH/bin/activate"
|
||||||
|
|
||||||
|
uv pip install -r "${SERVER_DIR}/requirements.txt"
|
||||||
|
|
||||||
touch ~/.no_auto_tmux
|
touch ~/.no_auto_tmux
|
||||||
else
|
else
|
||||||
source ~/.local/bin/env
|
[[ -f ~/.local/bin/env ]] && source ~/.local/bin/env
|
||||||
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
source "$WORKSPACE_DIR/worker-env/bin/activate"
|
||||||
echo "environment activated"
|
echo "environment activated"
|
||||||
echo "venv: $VIRTUAL_ENV"
|
echo "venv: $VIRTUAL_ENV"
|
||||||
@@ -111,9 +132,43 @@ cd "$SERVER_DIR"
|
|||||||
|
|
||||||
echo "launching PyWorker server"
|
echo "launching PyWorker server"
|
||||||
|
|
||||||
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
set +e
|
||||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
|
||||||
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
PY_STATUS=${PIPESTATUS[0]}
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if [ "${PY_STATUS}" -ne 0 ]; then
|
||||||
|
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
|
||||||
|
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
|
||||||
|
MTOKEN="${MASTER_TOKEN:-}"
|
||||||
|
VERSION="${PYWORKER_VERSION:-0}"
|
||||||
|
|
||||||
|
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
|
||||||
|
for addr in "${REPORT_ADDRS[@]}"; do
|
||||||
|
curl -sS -X POST -H 'Content-Type: application/json' \
|
||||||
|
-d "$(cat <<JSON
|
||||||
|
{
|
||||||
|
"id": ${CONTAINER_ID:-0},
|
||||||
|
"mtoken": "${MTOKEN}",
|
||||||
|
"version": "${VERSION}",
|
||||||
|
"loadtime": 0,
|
||||||
|
"new_load": 0,
|
||||||
|
"cur_load": 0,
|
||||||
|
"rej_load": 0,
|
||||||
|
"max_perf": 0,
|
||||||
|
"cur_perf": 0,
|
||||||
|
"error_msg": "${ERROR_MSG}",
|
||||||
|
"num_requests_working": 0,
|
||||||
|
"num_requests_recieved": 0,
|
||||||
|
"additional_disk_usage": 0,
|
||||||
|
"working_request_idxs": [],
|
||||||
|
"cur_capacity": 0,
|
||||||
|
"max_capacity": 0,
|
||||||
|
"url": "${URL}"
|
||||||
|
}
|
||||||
|
JSON
|
||||||
|
)" "${addr%/}/worker_status/" || true
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
|
||||||
echo "launching PyWorker server done"
|
echo "launching PyWorker server done"
|
||||||
+61
-5
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
import time
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -16,6 +17,60 @@ class Endpoint:
|
|||||||
Utility class for handling endpoint operations.
|
Utility class for handling endpoint operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_endpoint_info(
|
||||||
|
endpoint_name: str, account_api_key: str, instance: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||||
|
url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}"
|
||||||
|
# Retry a few times to smooth over transient propagation/network delays
|
||||||
|
for attempt in range(4):
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=headers, timeout=8)
|
||||||
|
if response.status_code != 200:
|
||||||
|
# brief backoff and retry
|
||||||
|
time.sleep(0.3 * (attempt + 1))
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except Exception:
|
||||||
|
# JSON parse failed; backoff and retry
|
||||||
|
time.sleep(0.3 * (attempt + 1))
|
||||||
|
continue
|
||||||
|
result = data.get("results", []) if isinstance(data, dict) else []
|
||||||
|
endpoint = next(
|
||||||
|
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if endpoint and endpoint.get("id") and endpoint.get("api_key"):
|
||||||
|
return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")}
|
||||||
|
except Exception:
|
||||||
|
# network or other transient error; retry
|
||||||
|
time.sleep(0.3 * (attempt + 1))
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_autoscaler_server_url(instance: str) -> str:
|
||||||
|
endpoints = {
|
||||||
|
"alpha": "run-alpha",
|
||||||
|
"candidate": "run-candidate",
|
||||||
|
"prod": "run",
|
||||||
|
}
|
||||||
|
host = endpoints.get(instance)
|
||||||
|
if host:
|
||||||
|
return f"https://{host}.vast.ai/"
|
||||||
|
return "http://localhost:8080"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_server_url(instance: str) -> str:
|
||||||
|
endpoints = {
|
||||||
|
"alpha": "alpha",
|
||||||
|
"candidate": "candidate",
|
||||||
|
"prod": "console",
|
||||||
|
}
|
||||||
|
host = endpoints.get(instance, "alpha")
|
||||||
|
return f"https://{host}.vast.ai/api/v0/endptjobs/"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_endpoint_api_key(
|
def get_endpoint_api_key(
|
||||||
endpoint_name: str, account_api_key: str, instance: str
|
endpoint_name: str, account_api_key: str, instance: str
|
||||||
@@ -30,13 +85,14 @@ class Endpoint:
|
|||||||
Returns:
|
Returns:
|
||||||
Endpoint API key if successful, None otherwise
|
Endpoint API key if successful, None otherwise
|
||||||
"""
|
"""
|
||||||
vast_console_url = "https://console.vast.ai/api/v0/endptjobs/"
|
|
||||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{vast_console_url}?autoscaler_instance={instance}", headers=headers
|
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
|
||||||
|
headers=headers,
|
||||||
|
timeout=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
@@ -46,14 +102,14 @@ class Endpoint:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
except requests.exceptions.JSONDecodeError as e:
|
except Exception as e:
|
||||||
log.debug(f"Failed to parse JSON response: {e}")
|
log.debug(f"Failed to parse JSON response: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
result = data.get("results", [])
|
result = data.get("results", [])
|
||||||
|
|
||||||
endpoint: Optional[Dict[str, Any]] = next(
|
endpoint: Optional[Dict[str, Any]] = next(
|
||||||
(item for item in result if item["endpoint_name"] == endpoint_name),
|
(item for item in result if item.get("endpoint_name") == endpoint_name),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if not endpoint:
|
if not endpoint:
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
import tempfile
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_cert_file_path():
|
||||||
|
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
|
||||||
|
response = requests.get(cert_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
# Use a temporary file that is not deleted on close
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
return f.name
|
||||||
@@ -0,0 +1,304 @@
|
|||||||
|
# ComfyUI PyWorker
|
||||||
|
|
||||||
|
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||||
|
|
||||||
|
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
|
||||||
|
|
||||||
|
## Instance Setup
|
||||||
|
|
||||||
|
1. Pick a template
|
||||||
|
|
||||||
|
- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless))
|
||||||
|
|
||||||
|
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [ComfyUI API Wrapper](https://github.com/ai-dock/comfyui-api-wrapper).
|
||||||
|
|
||||||
|
A docker image is provided but you may use any if the above requirements are met.
|
||||||
|
|
||||||
|
## Client
|
||||||
|
|
||||||
|
The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/vast-ai/pyworker
|
||||||
|
cd pyworker
|
||||||
|
pip install uv
|
||||||
|
uv venv -p 3.12
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Set your API key:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export VAST_API_KEY=<your_api_key>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Default prompt
|
||||||
|
python -m workers.comfyui-json.client
|
||||||
|
|
||||||
|
# Custom prompt
|
||||||
|
python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow"
|
||||||
|
|
||||||
|
# With options
|
||||||
|
python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30
|
||||||
|
|
||||||
|
# Using a custom workflow file
|
||||||
|
python -m workers.comfyui-json.client --workflow my_workflow.json
|
||||||
|
|
||||||
|
# With S3 upload
|
||||||
|
python -m workers.comfyui-json.client --s3
|
||||||
|
```
|
||||||
|
|
||||||
|
### CLI Flags
|
||||||
|
|
||||||
|
| Flag | Default | Description |
|
||||||
|
|------|---------|-------------|
|
||||||
|
| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name |
|
||||||
|
| `--prompt` | (default) | Text prompt for image generation |
|
||||||
|
| `--workflow` | (none) | Path to custom workflow JSON file |
|
||||||
|
| `--width` | 512 | Image width in pixels |
|
||||||
|
| `--height` | 512 | Image height in pixels |
|
||||||
|
| `--steps` | 20 | Number of denoising steps |
|
||||||
|
| `--seed` | (random) | Random seed for reproducibility |
|
||||||
|
| `--s3` | (disabled) | Upload generated images to S3 |
|
||||||
|
|
||||||
|
### Output
|
||||||
|
|
||||||
|
Images are saved to `./generated_images/comfy_{seed}.png`.
|
||||||
|
|
||||||
|
### S3 Upload (Optional)
|
||||||
|
|
||||||
|
You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag.
|
||||||
|
|
||||||
|
**1. Set environment variables:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com"
|
||||||
|
export S3_BUCKET_NAME="my-bucket"
|
||||||
|
export S3_ACCESS_KEY_ID="your-access-key-id"
|
||||||
|
export S3_SECRET_ACCESS_KEY="your-secret-access-key"
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Run with S3 upload enabled:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3
|
||||||
|
```
|
||||||
|
|
||||||
|
Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`.
|
||||||
|
|
||||||
|
**Note:** Requires `boto3` (`pip install boto3`).
|
||||||
|
|
||||||
|
## Benchmarking
|
||||||
|
|
||||||
|
### Custom Benchmark Workflows
|
||||||
|
|
||||||
|
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
|
||||||
|
|
||||||
|
**Ways to provide the benchmark file:**
|
||||||
|
- Fork this repository and add your `benchmark.json` file
|
||||||
|
- Write the file during worker provisioning (onstart script or setup phase)
|
||||||
|
|
||||||
|
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
|
||||||
|
|
||||||
|
### Default Benchmark (Fallback)
|
||||||
|
|
||||||
|
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
|
||||||
|
|
||||||
|
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
|
||||||
|
|
||||||
|
| Environment Variable | Default Value | Description |
|
||||||
|
| -------------------- | ------------- | ----------- |
|
||||||
|
| BENCHMARK_TEST_WIDTH | 512 | Image width (pixels) |
|
||||||
|
| BENCHMARK_TEST_HEIGHT | 512 | Image height (pixels) |
|
||||||
|
| BENCHMARK_TEST_STEPS | 20 | Number of denoising steps |
|
||||||
|
|
||||||
|
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
||||||
|
|
||||||
|
#### Calibrating Fallback Benchmark Duration
|
||||||
|
|
||||||
|
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
||||||
|
|
||||||
|
**Example:** If your typical workflow should complete in 90 seconds on acceptable hardware:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Measure it/sec on your reference machine
|
||||||
|
# RTX 4090 typically achieves ~43 it/sec with SD1.5
|
||||||
|
|
||||||
|
# 2. Calculate required steps
|
||||||
|
# 90 seconds × 43 it/sec = 3870 steps
|
||||||
|
|
||||||
|
# 3. Configure benchmark
|
||||||
|
export BENCHMARK_TEST_STEPS=3870
|
||||||
|
|
||||||
|
# 4. Machines completing significantly slower than 90s indicate hardware issues
|
||||||
|
```
|
||||||
|
|
||||||
|
**Performance expectations:**
|
||||||
|
- Benchmark duration should remain consistent across identical GPU models
|
||||||
|
- Significant variation (>20%) may indicate thermal, power, or configuration issues
|
||||||
|
|
||||||
|
## Endpoint
|
||||||
|
|
||||||
|
The worker provides a single endpoint:
|
||||||
|
|
||||||
|
- `/generate/sync`: Processes ComfyUI workflows using either predefined modifiers or custom workflow JSON
|
||||||
|
|
||||||
|
## Request Format
|
||||||
|
|
||||||
|
The worker accepts requests in the following format. Choose either modifier mode OR custom workflow mode:
|
||||||
|
|
||||||
|
**Modifier Mode:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"request_id": "uuid-string", // optional - UUID generated if not provided
|
||||||
|
"modifier": "RawWorkflow",
|
||||||
|
"modifications": {
|
||||||
|
"prompt": "a beautiful landscape",
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"steps": 20,
|
||||||
|
"seed": 123456789
|
||||||
|
},
|
||||||
|
"s3": { ... }, // optional
|
||||||
|
"webhook": { ... } // optional
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Custom Workflow Mode:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"request_id": "uuid-string", // optional - UUID generated if not provided
|
||||||
|
"workflow_json": {
|
||||||
|
// Complete ComfyUI workflow JSON
|
||||||
|
},
|
||||||
|
"s3": { ... }, // optional
|
||||||
|
"webhook": { ... } // optional
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Request Fields
|
||||||
|
|
||||||
|
### Required Fields
|
||||||
|
|
||||||
|
- **`input`**: Contains the main workflow data
|
||||||
|
- **`input.request_id`**: Unique identifier for the request
|
||||||
|
|
||||||
|
### Workflow Mode (Choose One)
|
||||||
|
|
||||||
|
You must provide either `modifier` OR `workflow_json`, but not both:
|
||||||
|
|
||||||
|
#### Option 1: Modifier Mode
|
||||||
|
- **`input.modifier`**: Name of the predefined workflow modifier (e.g., "Text2Image")
|
||||||
|
- **`input.modifications`**: Parameters to pass to the modifier
|
||||||
|
|
||||||
|
#### Option 2: Custom Workflow Mode
|
||||||
|
- **`input.workflow_json`**: Complete ComfyUI workflow JSON
|
||||||
|
|
||||||
|
### Optional Fields
|
||||||
|
|
||||||
|
- **`input.s3`**: S3 configuration for file storage
|
||||||
|
- **`input.webhook`**: Webhook configuration for notifications
|
||||||
|
|
||||||
|
These configurations can be provided in the request JSON or via environment variables. Request-level configuration takes precedence over environment variables.
|
||||||
|
|
||||||
|
#### S3 Configuration
|
||||||
|
|
||||||
|
**Via Request JSON:**
|
||||||
|
```json
|
||||||
|
"s3": {
|
||||||
|
"access_key_id": "your-s3-access-key",
|
||||||
|
"secret_access_key": "your-s3-secret-access-key",
|
||||||
|
"endpoint_url": "https://my-endpoint.backblaze.com",
|
||||||
|
"bucket_name": "your-bucket",
|
||||||
|
"region": "us-east-1"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Via Environment Variables:**
|
||||||
|
```bash
|
||||||
|
S3_ACCESS_KEY_ID=your-key
|
||||||
|
S3_SECRET_ACCESS_KEY=your-secret
|
||||||
|
S3_BUCKET_NAME=your-bucket
|
||||||
|
S3_ENDPOINT_URL=https://s3.amazonaws.com
|
||||||
|
S3_REGION=us-east-1
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Webhook Configuration
|
||||||
|
|
||||||
|
**Via Request JSON:**
|
||||||
|
```json
|
||||||
|
"webhook": {
|
||||||
|
"url": "your-webhook-url",
|
||||||
|
"extra_params": {
|
||||||
|
"custom_field": "value"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Via Environment Variables:**
|
||||||
|
```bash
|
||||||
|
WEBHOOK_URL=https://your-webhook.com # Default webhook URL
|
||||||
|
WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
### Basic Text-to-Image (Modifier Mode)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"modifier": "Text2Image",
|
||||||
|
"modifications": {
|
||||||
|
"prompt": "a cat sitting on a windowsill",
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
"steps": 20,
|
||||||
|
"seed": 42
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Workflow Mode
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"request_id": "67890", // optional - using custom ID for tracking
|
||||||
|
"workflow_json": {
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"seed": 42,
|
||||||
|
"steps": 20,
|
||||||
|
"cfg": 8,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"denoise": 1,
|
||||||
|
"model": ["4", 0],
|
||||||
|
"positive": ["6", 0],
|
||||||
|
"negative": ["7", 0],
|
||||||
|
"latent_image": ["5", 0]
|
||||||
|
},
|
||||||
|
"class_type": "KSampler"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
@@ -0,0 +1,312 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import random
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from vastai import Serverless
|
||||||
|
|
||||||
|
# ---------------------- Config ----------------------
|
||||||
|
DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed"
|
||||||
|
ENDPOINT_NAME = "my-comfyui-endpoint"
|
||||||
|
DEFAULT_WIDTH = 512
|
||||||
|
DEFAULT_HEIGHT = 512
|
||||||
|
DEFAULT_STEPS = 20
|
||||||
|
COST = 100 # Fixed cost for ComfyUI requests
|
||||||
|
|
||||||
|
# Optional S3 Configuration (from environment variables)
|
||||||
|
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
|
||||||
|
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
||||||
|
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
|
||||||
|
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_s3_client():
|
||||||
|
"""Create and return an S3 client configured for the S3-compatible endpoint"""
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.config import Config
|
||||||
|
except ImportError:
|
||||||
|
log.error("boto3 is required for S3 uploads. Install with: pip install boto3")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
|
||||||
|
log.error("S3 environment variables not fully configured. Required:")
|
||||||
|
log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return boto3.client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=S3_ENDPOINT_URL,
|
||||||
|
aws_access_key_id=S3_ACCESS_KEY_ID,
|
||||||
|
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||||
|
config=Config(signature_version="s3v4"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- API Functions ----------------------
|
||||||
|
async def call_generate(
|
||||||
|
client: Serverless,
|
||||||
|
*,
|
||||||
|
endpoint_name: str,
|
||||||
|
prompt: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
steps: int,
|
||||||
|
seed: int,
|
||||||
|
) -> dict:
|
||||||
|
"""Generate image using Text2Image modifier"""
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
payload = {
|
||||||
|
"input": {
|
||||||
|
"request_id": str(uuid.uuid4()),
|
||||||
|
"modifier": "Text2Image",
|
||||||
|
"modifications": {
|
||||||
|
"prompt": prompt,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"steps": steps,
|
||||||
|
"seed": seed,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return await endpoint.request("/generate/sync", payload, cost=COST)
|
||||||
|
|
||||||
|
|
||||||
|
async def call_generate_workflow(
|
||||||
|
client: Serverless,
|
||||||
|
*,
|
||||||
|
endpoint_name: str,
|
||||||
|
workflow_json: dict,
|
||||||
|
) -> dict:
|
||||||
|
"""Generate using custom workflow JSON"""
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
payload = {
|
||||||
|
"input": {
|
||||||
|
"request_id": str(uuid.uuid4()),
|
||||||
|
"workflow_json": workflow_json,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return await endpoint.request("/generate/sync", payload, cost=COST)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- Demo Class ----------------------
|
||||||
|
class APIDemo:
|
||||||
|
def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False):
|
||||||
|
self.client = client
|
||||||
|
self.endpoint_name = endpoint_name
|
||||||
|
self.upload_s3 = upload_s3
|
||||||
|
self.s3_client = get_s3_client() if upload_s3 else None
|
||||||
|
|
||||||
|
if upload_s3 and not self.s3_client:
|
||||||
|
log.warning("S3 upload requested but client creation failed. Images will only be saved locally.")
|
||||||
|
|
||||||
|
def extract_filename(self, response: dict) -> str | None:
|
||||||
|
"""Extract the generated image filename from ComfyUI response"""
|
||||||
|
if "comfyui_response" in response:
|
||||||
|
for data in response["comfyui_response"].values():
|
||||||
|
if isinstance(data, dict) and "outputs" in data:
|
||||||
|
for node_output in data["outputs"].values():
|
||||||
|
if "images" in node_output and node_output["images"]:
|
||||||
|
return node_output["images"][0].get("filename")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
|
||||||
|
"""Fetch and save image locally from the worker, optionally upload to S3"""
|
||||||
|
os.makedirs("generated_images", exist_ok=True)
|
||||||
|
return await self._fetch_image(worker_url, filename, local_name)
|
||||||
|
|
||||||
|
def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None:
|
||||||
|
"""Upload a local file to S3 and return the S3 URL"""
|
||||||
|
if not self.s3_client:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.s3_client.upload_file(
|
||||||
|
local_path,
|
||||||
|
S3_BUCKET_NAME,
|
||||||
|
s3_key,
|
||||||
|
ExtraArgs={"ContentType": "image/png"}
|
||||||
|
)
|
||||||
|
s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}"
|
||||||
|
print(f" ☁️ Uploaded to S3: {s3_key}")
|
||||||
|
return s3_url
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Failed to upload to S3: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
|
||||||
|
"""Fetch image from worker's /view endpoint and save locally"""
|
||||||
|
if not worker_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = f"{worker_url}/view"
|
||||||
|
params = {"filename": filename, "type": "output"}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url, params=params, ssl=False) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
path = f"generated_images/{local_name}"
|
||||||
|
image_data = await resp.read()
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(image_data)
|
||||||
|
print(f" 💾 Saved: {path}")
|
||||||
|
|
||||||
|
# Upload to S3 if enabled
|
||||||
|
if self.upload_s3 and self.s3_client:
|
||||||
|
s3_key = f"comfyui/{local_name}"
|
||||||
|
self._upload_to_s3(path, s3_key)
|
||||||
|
|
||||||
|
return path
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def demo_prompt(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
steps: int,
|
||||||
|
seed: int | None,
|
||||||
|
):
|
||||||
|
"""Demo: Generate image from text prompt"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("COMFYUI TEXT-TO-IMAGE DEMO")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if seed is None:
|
||||||
|
seed = random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
|
print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}")
|
||||||
|
print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}")
|
||||||
|
print("\n🎨 Generating image...")
|
||||||
|
|
||||||
|
response = await call_generate(
|
||||||
|
self.client,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
prompt=prompt,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
steps=steps,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ Generation complete!")
|
||||||
|
|
||||||
|
# Get worker URL for fetching images
|
||||||
|
worker_url = response.get("url", "")
|
||||||
|
print(f"Worker URL: {worker_url}")
|
||||||
|
|
||||||
|
# Fetch and save image
|
||||||
|
if "response" in response:
|
||||||
|
filename = self.extract_filename(response["response"])
|
||||||
|
if filename:
|
||||||
|
path = await self.save_image(worker_url, filename, f"comfy_{seed}.png")
|
||||||
|
if not path:
|
||||||
|
print(f"❌ Failed to fetch image")
|
||||||
|
else:
|
||||||
|
print("❌ No image in response")
|
||||||
|
else:
|
||||||
|
print("❌ Unexpected response format")
|
||||||
|
|
||||||
|
async def demo_workflow(self, workflow_file: str):
|
||||||
|
"""Demo: Generate using custom workflow file"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("COMFYUI CUSTOM WORKFLOW DEMO")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if not os.path.exists(workflow_file):
|
||||||
|
log.error(f"Workflow file not found: {workflow_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(workflow_file, "r") as f:
|
||||||
|
workflow_json = json.load(f)
|
||||||
|
|
||||||
|
print(f"Workflow: {workflow_file}")
|
||||||
|
print("\n🎨 Generating...")
|
||||||
|
|
||||||
|
response = await call_generate_workflow(
|
||||||
|
self.client,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
workflow_json=workflow_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ Generation complete!")
|
||||||
|
|
||||||
|
worker_url = response.get("url", "")
|
||||||
|
|
||||||
|
if "response" in response:
|
||||||
|
filename = self.extract_filename(response["response"])
|
||||||
|
if filename:
|
||||||
|
path = await self.save_image(worker_url, filename, "workflow.png")
|
||||||
|
if not path:
|
||||||
|
print(f"❌ Failed to fetch image")
|
||||||
|
else:
|
||||||
|
print("❌ No image in response")
|
||||||
|
else:
|
||||||
|
print("❌ Unexpected response format")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- CLI ----------------------
|
||||||
|
def build_arg_parser() -> argparse.ArgumentParser:
|
||||||
|
p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)")
|
||||||
|
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||||
|
p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT",
|
||||||
|
help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')")
|
||||||
|
p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead")
|
||||||
|
p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})")
|
||||||
|
p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})")
|
||||||
|
p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})")
|
||||||
|
p.add_argument("--seed", type=int, default=None, help="Seed (default: random)")
|
||||||
|
p.add_argument("--s3", action="store_true",
|
||||||
|
help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)")
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
async def main_async():
|
||||||
|
args = build_arg_parser().parse_args()
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Using endpoint: {args.endpoint}")
|
||||||
|
if args.s3:
|
||||||
|
print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with Serverless() as client:
|
||||||
|
demo = APIDemo(client, args.endpoint, upload_s3=args.s3)
|
||||||
|
|
||||||
|
if args.workflow:
|
||||||
|
await demo.demo_workflow(workflow_file=args.workflow)
|
||||||
|
else:
|
||||||
|
await demo.demo_prompt(
|
||||||
|
prompt=args.prompt,
|
||||||
|
width=args.width,
|
||||||
|
height=args.height,
|
||||||
|
steps=args.steps,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
except AttributeError as e:
|
||||||
|
if "API key" in str(e):
|
||||||
|
log.error("API key missing. Set VAST_API_KEY environment variable.")
|
||||||
|
else:
|
||||||
|
log.error(f"Error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main_async())
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
import dataclasses
|
||||||
|
from typing import Dict, Any
|
||||||
|
from functools import cache
|
||||||
|
from math import ceil
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from lib.data_types import ApiPayload, JsonDataException
|
||||||
|
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
def count_workload() -> float:
|
||||||
|
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
||||||
|
# Results will indicate % or a job completed per second. Avoids sub 0.1 sec performance indication
|
||||||
|
return 100.0
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ComfyWorkflowData(ApiPayload):
|
||||||
|
input: dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_test(cls):
|
||||||
|
"""
|
||||||
|
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
|
||||||
|
Otherwise, use the variables available to simulate workflows of the required running time
|
||||||
|
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
||||||
|
"""
|
||||||
|
# Try to load benchmark.json
|
||||||
|
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
|
||||||
|
|
||||||
|
if benchmark_file.exists():
|
||||||
|
try:
|
||||||
|
with open(benchmark_file, "r") as f:
|
||||||
|
benchmark_workflow = json.load(f)
|
||||||
|
return cls(
|
||||||
|
input={
|
||||||
|
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||||
|
"workflow_json": benchmark_workflow
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, IOError):
|
||||||
|
# JSON is malformed or file can't be read, fall through to default
|
||||||
|
log.error(f"Failed to benchmark using {benchmark_file}")
|
||||||
|
|
||||||
|
# Fallback: read prompts and construct payload
|
||||||
|
log.info("Using fallback method for benchmarking")
|
||||||
|
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
|
||||||
|
test_prompts = f.readlines()
|
||||||
|
|
||||||
|
test_prompt = random.choice(test_prompts).rstrip()
|
||||||
|
return cls(
|
||||||
|
input={
|
||||||
|
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||||
|
"modifier": "Text2Image",
|
||||||
|
"modifications": {
|
||||||
|
"prompt": test_prompt,
|
||||||
|
"width": os.getenv('BENCHMARK_TEST_WIDTH', 512),
|
||||||
|
"height": os.getenv('BENCHMARK_TEST_HEIGHT', 512),
|
||||||
|
"steps": os.getenv('BENCHMARK_TEST_STEPS', 20),
|
||||||
|
"seed": random.randint(0, sys.maxsize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_payload_json(self) -> Dict[str, Any]:
|
||||||
|
# input is already a dict, just return it wrapped in the expected structure
|
||||||
|
return {"input": self.input}
|
||||||
|
|
||||||
|
def count_workload(self) -> float:
|
||||||
|
return count_workload()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ComfyWorkflowData":
|
||||||
|
# Extract required fields
|
||||||
|
if "input" not in json_msg:
|
||||||
|
raise JsonDataException({"input": "missing parameter"})
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
input=json_msg["input"]
|
||||||
|
)
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
{
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"seed": "__RANDOM_INT__",
|
||||||
|
"steps": 20,
|
||||||
|
"cfg": 8,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"denoise": 1,
|
||||||
|
"model": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"7",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"5",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"_meta": {
|
||||||
|
"title": "KSampler"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Checkpoint"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
"batch_size": 1
|
||||||
|
},
|
||||||
|
"class_type": "EmptyLatentImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Empty Latent Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "text, watermark",
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"4",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE Decode"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "ComfyUI",
|
||||||
|
"images": [
|
||||||
|
"8",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Save Image"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
|
||||||
|
stardew valley, fine details
|
||||||
|
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
|
||||||
|
realistic futuristic city-downtown with short buildings, sunset
|
||||||
|
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
|
||||||
|
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
|
||||||
|
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
|
||||||
|
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
|
||||||
|
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
|
||||||
|
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
|
||||||
|
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
|
||||||
|
Pope Francis wearing biker (leather jacket), a masterpiece
|
||||||
|
Luke Skywalker ordering a burger and fries from the Death Star canteen.
|
||||||
|
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
|
||||||
|
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
|
||||||
|
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
|
||||||
|
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||||
|
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||||
|
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
|
||||||
|
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
|
||||||
|
london luxurious interior living-room, light walls
|
||||||
|
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
|
||||||
|
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
|
||||||
|
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
|
||||||
|
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
|
||||||
|
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
|
||||||
|
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
|
||||||
|
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
|
||||||
|
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
|
||||||
|
the street of amedieval fantasy town, at dawn, dark, highly detailed
|
||||||
|
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
|
||||||
|
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
|
||||||
|
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||||
|
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
|
||||||
@@ -0,0 +1,150 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import dataclasses
|
||||||
|
import base64
|
||||||
|
from typing import Optional, Union, Type
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web, ClientResponse
|
||||||
|
|
||||||
|
from lib.backend import Backend, LogAction
|
||||||
|
from lib.data_types import EndpointHandler
|
||||||
|
from lib.server import start_server
|
||||||
|
from .data_types import ComfyWorkflowData
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
|
||||||
|
COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:18188") # Raw ComfyUI server
|
||||||
|
|
||||||
|
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||||
|
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
||||||
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
|
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||||
|
"Value not in list: ", # This error is emitted when the model file is not there at all
|
||||||
|
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
# Check if the response is actually streaming based on response headers/content-type
|
||||||
|
is_streaming_response = (
|
||||||
|
model_response.content_type == "text/event-stream"
|
||||||
|
or model_response.content_type == "application/x-ndjson"
|
||||||
|
or model_response.headers.get("Transfer-Encoding") == "chunked"
|
||||||
|
or "stream" in model_response.content_type.lower()
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_streaming_response:
|
||||||
|
log.debug("Detected streaming response...")
|
||||||
|
res = web.StreamResponse()
|
||||||
|
res.content_type = model_response.content_type
|
||||||
|
await res.prepare(client_request)
|
||||||
|
async for chunk in model_response.content:
|
||||||
|
await res.write(chunk)
|
||||||
|
await res.write_eof()
|
||||||
|
log.debug("Done streaming response")
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
log.debug("Detected non-streaming response...")
|
||||||
|
content = await model_response.read()
|
||||||
|
return web.Response(
|
||||||
|
body=content,
|
||||||
|
status=model_response.status,
|
||||||
|
content_type=model_response.content_type
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> str:
|
||||||
|
return "/generate/sync"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def healthcheck_endpoint(self) -> Optional[str]:
|
||||||
|
return f"{MODEL_SERVER_URL}/health"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def payload_cls(cls) -> Type[ComfyWorkflowData]:
|
||||||
|
return ComfyWorkflowData
|
||||||
|
|
||||||
|
def make_benchmark_payload(self) -> ComfyWorkflowData:
|
||||||
|
return ComfyWorkflowData.for_test()
|
||||||
|
|
||||||
|
async def generate_client_response(
|
||||||
|
self, client_request: web.Request, model_response: ClientResponse
|
||||||
|
) -> Union[web.Response, web.StreamResponse]:
|
||||||
|
return await generate_client_response(client_request, model_response)
|
||||||
|
|
||||||
|
|
||||||
|
backend = Backend(
|
||||||
|
model_server_url=MODEL_SERVER_URL,
|
||||||
|
model_log_file=os.environ["MODEL_LOG"],
|
||||||
|
allow_parallel_requests=False,
|
||||||
|
benchmark_handler=ComfyWorkflowHandler(
|
||||||
|
benchmark_runs=3, benchmark_words=100
|
||||||
|
),
|
||||||
|
log_actions=[
|
||||||
|
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||||
|
(LogAction.Info, "Downloading:"),
|
||||||
|
*[
|
||||||
|
(LogAction.ModelError, error_msg)
|
||||||
|
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_ping(_):
|
||||||
|
return web.Response(body="pong")
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_view(request: web.Request) -> web.Response:
|
||||||
|
"""Proxy /view requests to raw ComfyUI server to fetch generated images"""
|
||||||
|
# Forward query params to raw ComfyUI (not the API wrapper)
|
||||||
|
query_string = request.query_string
|
||||||
|
url = f"{COMFYUI_URL}/view?{query_string}"
|
||||||
|
|
||||||
|
log.debug(f"Proxying /view request to: {url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
content = await resp.read()
|
||||||
|
return web.Response(
|
||||||
|
body=content,
|
||||||
|
status=200,
|
||||||
|
content_type=resp.content_type or "image/png"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text = await resp.text()
|
||||||
|
return web.Response(
|
||||||
|
text=text,
|
||||||
|
status=resp.status,
|
||||||
|
content_type="text/plain"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error proxying /view: {e}")
|
||||||
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
|
||||||
|
web.get("/view", handle_view),
|
||||||
|
web.get("/ping", handle_ping),
|
||||||
|
]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start_server(backend, routes)
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
from lib.test_utils import test_load_cmd, test_args
|
||||||
|
from .data_types import ComfyWorkflowData
|
||||||
|
|
||||||
|
WORKER_ENDPOINT = "/generate/sync"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_load_cmd(ComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||||
+22
-22
@@ -5,25 +5,19 @@ import requests
|
|||||||
|
|
||||||
from lib.test_utils import print_truncate_res
|
from lib.test_utils import print_truncate_res
|
||||||
from utils.endpoint_util import Endpoint
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
|
||||||
"""
|
from vastai import Serverless
|
||||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
|
||||||
"""
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.DEBUG,
|
|
||||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
|
||||||
log = logging.getLogger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
def call_default_workflow(
|
ENDPOINT_NAME = "my-comfyui-endpoint"
|
||||||
endpoint_group_name: str, api_key: str, server_url: str
|
COST = 100 # Use a constant cost for image generation
|
||||||
) -> None:
|
|
||||||
|
def call_default_workflow(endpoint_id: int, api_key: str, server_url: str) -> None:
|
||||||
WORKER_ENDPOINT = "/prompt"
|
WORKER_ENDPOINT = "/prompt"
|
||||||
COST = 100
|
COST = 100
|
||||||
route_payload = {
|
route_payload = {
|
||||||
"endpoint": endpoint_group_name,
|
"endpoint_id": endpoint_id,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"cost": COST,
|
"cost": COST,
|
||||||
}
|
}
|
||||||
@@ -38,7 +32,7 @@ def call_default_workflow(
|
|||||||
auth_data = dict(
|
auth_data = dict(
|
||||||
signature=message["signature"],
|
signature=message["signature"],
|
||||||
cost=message["cost"],
|
cost=message["cost"],
|
||||||
endpoint=message["endpoint"],
|
endpoint_id=message["endpoint_id"],
|
||||||
reqnum=message["reqnum"],
|
reqnum=message["reqnum"],
|
||||||
url=message["url"],
|
url=message["url"],
|
||||||
)
|
)
|
||||||
@@ -51,18 +45,19 @@ def call_default_workflow(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
|
|
||||||
|
|
||||||
def call_custom_workflow_for_sd3(
|
def call_custom_workflow_for_sd3(
|
||||||
endpoint_group_name: str, api_key: str, server_url: str
|
endpoint_id: int, api_key: str, server_url: str
|
||||||
) -> None:
|
) -> None:
|
||||||
WORKER_ENDPOINT = "/custom-workflow"
|
WORKER_ENDPOINT = "/custom-workflow"
|
||||||
COST = 100
|
COST = 100
|
||||||
route_payload = {
|
route_payload = {
|
||||||
"endpoint": endpoint_group_name,
|
"endpoint_id": endpoint_id,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"cost": COST,
|
"cost": COST,
|
||||||
}
|
}
|
||||||
@@ -77,9 +72,10 @@ def call_custom_workflow_for_sd3(
|
|||||||
auth_data = dict(
|
auth_data = dict(
|
||||||
signature=message["signature"],
|
signature=message["signature"],
|
||||||
cost=message["cost"],
|
cost=message["cost"],
|
||||||
endpoint=message["endpoint"],
|
endpoint_id=message["endpoint_id"],
|
||||||
reqnum=message["reqnum"],
|
reqnum=message["reqnum"],
|
||||||
url=message["url"],
|
url=message["url"],
|
||||||
|
request_idx=message["request_idx"],
|
||||||
)
|
)
|
||||||
workflow = {
|
workflow = {
|
||||||
"3": {
|
"3": {
|
||||||
@@ -141,6 +137,7 @@ def call_custom_workflow_for_sd3(
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json=req_data,
|
json=req_data,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
print_truncate_res(str(response.json()))
|
print_truncate_res(str(response.json()))
|
||||||
@@ -149,25 +146,28 @@ def call_custom_workflow_for_sd3(
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from lib.test_utils import test_args
|
from lib.test_utils import test_args
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
args = test_args.parse_args()
|
args = test_args.parse_args()
|
||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
endpoint_info = Endpoint.get_endpoint_info(
|
||||||
endpoint_name=args.endpoint_group_name,
|
endpoint_name=args.endpoint_group_name,
|
||||||
account_api_key=args.api_key,
|
account_api_key=args.api_key,
|
||||||
instance=args.instance,
|
instance=args.instance,
|
||||||
)
|
)
|
||||||
if endpoint_api_key:
|
if endpoint_info:
|
||||||
|
endpoint_id = endpoint_info["id"]
|
||||||
|
endpoint_api_key = endpoint_info["api_key"]
|
||||||
try:
|
try:
|
||||||
call_default_workflow(
|
call_default_workflow(
|
||||||
|
endpoint_id=endpoint_id,
|
||||||
api_key=endpoint_api_key,
|
api_key=endpoint_api_key,
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
|
||||||
server_url=args.server_url,
|
server_url=args.server_url,
|
||||||
)
|
)
|
||||||
call_custom_workflow_for_sd3(
|
call_custom_workflow_for_sd3(
|
||||||
|
endpoint_id=endpoint_id,
|
||||||
api_key=endpoint_api_key,
|
api_key=endpoint_api_key,
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
|
||||||
server_url=args.server_url,
|
server_url=args.server_url,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error during API call: {e}")
|
log.error(f"Error during API call: {e}")
|
||||||
else:
|
else:
|
||||||
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
|
log.error(f"Failed to get endpoint info for {args.endpoint_group_name}")
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from lib.server import start_server
|
|||||||
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
|
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
|
||||||
|
|
||||||
|
|
||||||
MODEL_SERVER_URL = "http://0.0.0.0:38188"
|
MODEL_SERVER_URL = "http://127.0.0.1:18288" # API Wrapper Service
|
||||||
|
|
||||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||||
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
|
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
|
||||||
|
|||||||
+33
-26
@@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
|
|||||||
|
|
||||||
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
||||||
|
|
||||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
|
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
|
||||||
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
||||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
|
|
||||||
|
|
||||||
|
|
||||||
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
||||||
|
|
||||||
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||||
|
|
||||||
## Client Setup (Demo)
|
## Client Setup (Demo)
|
||||||
|
|
||||||
@@ -34,38 +33,20 @@ uv pip install -r requirements.txt
|
|||||||
|
|
||||||
Several examples have been provided in the client to help you get started with your own implementation.
|
Several examples have been provided in the client to help you get started with your own implementation.
|
||||||
|
|
||||||
### Completions
|
First, set your API key as an environment variable:
|
||||||
|
|
||||||
Call to `/v1/completions` with json response
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
export VAST_API_KEY=<your_api_key>
|
||||||
```
|
```
|
||||||
|
|
||||||
### Chat Completion (json)
|
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
|
||||||
|
|
||||||
Call to `/v1/chat/completions` with json response
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
|
||||||
```
|
|
||||||
|
|
||||||
### Chat Completion (streaming)
|
### Chat Completion (streaming)
|
||||||
|
|
||||||
Call to `/v1/chat/completions` with streaming response
|
Call to `/v1/chat/completions` with streaming response
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||||
```
|
|
||||||
|
|
||||||
### Tool Use (json)
|
|
||||||
|
|
||||||
Call to `/v1/chat/completions` with tool and json response.
|
|
||||||
|
|
||||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Interactive Chat (streaming)
|
### Interactive Chat (streaming)
|
||||||
@@ -75,6 +56,32 @@ Interactive session with calls to `/v1/chat/completions`.
|
|||||||
Type `clear` to clear the chat history or `quit` to exit.
|
Type `clear` to clear the chat history or `quit` to exit.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Chat Completion (json)
|
||||||
|
|
||||||
|
Call to `/v1/chat/completions` with json response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Use (json)
|
||||||
|
|
||||||
|
Call to `/v1/chat/completions` with tool and json response.
|
||||||
|
|
||||||
|
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Completions
|
||||||
|
|
||||||
|
Call to `/v1/completions` with json response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
+355
-388
@@ -1,13 +1,15 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
from urllib.parse import urljoin
|
import argparse
|
||||||
from typing import Dict, Any, Optional, Iterator, Union, List
|
from typing import Any, Dict, List, Optional
|
||||||
import requests
|
|
||||||
from utils.endpoint_util import Endpoint
|
|
||||||
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
|
||||||
|
|
||||||
|
from vastai import Serverless
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# ---------------------- Logging ----------------------
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG,
|
level=logging.DEBUG,
|
||||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
@@ -15,135 +17,20 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
COMPLETIONS_PROMPT = "the capital of USA is"
|
# ---------------------- Prompts ----------------------
|
||||||
|
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
|
||||||
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||||
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
|
TOOLS_PROMPT = (
|
||||||
|
"Can you list the files in the current working directory and tell me what you see? "
|
||||||
class APIClient:
|
"What do you think this directory might be for?"
|
||||||
"""Lightweight client focused solely on API communication"""
|
|
||||||
|
|
||||||
# Remove the generic WORKER_ENDPOINT since we're now going direct
|
|
||||||
DEFAULT_COST = 100
|
|
||||||
DEFAULT_TIMEOUT = 4
|
|
||||||
|
|
||||||
def __init__(self, endpoint_group_name: str, api_key: str, server_url: str):
|
|
||||||
self.endpoint_group_name = endpoint_group_name
|
|
||||||
self.api_key = api_key
|
|
||||||
self.server_url = server_url
|
|
||||||
self.endpoint_api_key = self._get_endpoint_api_key()
|
|
||||||
|
|
||||||
def _get_endpoint_api_key(self) -> Optional[str]:
|
|
||||||
"""Get the endpoint API key"""
|
|
||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
|
||||||
endpoint_name=self.endpoint_group_name,
|
|
||||||
account_api_key=self.api_key,
|
|
||||||
)
|
|
||||||
if not endpoint_api_key:
|
|
||||||
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
|
|
||||||
return endpoint_api_key
|
|
||||||
|
|
||||||
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
|
||||||
"""Get worker URL and auth data from routing service"""
|
|
||||||
if not self.endpoint_api_key:
|
|
||||||
raise ValueError("No valid endpoint API key available")
|
|
||||||
|
|
||||||
route_payload = {
|
|
||||||
"endpoint": self.endpoint_group_name,
|
|
||||||
"api_key": self.endpoint_api_key,
|
|
||||||
"cost": cost,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
urljoin(self.server_url, "/route/"),
|
|
||||||
json=route_payload,
|
|
||||||
timeout=self.DEFAULT_TIMEOUT,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Create auth data from routing response"""
|
|
||||||
return {
|
|
||||||
"signature": message["signature"],
|
|
||||||
"cost": message["cost"],
|
|
||||||
"endpoint": message["endpoint"],
|
|
||||||
"reqnum": message["reqnum"],
|
|
||||||
"url": message["url"],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _make_request(self, payload: Dict[str, Any], endpoint: str, method: str = "POST",
|
|
||||||
stream: bool = False) -> Union[Dict[str, Any], Iterator[str]]:
|
|
||||||
"""Make request directly to the specific worker endpoint"""
|
|
||||||
# Get worker URL and auth data
|
|
||||||
cost = payload.get('max_tokens')
|
|
||||||
message = self._get_worker_url(cost=cost)
|
|
||||||
worker_url = message["url"]
|
|
||||||
auth_data = self._create_auth_data(message)
|
|
||||||
|
|
||||||
req_data = {
|
|
||||||
"payload": {
|
|
||||||
"input": payload
|
|
||||||
},
|
|
||||||
"auth_data": auth_data
|
|
||||||
}
|
|
||||||
|
|
||||||
url = urljoin(worker_url, endpoint)
|
|
||||||
log.debug(f"Making direct request to: {url}")
|
|
||||||
log.debug(f"Payload: {req_data}")
|
|
||||||
|
|
||||||
# Make the request using the specified method
|
|
||||||
if method.upper() == "POST":
|
|
||||||
response = requests.post(url, json=req_data, stream=stream)
|
|
||||||
elif method.upper() == "GET":
|
|
||||||
response = requests.get(url, params=req_data, stream=stream)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._handle_streaming_response(response)
|
|
||||||
else:
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
|
|
||||||
"""Handle streaming response and yield tokens"""
|
|
||||||
try:
|
|
||||||
for line in response.iter_lines(decode_unicode=True):
|
|
||||||
if line:
|
|
||||||
if line.startswith("data: "):
|
|
||||||
data_str = line[6:]
|
|
||||||
if data_str.strip() == "[DONE]":
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
data = json.loads(data_str)
|
|
||||||
yield data # Yield the full chunk
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error handling streaming response: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def call_completions(self, config: CompletionConfig) -> Union[Dict[str, Any], Iterator[str]]:
|
|
||||||
payload = config.to_dict()
|
|
||||||
|
|
||||||
return self._make_request(
|
|
||||||
payload=payload,
|
|
||||||
endpoint="/v1/completions",
|
|
||||||
stream=config.stream
|
|
||||||
)
|
|
||||||
|
|
||||||
def call_chat_completions(self, config: ChatCompletionConfig) -> Union[Dict[str, Any], Iterator[str]]:
|
|
||||||
payload = config.to_dict()
|
|
||||||
|
|
||||||
return self._make_request(
|
|
||||||
payload=payload,
|
|
||||||
endpoint="/v1/chat/completions",
|
|
||||||
stream=config.stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
|
||||||
|
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
|
||||||
|
MAX_TOKENS = 1024
|
||||||
|
DEFAULT_TEMPERATURE = 0.7
|
||||||
|
|
||||||
|
# ---------------------- Tooling ----------------------
|
||||||
class ToolManager:
|
class ToolManager:
|
||||||
"""Handles tool definitions and execution"""
|
"""Handles tool definitions and execution"""
|
||||||
|
|
||||||
@@ -151,7 +38,9 @@ class ToolManager:
|
|||||||
def list_files() -> str:
|
def list_files() -> str:
|
||||||
"""Execute ls on current directory"""
|
"""Execute ls on current directory"""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(['ls', '-la', '.'], capture_output=True, text=True, timeout=10)
|
result = subprocess.run(
|
||||||
|
["ls", "-la", "."], capture_output=True, text=True, timeout=10
|
||||||
|
)
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
return result.stdout
|
return result.stdout
|
||||||
else:
|
else:
|
||||||
@@ -161,296 +50,410 @@ class ToolManager:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_ls_tool_definition() -> List[Dict[str, Any]]:
|
def get_ls_tool_definition() -> List[Dict[str, Any]]:
|
||||||
"""Get the ls tool definition"""
|
"""OpenAI-compatible tool schema"""
|
||||||
return [{
|
return [
|
||||||
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "list_files",
|
"name": "list_files",
|
||||||
"description": "List files and directories in the cwd",
|
"description": "List files and directories in the cwd",
|
||||||
"parameters": {
|
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||||
"type": "object",
|
},
|
||||||
"properties": {},
|
|
||||||
"required": []
|
|
||||||
}
|
}
|
||||||
}
|
]
|
||||||
}]
|
|
||||||
|
|
||||||
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
||||||
"""Execute a tool call and return the result"""
|
"""Execute a tool call and return the result"""
|
||||||
function_name = tool_call["function"]["name"]
|
function_name = (tool_call.get("function") or {}).get("name")
|
||||||
|
|
||||||
if function_name == "list_files":
|
if function_name == "list_files":
|
||||||
return self.list_files()
|
return self.list_files()
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown tool function: {function_name}")
|
raise ValueError(f"Unknown tool function: {function_name}")
|
||||||
|
|
||||||
|
|
||||||
|
# ----- Helpers to handle streamed tool_calls assembly -----
|
||||||
|
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
|
||||||
|
We merge into a per-index state dict until the assistant message finishes.
|
||||||
|
"""
|
||||||
|
idx = tc_delta.get("index")
|
||||||
|
if idx is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
|
||||||
|
|
||||||
|
if tc_delta.get("id"):
|
||||||
|
entry["id"] = tc_delta["id"]
|
||||||
|
|
||||||
|
fn_delta = tc_delta.get("function") or {}
|
||||||
|
if "name" in fn_delta and fn_delta["name"]:
|
||||||
|
entry["function"]["name"] = fn_delta["name"]
|
||||||
|
if "arguments" in fn_delta and fn_delta["arguments"]:
|
||||||
|
entry["function"]["arguments"] += fn_delta["arguments"]
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
return [state[i] for i in sorted(state.keys())]
|
||||||
|
|
||||||
|
|
||||||
|
# ---- OpenAI-compatible calls (non-streaming) ----
|
||||||
|
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"input": {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
|
||||||
|
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
||||||
|
return resp["response"]
|
||||||
|
|
||||||
|
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"input": {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
|
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||||
|
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||||
|
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
|
||||||
|
return resp["response"]
|
||||||
|
|
||||||
|
# ---- Streaming variants ----
|
||||||
|
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
|
||||||
|
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"input": {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
|
"stream": True,
|
||||||
|
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
|
||||||
|
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||||
|
return resp["response"] # async generator
|
||||||
|
|
||||||
|
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
|
||||||
|
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"input": {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
|
"stream": True,
|
||||||
|
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||||
|
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||||
|
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||||
|
return resp["response"] # async generator
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- Demo Runner ----------------------
|
||||||
class APIDemo:
|
class APIDemo:
|
||||||
"""Demo and testing functionality for the API client"""
|
"""Demo and testing functionality for the API client"""
|
||||||
|
|
||||||
def __init__(self, client: APIClient, model: str, tool_manager: ToolManager = None):
|
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.endpoint_name = endpoint_name
|
||||||
self.tool_manager = tool_manager or ToolManager()
|
self.tool_manager = tool_manager or ToolManager()
|
||||||
|
|
||||||
def handle_streaming_response(self, response_stream, show_reasoning: bool = True) -> str:
|
# ----- Streaming handler -----
|
||||||
"""
|
async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
|
||||||
Handle streaming chat response and display all output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
full_response = ""
|
full_response = ""
|
||||||
reasoning_content = ""
|
reasoning_content = ""
|
||||||
reasoning_started = False
|
printed_reasoning = False
|
||||||
content_started = False
|
printed_answer = False
|
||||||
|
finish_reason = None
|
||||||
|
|
||||||
for chunk in response_stream:
|
async for chunk in stream:
|
||||||
# Normalize the chunk
|
choice = (chunk.get("choices") or [{}])[0]
|
||||||
if isinstance(chunk, str):
|
delta = choice.get("delta", {})
|
||||||
chunk = chunk.strip()
|
|
||||||
if chunk.startswith("data: "):
|
|
||||||
chunk = chunk[6:].strip()
|
|
||||||
if chunk in ["[DONE]", ""]:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
parsed_chunk = json.loads(chunk)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
elif isinstance(chunk, dict):
|
|
||||||
parsed_chunk = chunk
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Parse delta from the chunk
|
# Track finish reason
|
||||||
choices = parsed_chunk.get("choices", [])
|
if choice.get("finish_reason"):
|
||||||
if not choices:
|
finish_reason = choice.get("finish_reason")
|
||||||
continue
|
|
||||||
|
|
||||||
delta = choices[0].get("delta", {})
|
# reasoning tokens
|
||||||
reasoning_token = delta.get("reasoning_content", "")
|
rc = delta.get("reasoning_content")
|
||||||
content_token = delta.get("content", "")
|
if rc and show_reasoning:
|
||||||
|
if not printed_reasoning:
|
||||||
# Print reasoning token if applicable
|
|
||||||
if show_reasoning and reasoning_token:
|
|
||||||
if not reasoning_started:
|
|
||||||
print("\n🧠 Reasoning: ", end="", flush=True)
|
print("\n🧠 Reasoning: ", end="", flush=True)
|
||||||
reasoning_started = True
|
printed_reasoning = True
|
||||||
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
|
print(rc, end="", flush=True)
|
||||||
reasoning_content += reasoning_token
|
reasoning_content += rc
|
||||||
|
|
||||||
# Print content token
|
# content tokens
|
||||||
if content_token:
|
content_part = delta.get("content")
|
||||||
if not content_started:
|
if content_part:
|
||||||
if show_reasoning and reasoning_started:
|
if not printed_answer:
|
||||||
print(f"\n💬 Response: ", end="", flush=True)
|
if show_reasoning and printed_reasoning:
|
||||||
|
print("\n💬 Response: ", end="", flush=True)
|
||||||
else:
|
else:
|
||||||
print("Assistant: ", end="", flush=True)
|
print("Assistant: ", end="", flush=True)
|
||||||
content_started = True
|
printed_answer = True
|
||||||
print(content_token, end="", flush=True)
|
print(content_part, end="", flush=True)
|
||||||
full_response += content_token
|
full_response += content_part
|
||||||
|
|
||||||
print() # Ensure newline after response
|
|
||||||
|
|
||||||
|
print() # newline
|
||||||
if show_reasoning:
|
if show_reasoning:
|
||||||
if reasoning_started or content_started:
|
if printed_reasoning or printed_answer:
|
||||||
print("\nStreaming completed.")
|
print("\nStreaming completed.")
|
||||||
if reasoning_started:
|
if printed_reasoning:
|
||||||
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
||||||
if content_started:
|
if printed_answer:
|
||||||
print(f"Response tokens: {len(full_response.split())}")
|
print(f"Response tokens: {len(full_response.split())}")
|
||||||
|
if finish_reason:
|
||||||
|
print(f"Finish reason: {finish_reason}")
|
||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
|
async def demo_completions(self) -> None:
|
||||||
def test_tool_support(self) -> bool:
|
|
||||||
"""Test if the endpoint supports function calling"""
|
|
||||||
log.debug("Testing endpoint tool calling support...")
|
|
||||||
|
|
||||||
# Try a simple request with minimal tools to test support
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
|
||||||
minimal_tool = [{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "test_function",
|
|
||||||
"description": "Test function"
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
|
|
||||||
config = ChatCompletionConfig(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=10,
|
|
||||||
tools=minimal_tool,
|
|
||||||
tool_choice="none" # Don't actually call the tool
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = self.client.call_chat_completions(config)
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error: Endpoint does not support tool calling: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def demo_completions(self) -> None:
|
|
||||||
"""Demo: test basic completions endpoint"""
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("COMPLETIONS DEMO")
|
print("COMPLETIONS DEMO")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
config = CompletionConfig(
|
response = await call_completions(
|
||||||
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
prompt=COMPLETIONS_PROMPT,
|
prompt=COMPLETIONS_PROMPT,
|
||||||
stream=False
|
endpoint_name=self.endpoint_name,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'")
|
|
||||||
response = self.client.call_completions(config)
|
|
||||||
|
|
||||||
if isinstance(response, dict):
|
|
||||||
print("\nResponse:")
|
print("\nResponse:")
|
||||||
print(json.dumps(response, indent=2))
|
print(json.dumps(response, indent=2))
|
||||||
else:
|
|
||||||
log.error("Unexpected response format")
|
|
||||||
|
|
||||||
def demo_chat(self, use_streaming: bool = True) -> None:
|
async def demo_chat(self, use_streaming: bool = True) -> None:
|
||||||
"""
|
|
||||||
Demo: test chat completions endpoint with optional streaming
|
|
||||||
"""
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
|
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
config = ChatCompletionConfig(
|
messages = [{"role": "user", "content": CHAT_PROMPT}]
|
||||||
model=self.model,
|
|
||||||
messages=[{"role": "user", "content": CHAT_PROMPT}],
|
|
||||||
stream=use_streaming,
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info(f"Testing chat completions with model '{self.model}'...")
|
|
||||||
response = self.client.call_chat_completions(config)
|
|
||||||
|
|
||||||
if use_streaming:
|
if use_streaming:
|
||||||
|
stream = await stream_chat_completions(
|
||||||
|
client=self.client,
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
temperature=DEFAULT_TEMPERATURE
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
self.handle_streaming_response(response, show_reasoning=True)
|
await self.handle_streaming_response(stream, show_reasoning=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"\nError during streaming: {e}")
|
log.error("\nError during streaming: %s", e, exc_info=True)
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if isinstance(response, dict):
|
response = await call_chat_completions(
|
||||||
choice = response.get("choices", [{}])[0]
|
client=self.client,
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
temperature=DEFAULT_TEMPERATURE
|
||||||
|
)
|
||||||
|
choice = (response.get("choices") or [{}])[0]
|
||||||
message = choice.get("message", {})
|
message = choice.get("message", {})
|
||||||
content = message.get("content", "")
|
content = message.get("content", "")
|
||||||
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
|
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
|
||||||
|
|
||||||
if reasoning:
|
if reasoning:
|
||||||
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
||||||
|
|
||||||
print(f"\n💬 Assistant: {content}")
|
print(f"\n💬 Assistant: {content}")
|
||||||
print(f"\nFull Response:")
|
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
||||||
print(json.dumps(response, indent=2))
|
|
||||||
else:
|
|
||||||
log.error("Unexpected response format")
|
|
||||||
|
|
||||||
|
async def test_tool_support(self) -> bool:
|
||||||
|
"""Probe that tool schema is accepted (no actual call)"""
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
minimal_tool = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "test_function", "description": "Test function"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
_ = await call_chat_completions(
|
||||||
|
client=self.client,
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
tools=minimal_tool,
|
||||||
|
tool_choice="none",
|
||||||
|
max_tokens=10
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
log.error("Endpoint does not support tool calling: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def demo_ls_tool(self) -> None:
|
||||||
def demo_ls_tool(self) -> None:
|
"""Ask to list files using function calling, then provide final analysis"""
|
||||||
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("TOOL USE DEMO: List Directory Contents")
|
print("TOOL USE DEMO: List Directory Contents")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# Test if tools are supported first
|
if not await self.test_tool_support():
|
||||||
if not self.test_tool_support():
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Request with tool available
|
messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": TOOLS_PROMPT}
|
|
||||||
]
|
|
||||||
|
|
||||||
config = ChatCompletionConfig(
|
# First pass: let the model decide tools, stream tool_calls and partial content
|
||||||
|
stream = await stream_chat_completions(
|
||||||
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
tools=self.tool_manager.get_ls_tool_definition(),
|
tools=self.tool_manager.get_ls_tool_definition(),
|
||||||
tool_choice="auto"
|
tool_choice="auto",
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"Making initial request with tool using model '{self.model}'...")
|
assistant_content_buf: List[str] = []
|
||||||
response = self.client.call_chat_completions(config)
|
tool_calls_state: Dict[int, Dict[str, Any]] = {}
|
||||||
|
printed_reasoning = False
|
||||||
|
printed_answer = False
|
||||||
|
|
||||||
if not isinstance(response, dict):
|
async for chunk in stream:
|
||||||
raise ValueError("Expected dict response for tool use")
|
choice = (chunk.get("choices") or [{}])[0]
|
||||||
|
delta = choice.get("delta", {})
|
||||||
|
|
||||||
choice = response.get("choices", [{}])[0]
|
rc = delta.get("reasoning_content")
|
||||||
message = choice.get("message", {})
|
if rc:
|
||||||
|
if not printed_reasoning:
|
||||||
|
printed_reasoning = True
|
||||||
|
print("🧠 Reasoning: ", end="", flush=True)
|
||||||
|
print(rc, end="", flush=True)
|
||||||
|
|
||||||
print(f"Assistant response: {message.get('content', 'No content')}")
|
content_part = delta.get("content")
|
||||||
|
if content_part:
|
||||||
|
assistant_content_buf.append(content_part)
|
||||||
|
if not printed_answer:
|
||||||
|
printed_answer = True
|
||||||
|
print("\n💬 Response: ", end="", flush=True)
|
||||||
|
print(content_part, end="", flush=True)
|
||||||
|
|
||||||
# Check for tool calls
|
if "tool_calls" in delta and delta["tool_calls"]:
|
||||||
tool_calls = message.get("tool_calls")
|
for tc_delta in delta["tool_calls"]:
|
||||||
if not tool_calls:
|
_merge_tool_call_delta(tool_calls_state, tc_delta)
|
||||||
raise ValueError("No tool calls made - model may not support function calling")
|
|
||||||
|
|
||||||
print(f"Tool calls detected: {len(tool_calls)}")
|
# If no tool calls, we’re done.
|
||||||
|
if not tool_calls_state:
|
||||||
|
print("\n(No tool calls were made.)")
|
||||||
|
return
|
||||||
|
|
||||||
# Execute the tool call
|
# Build assistant message with tool_calls
|
||||||
for tool_call in tool_calls:
|
assistant_message = {
|
||||||
function_name = tool_call["function"]["name"]
|
"role": "assistant",
|
||||||
print(f"Executing tool: {function_name}")
|
"content": "".join(assistant_content_buf) if assistant_content_buf else None,
|
||||||
|
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
|
||||||
|
}
|
||||||
|
messages.append(assistant_message)
|
||||||
|
|
||||||
tool_result = self.tool_manager.execute_tool_call(tool_call)
|
# Execute tools and feed results back
|
||||||
print(f"Tool result:\n{tool_result}")
|
for tc in assistant_message["tool_calls"]:
|
||||||
|
tool_name = (tc.get("function") or {}).get("name")
|
||||||
|
call_id = tc.get("id")
|
||||||
|
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
|
||||||
|
|
||||||
# Add tool result and continue conversation
|
try:
|
||||||
messages.append(message) # Add assistant's message with tool call
|
args = json.loads(raw_args) if raw_args.strip() else {}
|
||||||
messages.append({
|
except Exception as e:
|
||||||
"role": "tool",
|
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
|
||||||
"tool_call_id": tool_call["id"],
|
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
|
||||||
"content": tool_result
|
continue
|
||||||
})
|
|
||||||
|
|
||||||
# Get final response
|
try:
|
||||||
final_config = ChatCompletionConfig(
|
if tool_name == "list_files":
|
||||||
|
tool_result = self.tool_manager.list_files()
|
||||||
|
else:
|
||||||
|
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
|
||||||
|
except Exception as e:
|
||||||
|
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
|
||||||
|
|
||||||
|
print("\n[Tool executed]", tool_name)
|
||||||
|
print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
|
||||||
|
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
|
||||||
|
|
||||||
|
# Second pass: get final streamed answer after tool results
|
||||||
|
stream2 = await stream_chat_completions(
|
||||||
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=self.tool_manager.get_ls_tool_definition()
|
endpoint_name=self.endpoint_name,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Getting final response...")
|
final_buf = []
|
||||||
final_response = self.client.call_chat_completions(final_config)
|
printed_reasoning2 = False
|
||||||
|
printed_answer2 = False
|
||||||
|
|
||||||
if isinstance(final_response, dict):
|
async for chunk in stream2:
|
||||||
final_choice = final_response.get("choices", [{}])[0]
|
choice = (chunk.get("choices") or [{}])[0]
|
||||||
final_message = final_choice.get("message", {})
|
delta = choice.get("delta", {})
|
||||||
final_content = final_message.get("content", "")
|
|
||||||
|
rc2 = delta.get("reasoning_content")
|
||||||
|
if rc2:
|
||||||
|
if not printed_reasoning2:
|
||||||
|
printed_reasoning2 = True
|
||||||
|
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
|
||||||
|
print(rc2, end="", flush=True)
|
||||||
|
|
||||||
|
c2 = delta.get("content")
|
||||||
|
if c2:
|
||||||
|
final_buf.append(c2)
|
||||||
|
if not printed_answer2:
|
||||||
|
printed_answer2 = True
|
||||||
|
print("\n💬 Response (final): ", end="", flush=True)
|
||||||
|
print(c2, end="", flush=True)
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("FINAL LLM ANALYSIS:")
|
print("FINAL LLM ANALYSIS:")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(final_content)
|
print("".join(final_buf))
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
def interactive_chat(self) -> None:
|
async def interactive_chat(self) -> None:
|
||||||
"""Interactive chat session with streaming"""
|
"""Interactive chat session with streaming"""
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("INTERACTIVE STREAMING CHAT")
|
print("INTERACTIVE STREAMING CHAT")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"Using model: {self.model}")
|
|
||||||
print("Type 'quit' to exit, 'clear' to clear history")
|
print("Type 'quit' to exit, 'clear' to clear history")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
messages = []
|
messages: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
user_input = input("You: ").strip()
|
user_input = input("You: ").strip()
|
||||||
|
|
||||||
if user_input.lower() == 'quit':
|
if user_input.lower() == "quit":
|
||||||
print("👋 Goodbye!")
|
print("👋 Goodbye!")
|
||||||
break
|
break
|
||||||
elif user_input.lower() == 'clear':
|
elif user_input.lower() == "clear":
|
||||||
messages = []
|
messages = []
|
||||||
print("Chat history cleared")
|
print("Chat history cleared")
|
||||||
continue
|
continue
|
||||||
@@ -459,17 +462,16 @@ class APIDemo:
|
|||||||
|
|
||||||
messages.append({"role": "user", "content": user_input})
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
config = ChatCompletionConfig(
|
print("Assistant: ", end="", flush=True)
|
||||||
|
stream = await stream_chat_completions(
|
||||||
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
endpoint_name=self.endpoint_name,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
temperature=0.7
|
temperature=0.7
|
||||||
)
|
)
|
||||||
|
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
|
||||||
print("Assistant: ", end="", flush=True)
|
|
||||||
|
|
||||||
response = self.client.call_chat_completions(config)
|
|
||||||
assistant_content = self.handle_streaming_response(response, show_reasoning=True)
|
|
||||||
|
|
||||||
# Add assistant response to conversation history
|
# Add assistant response to conversation history
|
||||||
messages.append({"role": "assistant", "content": assistant_content})
|
messages.append({"role": "assistant", "content": assistant_content})
|
||||||
@@ -478,101 +480,66 @@ class APIDemo:
|
|||||||
print("\n👋 Chat interrupted. Goodbye!")
|
print("\n👋 Chat interrupted. Goodbye!")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"\nError: {e}")
|
log.error("\nError: %s", e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
def main():
|
# ---------------------- CLI ----------------------
|
||||||
"""Main function with CLI switches for different tests"""
|
def build_arg_parser() -> argparse.ArgumentParser:
|
||||||
from lib.test_utils import test_args
|
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
|
||||||
|
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
|
||||||
|
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||||
|
|
||||||
# Add mandatory model argument
|
modes = p.add_mutually_exclusive_group(required=False)
|
||||||
test_args.add_argument(
|
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
|
||||||
"--model",
|
modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
|
||||||
required=True,
|
modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
|
||||||
help="Model to use for requests (required)"
|
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
|
||||||
)
|
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
|
||||||
|
return p
|
||||||
|
|
||||||
# Add test mode arguments
|
|
||||||
test_args.add_argument(
|
|
||||||
"--completion",
|
|
||||||
action="store_true",
|
|
||||||
help="Test completions endpoint"
|
|
||||||
)
|
|
||||||
test_args.add_argument(
|
|
||||||
"--chat",
|
|
||||||
action="store_true",
|
|
||||||
help="Test chat completions endpoint (non-streaming)"
|
|
||||||
)
|
|
||||||
test_args.add_argument(
|
|
||||||
"--chat-stream",
|
|
||||||
action="store_true",
|
|
||||||
help="Test chat completions endpoint with streaming"
|
|
||||||
)
|
|
||||||
test_args.add_argument(
|
|
||||||
"--tools",
|
|
||||||
action="store_true",
|
|
||||||
help="Test function calling with ls tool (non-streaming)"
|
|
||||||
)
|
|
||||||
test_args.add_argument(
|
|
||||||
"--interactive",
|
|
||||||
action="store_true",
|
|
||||||
help="Start interactive streaming chat session"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = test_args.parse_args()
|
async def main_async():
|
||||||
|
args = build_arg_parser().parse_args()
|
||||||
|
|
||||||
# Check that only one test mode is selected
|
selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
|
||||||
test_modes = [
|
if selected == 0:
|
||||||
args.completion, args.chat, args.chat_stream,
|
|
||||||
args.tools, args.interactive
|
|
||||||
]
|
|
||||||
selected_count = sum(test_modes)
|
|
||||||
|
|
||||||
if selected_count == 0:
|
|
||||||
print("Please specify exactly one test mode:")
|
print("Please specify exactly one test mode:")
|
||||||
print(" --completion : Test completions endpoint")
|
print(" --completion : Test completions endpoint")
|
||||||
print(" --chat : Test chat completions endpoint (non-streaming)")
|
print(" --chat : Test chat completions endpoint (non-streaming)")
|
||||||
print(" --chat-stream : Test chat completions endpoint with streaming")
|
print(" --chat-stream : Test chat completions endpoint with streaming")
|
||||||
print(" --tools : Test function calling with ls tool (non-streaming)")
|
print(" --tools : Test function calling with ls tool")
|
||||||
print(" --interactive : Start interactive streaming chat session")
|
print(" --interactive : Start interactive streaming chat session")
|
||||||
print(f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT")
|
print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
elif selected_count > 1:
|
elif selected > 1:
|
||||||
print("Please specify exactly one test mode")
|
print("Please specify exactly one test mode")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
try:
|
|
||||||
# Create the core API client
|
|
||||||
client = APIClient(
|
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
|
||||||
api_key=args.api_key,
|
|
||||||
server_url=args.server_url
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create tool manager and demo (passing the model parameter)
|
|
||||||
tool_manager = ToolManager()
|
|
||||||
demo = APIDemo(client, args.model, tool_manager)
|
|
||||||
|
|
||||||
print(f"Using model: {args.model}")
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
print(f"Using model: {args.model}")
|
||||||
|
print(f"Using endpoint: {args.endpoint}")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with Serverless() as client:
|
||||||
|
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
|
||||||
|
|
||||||
# Run the selected test
|
|
||||||
if args.completion:
|
if args.completion:
|
||||||
demo.demo_completions()
|
await demo.demo_completions()
|
||||||
elif args.chat:
|
elif args.chat:
|
||||||
demo.demo_chat(use_streaming=False)
|
await demo.demo_chat(use_streaming=False)
|
||||||
elif args.chat_stream:
|
elif args.chat_stream:
|
||||||
demo.demo_chat(use_streaming=True)
|
await demo.demo_chat(use_streaming=True)
|
||||||
elif args.tools:
|
elif args.tools:
|
||||||
demo.demo_ls_tool()
|
await demo.demo_ls_tool()
|
||||||
elif args.interactive:
|
elif args.interactive:
|
||||||
demo.interactive_chat()
|
await demo.interactive_chat()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error during test: {e}", exc_info=True)
|
log.error("Error during test: %s", e, exc_info=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
asyncio.run(main_async())
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ from typing import Optional, List, Dict, Any
|
|||||||
class SerializableDataclass:
|
class SerializableDataclass:
|
||||||
def _serialize_recursive(self, obj: Any) -> Any:
|
def _serialize_recursive(self, obj: Any) -> Any:
|
||||||
if is_dataclass(obj):
|
if is_dataclass(obj):
|
||||||
return {field.name: self._serialize_recursive(getattr(obj, field.name))
|
return {
|
||||||
for field in fields(obj)}
|
field.name: self._serialize_recursive(getattr(obj, field.name))
|
||||||
|
for field in fields(obj)
|
||||||
|
}
|
||||||
elif isinstance(obj, dict):
|
elif isinstance(obj, dict):
|
||||||
return {key: self._serialize_recursive(value) for key, value in obj.items()}
|
return {key: self._serialize_recursive(value) for key, value in obj.items()}
|
||||||
elif isinstance(obj, (list, tuple)):
|
elif isinstance(obj, (list, tuple)):
|
||||||
@@ -27,6 +29,7 @@ class SerializableDataclass:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class CompletionConfig(SerializableDataclass):
|
class CompletionConfig(SerializableDataclass):
|
||||||
"""Configuration for completion requests"""
|
"""Configuration for completion requests"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
prompt: str = "Hello"
|
prompt: str = "Hello"
|
||||||
max_tokens: int = 256
|
max_tokens: int = 256
|
||||||
@@ -39,8 +42,9 @@ class CompletionConfig(SerializableDataclass):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ChatCompletionConfig(SerializableDataclass):
|
class ChatCompletionConfig(SerializableDataclass):
|
||||||
"""Configuration for chat completion requests"""
|
"""Configuration for chat completion requests"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
messages: list = None
|
messages: list = field(default_factory=list)
|
||||||
max_tokens: int = 2096
|
max_tokens: int = 2096
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
top_k: int = 20
|
top_k: int = 20
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os, json, random
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
|
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
|
||||||
from typing import Union, Type, Dict, Any
|
from typing import Union, Type, Dict, Any, Optional
|
||||||
from aiohttp import web, ClientResponse
|
from aiohttp import web, ClientResponse
|
||||||
import nltk
|
import nltk
|
||||||
import logging
|
import logging
|
||||||
@@ -14,15 +14,15 @@ log = logging.getLogger(__name__)
|
|||||||
"""
|
"""
|
||||||
Generic dataclass accepts any dictionary in input.
|
Generic dataclass accepts any dictionary in input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenericData(ApiPayload, ABC):
|
class GenericData(ApiPayload, ABC):
|
||||||
input: Dict[str, Any]
|
input: Dict[str, Any]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
|
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
|
||||||
return cls(
|
return cls(input=data["input"])
|
||||||
input=data["input"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
|
||||||
@@ -39,9 +39,7 @@ class GenericData(ApiPayload, ABC):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Create clean data dict and delegate to from_dict
|
# Create clean data dict and delegate to from_dict
|
||||||
clean_data = {
|
clean_data = {"input": json_msg["input"]}
|
||||||
"input": json_msg["input"]
|
|
||||||
}
|
|
||||||
|
|
||||||
return cls.from_dict(clean_data)
|
return cls.from_dict(clean_data)
|
||||||
|
|
||||||
@@ -60,6 +58,7 @@ class GenericData(ApiPayload, ABC):
|
|||||||
def count_workload(self) -> int:
|
def count_workload(self) -> int:
|
||||||
return self.input.get("max_tokens", 0)
|
return self.input.get("max_tokens", 0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenericHandler(EndpointHandler[GenericData], ABC):
|
class GenericHandler(EndpointHandler[GenericData], ABC):
|
||||||
|
|
||||||
@@ -69,8 +68,8 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def healthcheck_endpoint(self) -> str:
|
def healthcheck_endpoint(self) -> Optional[str]:
|
||||||
return os.environ.get('MODEL_HEALTH_ENDPOINT')
|
return os.environ.get("MODEL_HEALTH_ENDPOINT")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def payload_cls(cls) -> Type[GenericData]:
|
def payload_cls(cls) -> Type[GenericData]:
|
||||||
@@ -87,10 +86,10 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
|
|||||||
case 200:
|
case 200:
|
||||||
# Check if the response is actually streaming based on response headers/content-type
|
# Check if the response is actually streaming based on response headers/content-type
|
||||||
is_streaming_response = (
|
is_streaming_response = (
|
||||||
model_response.content_type == "text/event-stream" or
|
model_response.content_type == "text/event-stream"
|
||||||
model_response.content_type == "application/x-ndjson" or
|
or model_response.content_type == "application/x-ndjson"
|
||||||
model_response.headers.get("Transfer-Encoding") == "chunked" or
|
or model_response.headers.get("Transfer-Encoding") == "chunked"
|
||||||
"stream" in model_response.content_type.lower()
|
or "stream" in model_response.content_type.lower()
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_streaming_response:
|
if is_streaming_response:
|
||||||
@@ -109,28 +108,42 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
|
|||||||
return web.Response(
|
return web.Response(
|
||||||
body=content,
|
body=content,
|
||||||
status=200,
|
status=200,
|
||||||
content_type=model_response.content_type
|
content_type=model_response.content_type,
|
||||||
)
|
)
|
||||||
case code:
|
case code:
|
||||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||||
return web.Response(status=code)
|
return web.Response(status=code)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompletionsData(GenericData):
|
class CompletionsData(GenericData):
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls) -> "CompletionsData":
|
def for_test(cls) -> "CompletionsData":
|
||||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||||
|
|
||||||
|
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||||
|
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||||
|
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||||
|
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||||
|
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||||
|
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||||
|
woodlands, shrublands, and mountainous areas.
|
||||||
|
|
||||||
|
Please answer the following question based on the above context."""
|
||||||
|
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||||
model = os.environ.get("MODEL_NAME")
|
model = os.environ.get("MODEL_NAME")
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError("MODEL_NAME environment variable not set")
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
|
|
||||||
test_input = {
|
test_input = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": prompt,
|
"prompt": f"{system_prompt}\n\n{unique_question}",
|
||||||
"temperature": 0.7
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 500,
|
||||||
}
|
}
|
||||||
return cls(input=test_input)
|
return cls(input=test_input)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompletionsHandler(GenericHandler):
|
class CompletionsHandler(GenericHandler):
|
||||||
@property
|
@property
|
||||||
@@ -144,13 +157,25 @@ class CompletionsHandler(GenericHandler):
|
|||||||
def make_benchmark_payload(self) -> CompletionsData:
|
def make_benchmark_payload(self) -> CompletionsData:
|
||||||
return CompletionsData.for_test()
|
return CompletionsData.for_test()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatCompletionsData(GenericData):
|
class ChatCompletionsData(GenericData):
|
||||||
"""Chat completions-specific data implementation"""
|
"""Chat completions-specific data implementation"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls) -> "ChatCompletionsData":
|
def for_test(cls) -> "ChatCompletionsData":
|
||||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||||
|
|
||||||
|
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||||
|
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||||
|
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||||
|
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||||
|
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||||
|
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||||
|
woodlands, shrublands, and mountainous areas.
|
||||||
|
|
||||||
|
Please answer the following question based on the above context."""
|
||||||
|
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||||
model = os.environ.get("MODEL_NAME")
|
model = os.environ.get("MODEL_NAME")
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError("MODEL_NAME environment variable not set")
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
@@ -158,11 +183,16 @@ class ChatCompletionsData(GenericData):
|
|||||||
# Chat completions use messages format instead of prompt
|
# Chat completions use messages format instead of prompt
|
||||||
test_input = {
|
test_input = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [
|
||||||
"temperature": 0.7
|
{"role": "system", "content": system_prompt}, # Shared prefix
|
||||||
|
{"role": "user", "content": unique_question} # Unique per request
|
||||||
|
],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 500,
|
||||||
}
|
}
|
||||||
return cls(input=test_input)
|
return cls(input=test_input)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatCompletionsHandler(GenericHandler):
|
class ChatCompletionsHandler(GenericHandler):
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ MODEL_SERVER_START_LOG_MSG = [
|
|||||||
"llama runner started", # Ollama
|
"llama runner started", # Ollama
|
||||||
'"message":"Connected","target":"text_generation_router"', # TGI
|
'"message":"Connected","target":"text_generation_router"', # TGI
|
||||||
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
||||||
|
"main: model loaded" # llama.cpp
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
@@ -31,9 +32,10 @@ logging.basicConfig(
|
|||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
backend = Backend(
|
backend = Backend(
|
||||||
model_server_url=os.environ.get("MODEL_SERVER_URL"),
|
model_server_url=os.environ["MODEL_SERVER_URL"],
|
||||||
model_log_file=os.environ.get("MODEL_LOG"),
|
model_log_file=os.environ["MODEL_LOG"],
|
||||||
allow_parallel_requests=True,
|
allow_parallel_requests=True,
|
||||||
|
max_wait_time=600.0,
|
||||||
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
||||||
log_actions=[
|
log_actions=[
|
||||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||||
@@ -45,9 +47,11 @@ backend = Backend(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_ping(_):
|
async def handle_ping(_):
|
||||||
return web.Response(body="pong")
|
return web.Response(body="pong")
|
||||||
|
|
||||||
|
|
||||||
routes = [
|
routes = [
|
||||||
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
|
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
|
||||||
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
|
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
|
||||||
|
|||||||
+416
-10
@@ -1,8 +1,395 @@
|
|||||||
from lib.test_utils import test_load_cmd, test_args
|
from lib.test_utils import test_args
|
||||||
|
from utils.endpoint_util import Endpoint
|
||||||
|
from utils.ssl import get_cert_file_path
|
||||||
|
from lib.data_types import AuthData
|
||||||
from .data_types.server import CompletionsData
|
from .data_types.server import CompletionsData
|
||||||
import os
|
|
||||||
|
|
||||||
WORKER_ENDPOINT = "/v1/completions"
|
import os
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import requests
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from collections import Counter
|
||||||
|
from urllib.parse import urljoin, urlparse
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Headless plotting
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import logging
|
||||||
|
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
|
||||||
|
def get_incremented_path(path: str) -> str:
|
||||||
|
base, ext = os.path.splitext(path)
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return path
|
||||||
|
i = 1
|
||||||
|
while os.path.exists(f"{base}-{i}{ext}"):
|
||||||
|
i += 1
|
||||||
|
return f"{base}-{i}{ext}"
|
||||||
|
|
||||||
|
WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReqResult:
|
||||||
|
worker_url: str
|
||||||
|
route_ms: float
|
||||||
|
worker_ms: float
|
||||||
|
total_ms: float
|
||||||
|
ok: bool
|
||||||
|
error: str = ""
|
||||||
|
status_code: int = 0
|
||||||
|
t_start: float = 0.0
|
||||||
|
t_end: float = 0.0
|
||||||
|
workload: float = 0.0
|
||||||
|
|
||||||
|
def do_one(endpoint_name: str,
|
||||||
|
endpoint_id: int,
|
||||||
|
endpoint_api_key: str,
|
||||||
|
server_url: str,
|
||||||
|
worker_endpoint: str,
|
||||||
|
payload,
|
||||||
|
results_list,
|
||||||
|
t0,
|
||||||
|
status_samples,
|
||||||
|
route_session,
|
||||||
|
worker_session):
|
||||||
|
try:
|
||||||
|
workload = payload.count_workload()
|
||||||
|
route_payload = {"endpoint_id": endpoint_id, "api_key": endpoint_api_key, "cost": workload}
|
||||||
|
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
|
||||||
|
start = time.time()
|
||||||
|
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
|
||||||
|
t_after_route = time.time()
|
||||||
|
if r0.status_code != 200:
|
||||||
|
results_list.append(ReqResult(worker_url="",
|
||||||
|
route_ms=(t_after_route - start) * 1000.0,
|
||||||
|
worker_ms=0.0,
|
||||||
|
total_ms=(t_after_route - start) * 1000.0,
|
||||||
|
ok=False,
|
||||||
|
error=f"route error {r0.reason} {r0.text}",
|
||||||
|
status_code=r0.status_code,
|
||||||
|
t_start=start - t0,
|
||||||
|
t_end=t_after_route - t0,
|
||||||
|
workload=workload))
|
||||||
|
return
|
||||||
|
msg = r0.json()
|
||||||
|
|
||||||
|
# 1) Check if we got a worker back from route
|
||||||
|
worker_url = msg.get("url", "")
|
||||||
|
if not worker_url:
|
||||||
|
status = msg.get("status", "")
|
||||||
|
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
|
||||||
|
if m:
|
||||||
|
tot, loading, standby, err = map(int, m.groups())
|
||||||
|
idle = max(tot - loading - standby - err, 0)
|
||||||
|
status_samples.append((time.time() - t0, idle))
|
||||||
|
|
||||||
|
# 2) If we got a worker, send the request
|
||||||
|
if worker_url:
|
||||||
|
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
|
||||||
|
t_before_worker = time.time()
|
||||||
|
r1 = worker_session.post(
|
||||||
|
urljoin(worker_url, worker_endpoint),
|
||||||
|
json=req,
|
||||||
|
verify=get_cert_file_path(),
|
||||||
|
timeout=(4, 120),
|
||||||
|
)
|
||||||
|
t_after_worker = time.time()
|
||||||
|
if r1.status_code != 200:
|
||||||
|
results_list.append(ReqResult(worker_url=worker_url,
|
||||||
|
route_ms=(t_after_route - start) * 1000.0,
|
||||||
|
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||||
|
total_ms=(t_after_worker - start) * 1000.0,
|
||||||
|
ok=False,
|
||||||
|
error=f"worker inference error {r1.reason} {r1.text}",
|
||||||
|
status_code=r1.status_code,
|
||||||
|
t_start=start - t0,
|
||||||
|
t_end=t_after_worker - t0,
|
||||||
|
workload=workload))
|
||||||
|
return
|
||||||
|
# Success case
|
||||||
|
results_list.append(ReqResult(worker_url=worker_url,
|
||||||
|
route_ms=(t_after_route - start) * 1000.0,
|
||||||
|
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||||
|
total_ms=(t_after_worker - start) * 1000.0,
|
||||||
|
ok=True,
|
||||||
|
error="",
|
||||||
|
status_code=200,
|
||||||
|
t_start=start - t0,
|
||||||
|
t_end=t_after_worker - t0,
|
||||||
|
workload=workload))
|
||||||
|
|
||||||
|
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
|
||||||
|
if worker_url:
|
||||||
|
try:
|
||||||
|
r_status = route_session.post(
|
||||||
|
urljoin(server_url, "/get_endpoint_workers/"),
|
||||||
|
json={"id": endpoint_id},
|
||||||
|
headers={"Authorization": f"Bearer {endpoint_api_key}"},
|
||||||
|
timeout=3,
|
||||||
|
)
|
||||||
|
if r_status.status_code == 200:
|
||||||
|
workers = r_status.json()
|
||||||
|
idle = 0
|
||||||
|
for w in workers:
|
||||||
|
st = str(w.get("status", "")).lower()
|
||||||
|
if (st in ("idle")):
|
||||||
|
idle += 1
|
||||||
|
status_samples.append((time.time() - t0, idle))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
t = time.time()
|
||||||
|
results_list.append(ReqResult(worker_url="",
|
||||||
|
route_ms=0.0,
|
||||||
|
worker_ms=0.0,
|
||||||
|
total_ms=0.0,
|
||||||
|
ok=False,
|
||||||
|
error=f"unknown error {e}",
|
||||||
|
status_code=0,
|
||||||
|
t_start=t - t0,
|
||||||
|
t_end=t - t0,
|
||||||
|
workload=0.0))
|
||||||
|
|
||||||
|
def run_load_with_metrics(num_requests: int,
|
||||||
|
requests_per_second: float,
|
||||||
|
endpoint_group_name: str,
|
||||||
|
account_api_key: str,
|
||||||
|
server_url: str,
|
||||||
|
worker_endpoint: str,
|
||||||
|
instance: str,
|
||||||
|
out_path: str):
|
||||||
|
|
||||||
|
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
|
||||||
|
account_api_key=account_api_key,
|
||||||
|
instance=instance)
|
||||||
|
if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"):
|
||||||
|
print(f"Endpoint {endpoint_group_name} not found for API key")
|
||||||
|
return
|
||||||
|
endpoint_id = int(ep_info["id"])
|
||||||
|
endpoint_api_key = ep_info["api_key"]
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
results = []
|
||||||
|
status_samples = []
|
||||||
|
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
|
||||||
|
submit_queue_factor = 2 # cap queued tasks to reduce memory
|
||||||
|
|
||||||
|
# Shared HTTP sessions with connection pooling (persistent connections)
|
||||||
|
def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session:
|
||||||
|
sess = requests.Session()
|
||||||
|
adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0)
|
||||||
|
sess.mount("https://", adapter)
|
||||||
|
sess.mount("http://", adapter)
|
||||||
|
return sess
|
||||||
|
|
||||||
|
# Router: mostly single host, small connection pool is sufficient
|
||||||
|
route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
|
||||||
|
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
|
||||||
|
worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
|
||||||
|
|
||||||
|
# Fire requests using a thread pool, scheduling at requested RPS
|
||||||
|
inflight = set()
|
||||||
|
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||||
|
for i in range(num_requests):
|
||||||
|
# Pace submissions to RPS
|
||||||
|
target_time = t0 + i / max(requests_per_second, 1e-9)
|
||||||
|
sleep_s = target_time - time.time()
|
||||||
|
if sleep_s > 0:
|
||||||
|
time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive
|
||||||
|
|
||||||
|
payload = CompletionsData.for_test()
|
||||||
|
fut = executor.submit(
|
||||||
|
do_one,
|
||||||
|
endpoint_group_name,
|
||||||
|
endpoint_id,
|
||||||
|
endpoint_api_key,
|
||||||
|
server_url,
|
||||||
|
worker_endpoint,
|
||||||
|
payload,
|
||||||
|
results,
|
||||||
|
t0,
|
||||||
|
status_samples,
|
||||||
|
route_session,
|
||||||
|
worker_session,
|
||||||
|
)
|
||||||
|
inflight.add(fut)
|
||||||
|
# Prevent unbounded queue growth
|
||||||
|
if len(inflight) >= max_concurrency * submit_queue_factor:
|
||||||
|
done, not_done = wait(inflight, return_when=FIRST_COMPLETED)
|
||||||
|
inflight = not_done
|
||||||
|
# Wait for all outstanding tasks
|
||||||
|
if inflight:
|
||||||
|
wait(inflight)
|
||||||
|
# Close sessions
|
||||||
|
try:
|
||||||
|
route_session.close()
|
||||||
|
finally:
|
||||||
|
worker_session.close()
|
||||||
|
|
||||||
|
# Aggregate results
|
||||||
|
oks = [r for r in results if r.ok]
|
||||||
|
errs = [r for r in results if not r.ok]
|
||||||
|
total_reqs = len(results)
|
||||||
|
succ = len(oks)
|
||||||
|
|
||||||
|
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
|
||||||
|
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
|
||||||
|
route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
|
||||||
|
|
||||||
|
avg_total = float(np.mean(total_ms)) if succ else 0.0
|
||||||
|
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
|
||||||
|
avg_route = float(np.mean(route_ms)) if succ else 0.0
|
||||||
|
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
|
||||||
|
|
||||||
|
# Distribution over workers (by host:port)
|
||||||
|
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
|
||||||
|
dist = Counter(hosts)
|
||||||
|
|
||||||
|
# Idle over time (mode per second)
|
||||||
|
idle_ts, idle_vals = [], []
|
||||||
|
if status_samples:
|
||||||
|
buckets = {}
|
||||||
|
for ts, idle in status_samples:
|
||||||
|
k = int(ts)
|
||||||
|
buckets.setdefault(k, []).append(idle)
|
||||||
|
keys = sorted(buckets.keys())
|
||||||
|
idle_ts = keys
|
||||||
|
# Use the most frequent sampled value per second (mode) to keep integer counts
|
||||||
|
idle_vals = []
|
||||||
|
for k in keys:
|
||||||
|
vals_k = [int(v) for v in buckets[k]]
|
||||||
|
if vals_k:
|
||||||
|
cnt = Counter(vals_k)
|
||||||
|
idle_vals.append(cnt.most_common(1)[0][0])
|
||||||
|
else:
|
||||||
|
idle_vals.append(0)
|
||||||
|
|
||||||
|
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
|
||||||
|
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
|
||||||
|
print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
|
||||||
|
if errs:
|
||||||
|
print("Sample errors:")
|
||||||
|
for e in errs[:5]:
|
||||||
|
print(f" {e.status_code} {e.error}")
|
||||||
|
|
||||||
|
# Plot: 2x3 grid
|
||||||
|
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||||||
|
fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}")
|
||||||
|
|
||||||
|
# Dist per worker
|
||||||
|
ax0 = axes[0, 0]
|
||||||
|
if dist:
|
||||||
|
items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True)
|
||||||
|
labels, counts = zip(*items)
|
||||||
|
ax0.bar(range(len(labels)), counts)
|
||||||
|
ax0.set_xticks(range(len(labels)))
|
||||||
|
ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||||
|
ax0.set_title("Request distribution over workers")
|
||||||
|
ax0.set_ylabel("count")
|
||||||
|
|
||||||
|
# Latency histogram (total)
|
||||||
|
ax1 = axes[0, 1]
|
||||||
|
if succ:
|
||||||
|
ax1.hist(total_ms, bins=30)
|
||||||
|
ax1.set_title("Total latency (ms)")
|
||||||
|
ax1.set_xlabel("ms")
|
||||||
|
ax1.set_ylabel("freq")
|
||||||
|
|
||||||
|
# Eligible workers over time
|
||||||
|
ax_idle = axes[0, 2]
|
||||||
|
if idle_ts:
|
||||||
|
ax_idle.plot(idle_ts, idle_vals, "-o", ms=3)
|
||||||
|
ax_idle.set_title("Eligible workers over time")
|
||||||
|
ax_idle.set_xlabel("time (s)")
|
||||||
|
ax_idle.set_ylabel("eligible count")
|
||||||
|
|
||||||
|
# Throughput over time (completions/sec)
|
||||||
|
ax_idle = axes[1, 0]
|
||||||
|
ax_idle.clear()
|
||||||
|
if succ:
|
||||||
|
per_sec = {}
|
||||||
|
for r in oks:
|
||||||
|
s = int(r.t_end)
|
||||||
|
per_sec[s] = per_sec.get(s, 0) + 1
|
||||||
|
ts = sorted(per_sec.keys())
|
||||||
|
vals = [per_sec[t] for t in ts]
|
||||||
|
ax_idle.plot(ts, vals, "-o", ms=3)
|
||||||
|
ax_idle.set_title("Completions per second")
|
||||||
|
ax_idle.set_xlabel("time (s)")
|
||||||
|
ax_idle.set_ylabel("completions / sec")
|
||||||
|
|
||||||
|
# Summary text
|
||||||
|
ax3 = axes[1, 1]
|
||||||
|
ax3.axis("off")
|
||||||
|
text = (
|
||||||
|
f"Total requests: {total_reqs}\n"
|
||||||
|
f"Success: {succ} Errors: {len(errs)}\n"
|
||||||
|
f"Avg total latency: {avg_total:.1f} ms\n"
|
||||||
|
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
|
||||||
|
f"Avg route latency: {avg_route:.1f} ms\n"
|
||||||
|
f"Avg worker latency: {avg_worker:.1f} ms\n"
|
||||||
|
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
|
||||||
|
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
|
||||||
|
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
|
||||||
|
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
|
||||||
|
)
|
||||||
|
ax3.set_title("Summary")
|
||||||
|
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
|
||||||
|
|
||||||
|
# Error count over time
|
||||||
|
ax_errors = axes[1, 2]
|
||||||
|
all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
|
||||||
|
if all_end_times:
|
||||||
|
min_second = min(all_end_times)
|
||||||
|
max_second = max(all_end_times)
|
||||||
|
# Count errors per second
|
||||||
|
errors_per_second = {}
|
||||||
|
for result in errs:
|
||||||
|
second = int(result.t_end)
|
||||||
|
errors_per_second[second] = errors_per_second.get(second, 0) + 1
|
||||||
|
# Create complete timeline including zeros
|
||||||
|
time_seconds = list(range(min_second, max_second + 1))
|
||||||
|
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
|
||||||
|
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
|
||||||
|
ax_errors.set_title("Errors per second")
|
||||||
|
ax_errors.set_xlabel("time (s)")
|
||||||
|
ax_errors.set_ylabel("errors / sec")
|
||||||
|
|
||||||
|
# Ensure unique output path and create directory if needed
|
||||||
|
final_out_path = get_incremented_path(out_path)
|
||||||
|
out_dir = os.path.dirname(final_out_path)
|
||||||
|
if out_dir:
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||||
|
plt.savefig(final_out_path, dpi=120)
|
||||||
|
print(f"Saved report to: {final_out_path}")
|
||||||
|
|
||||||
|
# Per-worker latency boxplot (top 12 by volume)
|
||||||
|
groups = {}
|
||||||
|
for r in oks:
|
||||||
|
host = urlparse(r.worker_url).netloc
|
||||||
|
groups.setdefault(host, []).append(r.total_ms)
|
||||||
|
items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12]
|
||||||
|
if items:
|
||||||
|
labels, data = zip(*items)
|
||||||
|
fig2, axb = plt.subplots(1, 1, figsize=(12, 5))
|
||||||
|
axb.boxplot(data, showfliers=False)
|
||||||
|
axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||||
|
axb.set_title("Per-worker latency (ms)")
|
||||||
|
axb.set_ylabel("ms")
|
||||||
|
plt.tight_layout()
|
||||||
|
extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png")
|
||||||
|
plt.savefig(extra_out, dpi=120)
|
||||||
|
fig2.tight_layout()
|
||||||
|
fig2.savefig(extra_out, dpi=120)
|
||||||
|
print(f"Saved worker latency plot to: {extra_out}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Check if MODEL_NAME environment variable is set
|
# Check if MODEL_NAME environment variable is set
|
||||||
@@ -13,16 +400,35 @@ if __name__ == "__main__":
|
|||||||
"--model",
|
"--model",
|
||||||
dest="model",
|
dest="model",
|
||||||
required=not model_name_set,
|
required=not model_name_set,
|
||||||
help="Model to use for completions request (required if MODEL_NAME env var not set)"
|
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse known args to get model early, before test_load_cmd adds its args
|
# Parse known args to get model early, before adding load args
|
||||||
known_args, _ = test_args.parse_known_args()
|
known_args, _ = test_args.parse_known_args()
|
||||||
|
if hasattr(known_args, "model") and known_args.model:
|
||||||
# Set environment variable if model was provided
|
|
||||||
if hasattr(known_args, 'model') and known_args.model:
|
|
||||||
os.environ["MODEL_NAME"] = known_args.model
|
os.environ["MODEL_NAME"] = known_args.model
|
||||||
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
||||||
|
|
||||||
# Now call test_load_cmd normally - it will add its own args and re-parse
|
# Load test args
|
||||||
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)
|
test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
|
||||||
|
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
|
||||||
|
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
|
||||||
|
args = test_args.parse_args()
|
||||||
|
|
||||||
|
server_url = {
|
||||||
|
"prod": "https://run.vast.ai",
|
||||||
|
"alpha": "https://run-alpha.vast.ai",
|
||||||
|
"candidate": "https://run-candidate.vast.ai",
|
||||||
|
"local": "http://localhost:8080"
|
||||||
|
}.get(args.instance, "http://localhost:8080")
|
||||||
|
|
||||||
|
run_load_with_metrics(
|
||||||
|
num_requests=args.num_requests,
|
||||||
|
requests_per_second=args.requests_per_second,
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
|
account_api_key=args.api_key,
|
||||||
|
server_url=server_url,
|
||||||
|
worker_endpoint=WORKER_ENDPOINT,
|
||||||
|
instance=args.instance,
|
||||||
|
out_path=args.out_path,
|
||||||
|
)
|
||||||
+93
-9
@@ -1,19 +1,103 @@
|
|||||||
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
|
# HuggingFace TGI PyWorker
|
||||||
|
|
||||||
1. `generate`: Generates the LLM's response to a given prompt in a single request.
|
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||||
2. `generate_stream`: Streams the LLM's response token by token.
|
|
||||||
|
|
||||||
Both endpoints use the following API payload format:
|
## Instance Setup
|
||||||
|
|
||||||
|
1. Pick a template
|
||||||
|
|
||||||
|
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
|
||||||
|
|
||||||
|
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
|
||||||
|
|
||||||
|
The template can be configured via the template interface. You may want to change the model or startup arguments.
|
||||||
|
|
||||||
|
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||||
|
|
||||||
|
## Client Setup (Demo)
|
||||||
|
|
||||||
|
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/vast-ai/pyworker
|
||||||
|
cd pyworker
|
||||||
|
pip install uv
|
||||||
|
uv venv -p 3.12
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using the Test Client
|
||||||
|
|
||||||
|
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
|
||||||
|
|
||||||
|
First, set your API key as an environment variable:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export VAST_API_KEY=<your_api_key>
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
|
||||||
|
|
||||||
|
### Generate (Streaming)
|
||||||
|
|
||||||
|
Call to `/generate_stream` with streaming response:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Generate (Non-Streaming)
|
||||||
|
|
||||||
|
Call to `/generate` with json response:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Interactive Session (Streaming)
|
||||||
|
|
||||||
|
Interactive session with streaming responses. Type `quit` to exit.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
TGI provides two primary endpoints:
|
||||||
|
|
||||||
|
### Generate (Non-Streaming)
|
||||||
|
|
||||||
|
`/generate` - Returns the complete response in a single request.
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"inputs": "PROMPT",
|
"inputs": "Your prompt here",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_new_tokens": 250
|
"max_new_tokens": 1024,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"return_full_text": false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
|
### Generate Stream (Streaming)
|
||||||
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
|
|
||||||
approximately 2 seconds to complete.
|
`/generate_stream` - Streams the response token by token.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"inputs": "Your prompt here",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 1024,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"do_sample": true,
|
||||||
|
"return_full_text": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Notes
|
||||||
|
|
||||||
|
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
|
||||||
|
|||||||
+202
-100
@@ -1,10 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
from urllib.parse import urljoin
|
import os
|
||||||
import requests
|
import sys
|
||||||
from utils.endpoint_util import Endpoint
|
import argparse
|
||||||
|
|
||||||
|
from vastai import Serverless
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# ---------------------- Logging ----------------------
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG,
|
level=logging.DEBUG,
|
||||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
@@ -12,109 +15,208 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
# ---------------------- Defaults ----------------------
|
||||||
|
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||||
|
|
||||||
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
|
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
|
||||||
WORKER_ENDPOINT = "/generate"
|
MAX_TOKENS = 1024
|
||||||
COST = 100
|
DEFAULT_TEMPERATURE = 0.7
|
||||||
route_payload = {
|
|
||||||
"endpoint": endpoint_group_name,
|
|
||||||
"api_key": api_key,
|
# ---------------------- API Calls ----------------------
|
||||||
"cost": COST,
|
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
|
||||||
|
"""Non-streaming generation via /generate endpoint"""
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
|
"return_full_text": False,
|
||||||
}
|
}
|
||||||
response = requests.post(
|
|
||||||
urljoin(server_url, "/route/"),
|
|
||||||
json=route_payload,
|
|
||||||
timeout=4,
|
|
||||||
)
|
|
||||||
response.raise_for_status() # Raise an exception for bad status codes
|
|
||||||
message = response.json()
|
|
||||||
url = message["url"]
|
|
||||||
|
|
||||||
auth_data = dict(
|
|
||||||
signature=message["signature"],
|
|
||||||
cost=message["cost"],
|
|
||||||
endpoint=message["endpoint"],
|
|
||||||
reqnum=message["reqnum"],
|
|
||||||
url=url,
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
|
|
||||||
req_data = dict(payload=payload, auth_data=auth_data)
|
|
||||||
url = urljoin(url, WORKER_ENDPOINT)
|
|
||||||
print(f"url: {url}")
|
|
||||||
response = requests.post(url, json=req_data)
|
|
||||||
response.raise_for_status()
|
|
||||||
res = response.json()
|
|
||||||
print(res)
|
|
||||||
|
|
||||||
|
|
||||||
def call_generate_stream(
|
|
||||||
endpoint_group_name: str, api_key: str, server_url: str
|
|
||||||
) -> None:
|
|
||||||
WORKER_ENDPOINT = "/generate_stream"
|
|
||||||
COST = 100
|
|
||||||
route_payload = {
|
|
||||||
"endpoint": endpoint_group_name,
|
|
||||||
"api_key": api_key,
|
|
||||||
"cost": COST,
|
|
||||||
}
|
}
|
||||||
response = requests.post(
|
log.debug("POST /generate %s", json.dumps(payload)[:500])
|
||||||
urljoin(server_url, "/route/"),
|
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
|
||||||
json=route_payload,
|
return resp["response"]
|
||||||
timeout=4,
|
|
||||||
|
|
||||||
|
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
|
||||||
|
"""Streaming generation via /generate_stream endpoint"""
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
|
"do_sample": True,
|
||||||
|
"return_full_text": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
|
||||||
|
resp = await endpoint.request(
|
||||||
|
"/generate_stream",
|
||||||
|
payload,
|
||||||
|
cost=payload["parameters"]["max_new_tokens"],
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
response.raise_for_status() # Raise an exception for bad status codes
|
return resp["response"] # async generator
|
||||||
message = response.json()
|
|
||||||
url = message["url"]
|
|
||||||
print(f"url: {url}")
|
# ---------------------- Demo Runner ----------------------
|
||||||
auth_data = dict(
|
class APIDemo:
|
||||||
signature=message["signature"],
|
"""Demo and testing functionality for the TGI API client"""
|
||||||
cost=message["cost"],
|
|
||||||
endpoint=message["endpoint"],
|
def __init__(self, client: Serverless, endpoint_name: str):
|
||||||
reqnum=message["reqnum"],
|
self.client = client
|
||||||
url=message["url"],
|
self.endpoint_name = endpoint_name
|
||||||
|
|
||||||
|
async def handle_streaming_response(self, stream) -> str:
|
||||||
|
"""Process streaming response and print tokens"""
|
||||||
|
full_response = ""
|
||||||
|
printed_answer = False
|
||||||
|
|
||||||
|
async for event in stream:
|
||||||
|
tok = (event.get("token") or {}).get("text")
|
||||||
|
if tok:
|
||||||
|
if not printed_answer:
|
||||||
|
printed_answer = True
|
||||||
|
print("\n💬 Response: ", end="", flush=True)
|
||||||
|
print(tok, end="", flush=True)
|
||||||
|
full_response += tok
|
||||||
|
|
||||||
|
print() # newline
|
||||||
|
if printed_answer:
|
||||||
|
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
|
||||||
|
|
||||||
|
return full_response
|
||||||
|
|
||||||
|
async def demo_generate(self) -> None:
|
||||||
|
"""Demo non-streaming generation"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("GENERATE DEMO (NON-STREAMING)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
response = await call_generate(
|
||||||
|
client=self.client,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
prompt=DEFAULT_PROMPT,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
)
|
)
|
||||||
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
|
|
||||||
req_data = dict(payload=payload, auth_data=auth_data)
|
print(f"\n💬 Response: {response.get('generated_text', '')}")
|
||||||
url = urljoin(url, WORKER_ENDPOINT)
|
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
||||||
response = requests.post(url, json=req_data, stream=True)
|
|
||||||
response.raise_for_status() # Raise an exception for bad status codes
|
async def demo_generate_stream(self) -> None:
|
||||||
for line in response.iter_lines():
|
"""Demo streaming generation"""
|
||||||
payload = line.decode().lstrip("data:").rstrip()
|
print("=" * 60)
|
||||||
if payload:
|
print("GENERATE DEMO (STREAMING)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
stream = await call_generate_stream(
|
||||||
|
client=self.client,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
prompt=DEFAULT_PROMPT,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(payload)
|
await self.handle_streaming_response(stream)
|
||||||
print(data["token"]["text"], end="")
|
except Exception as e:
|
||||||
sys.stdout.flush()
|
log.error("\nError during streaming: %s", e, exc_info=True)
|
||||||
except (json.JSONDecodeError, KeyError) as e:
|
|
||||||
log.warning(f"Failed to parse streaming response: {e}")
|
async def interactive_chat(self) -> None:
|
||||||
continue
|
"""Interactive session with streaming generation"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("INTERACTIVE STREAMING SESSION")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Using endpoint: {self.endpoint_name}")
|
||||||
|
print("Type 'quit' to exit")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
user_input = input("You: ").strip()
|
||||||
|
|
||||||
|
if user_input.lower() == "quit":
|
||||||
|
print("👋 Goodbye!")
|
||||||
|
break
|
||||||
|
elif not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("Assistant: ", end="", flush=True)
|
||||||
|
stream = await call_generate_stream(
|
||||||
|
client=self.client,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
prompt=user_input,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
|
)
|
||||||
|
|
||||||
|
full_response = ""
|
||||||
|
async for event in stream:
|
||||||
|
tok = (event.get("token") or {}).get("text")
|
||||||
|
if tok:
|
||||||
|
print(tok, end="", flush=True)
|
||||||
|
full_response += tok
|
||||||
|
print() # newline
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n👋 Session interrupted. Goodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
log.error("\nError: %s", e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- CLI ----------------------
|
||||||
|
def build_arg_parser() -> argparse.ArgumentParser:
|
||||||
|
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
|
||||||
|
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||||
|
|
||||||
|
modes = p.add_mutually_exclusive_group(required=False)
|
||||||
|
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
|
||||||
|
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
|
||||||
|
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
async def main_async():
|
||||||
|
args = build_arg_parser().parse_args()
|
||||||
|
|
||||||
|
selected = sum([args.generate, args.generate_stream, args.interactive])
|
||||||
|
if selected == 0:
|
||||||
|
print("Please specify exactly one test mode:")
|
||||||
|
print(" --generate : Test generate endpoint (non-streaming)")
|
||||||
|
print(" --generate-stream : Test generate endpoint with streaming")
|
||||||
|
print(" --interactive : Start interactive streaming session")
|
||||||
|
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
|
||||||
|
sys.exit(1)
|
||||||
|
elif selected > 1:
|
||||||
|
print("Please specify exactly one test mode")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Using endpoint: {args.endpoint}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with Serverless() as client:
|
||||||
|
demo = APIDemo(client, args.endpoint)
|
||||||
|
|
||||||
|
if args.generate:
|
||||||
|
await demo.demo_generate()
|
||||||
|
elif args.generate_stream:
|
||||||
|
await demo.demo_generate_stream()
|
||||||
|
elif args.interactive:
|
||||||
|
await demo.interactive_chat()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error("Error during test: %s", e, exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from lib.test_utils import test_args
|
asyncio.run(main_async())
|
||||||
|
|
||||||
args = test_args.parse_args()
|
|
||||||
|
|
||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
|
||||||
endpoint_name=args.endpoint_group_name,
|
|
||||||
account_api_key=args.api_key,
|
|
||||||
instance=args.instance,
|
|
||||||
)
|
|
||||||
if endpoint_api_key:
|
|
||||||
try:
|
|
||||||
call_generate(
|
|
||||||
api_key=endpoint_api_key,
|
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
|
||||||
server_url=args.server_url,
|
|
||||||
)
|
|
||||||
call_generate_stream(
|
|
||||||
api_key=endpoint_api_key,
|
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
|
||||||
server_url=args.server_url,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error during API call: {e}")
|
|
||||||
else:
|
|
||||||
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
|
|
||||||
|
|||||||
Reference in New Issue
Block a user