Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6b5b1341a7 | |||
| 8be92c03de | |||
| adedb8ba90 | |||
| 2f543c01ad | |||
| 0bcd2219ea | |||
| 0339b471c5 | |||
| e143162438 | |||
| 7986e51e9e | |||
| 9c6ab78503 | |||
| 45e0c7d9ca | |||
| 7a792fd176 | |||
| e0449cb3c7 | |||
| a4339bd3f1 | |||
| 2b26e5e20c | |||
| d3727d4fd7 | |||
| a47c9d1ed0 | |||
| 0b14562a63 | |||
| de9b50abb9 | |||
| c510801723 | |||
| a12523b1d2 | |||
| eedf81c0a3 | |||
| 3adec1826d | |||
| b55bfa9611 | |||
| 7db54f3bd7 | |||
| d63a060202 | |||
| c6521cb6d4 | |||
| b7fe4ebb91 | |||
| 8ae7b74605 | |||
| 106067d716 | |||
| f5134d4bf5 | |||
| 47e5460532 | |||
| ec2ac0a21a | |||
| 2cde573c56 | |||
| b2e4a5db0c | |||
| 7437028cb2 | |||
| 02c8307af7 | |||
| 7c0f316eeb | |||
| b4025a744f | |||
| d190308329 | |||
| 944f83fc03 | |||
| f56bbc0ebe |
@@ -3,3 +3,4 @@
|
||||
__pycache__
|
||||
bin/
|
||||
lib64
|
||||
.venv
|
||||
@@ -39,11 +39,12 @@ reporting these metrics to the autoscaler.
|
||||
|
||||
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
|
||||
|
||||
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00)
|
||||
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447)
|
||||
* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d)
|
||||
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d)
|
||||
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938)
|
||||
|
||||
Currently available workers:
|
||||
* `hello_world`: A simple example worker for a basic LLM server.
|
||||
* `openai`: A simple example worker for a basic vLLM server.
|
||||
* `comfyui`: A worker for the ComfyUI image generation backend.
|
||||
* `tgi`: A worker for the Text Generation Inference backend.
|
||||
|
||||
|
||||
+22
-19
@@ -30,7 +30,7 @@ from lib.data_types import (
|
||||
BenchmarkResult
|
||||
)
|
||||
|
||||
VERSION = "0.1.0"
|
||||
VERSION = "0.2.1"
|
||||
|
||||
MSG_HISTORY_LEN = 100
|
||||
log = logging.getLogger(__file__)
|
||||
@@ -66,10 +66,17 @@ class Backend:
|
||||
unsecured: bool = dataclasses.field(
|
||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||
)
|
||||
report_addr: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
||||
)
|
||||
mtoken: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.metrics = Metrics()
|
||||
self.metrics._set_version(self.version)
|
||||
self.metrics._set_mtoken(self.mtoken)
|
||||
self._total_pubkey_fetch_errors = 0
|
||||
self._pubkey = self._fetch_pubkey()
|
||||
self.__start_healthcheck: bool = False
|
||||
@@ -104,23 +111,19 @@ class Backend:
|
||||
|
||||
#######################################Private#######################################
|
||||
def _fetch_pubkey(self):
|
||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||
report_addr = self.report_addr.rstrip("/")
|
||||
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
||||
try:
|
||||
result = subprocess.check_output(command, universal_newlines=True)
|
||||
log.debug("public key:")
|
||||
log.debug(result)
|
||||
key = None
|
||||
for _ in range(5):
|
||||
try:
|
||||
key = RSA.import_key(result)
|
||||
break
|
||||
except ValueError as e:
|
||||
log.debug(f"Error downloading key: {e}")
|
||||
time.sleep(15)
|
||||
if key is None:
|
||||
self._total_pubkey_fetch_errors += 1
|
||||
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
||||
self.backend_errored("Failed to get autoscaler pubkey")
|
||||
if key is not None:
|
||||
return key
|
||||
except (ValueError , subprocess.CalledProcessError) as e:
|
||||
log.debug(f"Error downloading key: {e}")
|
||||
self.backend_errored("Failed to get autoscaler pubkey")
|
||||
|
||||
|
||||
async def __handle_request(
|
||||
self,
|
||||
@@ -315,10 +318,10 @@ class Backend:
|
||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||
log.debug("already ran benchmark")
|
||||
# trigger model load
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
_ = await self.__call_api(
|
||||
handler=self.benchmark_handler, payload=payload
|
||||
)
|
||||
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||
# _ = await self.__call_api(
|
||||
# handler=self.benchmark_handler, payload=payload
|
||||
# )
|
||||
return float(f.readline())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
@@ -393,7 +396,7 @@ class Backend:
|
||||
)
|
||||
# some backends need a few seconds after logging successful startup before
|
||||
# they can begin accepting requests
|
||||
await sleep(5)
|
||||
# await sleep(5)
|
||||
try:
|
||||
max_throughput = await run_benchmark()
|
||||
self.__start_healthcheck = True
|
||||
@@ -414,7 +417,7 @@ class Backend:
|
||||
|
||||
async def tail_log():
|
||||
log.debug(f"tailing file: {self.model_log_file}")
|
||||
async with await open_file(self.model_log_file) as f:
|
||||
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
|
||||
while True:
|
||||
line = await f.readline()
|
||||
if line:
|
||||
|
||||
@@ -286,6 +286,7 @@ class AutoScalerData:
|
||||
"""Data that is reported to autoscaler"""
|
||||
|
||||
id: int
|
||||
mtoken: str
|
||||
version: str
|
||||
loadtime: float
|
||||
cur_load: float
|
||||
|
||||
+16
-2
@@ -28,6 +28,7 @@ def get_url() -> str:
|
||||
@dataclass
|
||||
class Metrics:
|
||||
version: str = "0"
|
||||
mtoken: str = ""
|
||||
last_metric_update: float = 0.0
|
||||
last_request_served: float = 0.0
|
||||
update_pending: bool = False
|
||||
@@ -142,12 +143,16 @@ class Metrics:
|
||||
def _set_version(self, version: str) -> None:
|
||||
self.version = version
|
||||
|
||||
def _set_mtoken(self, mtoken: str) -> None:
|
||||
self.mtoken = mtoken
|
||||
|
||||
#######################################Private#######################################
|
||||
|
||||
async def __send_delete_requests_and_reset(self):
|
||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||
data = {
|
||||
"worker_id": self.id,
|
||||
"mtoken": self.mtoken,
|
||||
"request_idxs": idxs,
|
||||
"success": success_flag,
|
||||
}
|
||||
@@ -209,6 +214,7 @@ class Metrics:
|
||||
def compute_autoscaler_data() -> AutoScalerData:
|
||||
return AutoScalerData(
|
||||
id=self.id,
|
||||
mtoken=self.mtoken,
|
||||
version=self.version,
|
||||
loadtime=(loadtime_snapshot or 0.0),
|
||||
new_load=self.model_metrics.workload_processing,
|
||||
@@ -228,17 +234,25 @@ class Metrics:
|
||||
|
||||
async def send_data(report_addr: str) -> bool:
|
||||
data = compute_autoscaler_data()
|
||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||
log_data = asdict(data)
|
||||
def obfuscate(secret: str) -> str:
|
||||
if secret is None:
|
||||
return ""
|
||||
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
||||
|
||||
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
||||
log.debug(
|
||||
"\n".join(
|
||||
[
|
||||
"#" * 60,
|
||||
f"sending data to autoscaler",
|
||||
f"{json.dumps((asdict(data)), indent=2)}",
|
||||
f"{json.dumps(log_data, indent=2)}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
session = await self.http()
|
||||
|
||||
+21
-1
@@ -3,15 +3,17 @@ import logging
|
||||
from typing import List
|
||||
import ssl
|
||||
from asyncio import run, gather
|
||||
|
||||
import asyncio
|
||||
|
||||
from lib.backend import Backend
|
||||
from lib.metrics import Metrics
|
||||
from aiohttp import web
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||
try:
|
||||
log.debug("getting certificate...")
|
||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||
if use_ssl is True:
|
||||
@@ -38,3 +40,21 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||
await gather(site.start(), backend._start_tracking())
|
||||
|
||||
run(main())
|
||||
|
||||
except Exception as e:
|
||||
err_msg = f"PyWorker failed to launch: {e}"
|
||||
log.error(err_msg)
|
||||
|
||||
async def beacon():
|
||||
metrics = Metrics()
|
||||
metrics._set_version(getattr(backend, "version", "0"))
|
||||
metrics._set_mtoken(getattr(backend, "mtoken", ""))
|
||||
try:
|
||||
while True:
|
||||
metrics._model_errored(err_msg)
|
||||
await metrics._Metrics__send_metrics_and_reset()
|
||||
await asyncio.sleep(10)
|
||||
finally:
|
||||
await metrics.aclose()
|
||||
|
||||
run(beacon())
|
||||
|
||||
@@ -8,3 +8,4 @@ Requests~=2.32
|
||||
transformers~=4.52
|
||||
utils==1.0.*
|
||||
hf_transfer>=0.1.9
|
||||
vastai-sdk>=0.2.0
|
||||
+47
-5
@@ -9,7 +9,7 @@ ENV_PATH="$WORKSPACE_DIR/worker-env"
|
||||
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
||||
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
||||
|
||||
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
|
||||
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
|
||||
USE_SSL="${USE_SSL:-true}"
|
||||
WORKER_PORT="${WORKER_PORT:-3000}"
|
||||
mkdir -p "$WORKSPACE_DIR"
|
||||
@@ -41,6 +41,14 @@ echo_var DEBUG_LOG
|
||||
echo_var PYWORKER_LOG
|
||||
echo_var MODEL_LOG
|
||||
|
||||
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||
if [ -e "$MODEL_LOG" ]; then
|
||||
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
|
||||
cat "$MODEL_LOG" >> "$MODEL_LOG.old"
|
||||
: > "$MODEL_LOG"
|
||||
fi
|
||||
|
||||
# Populate /etc/environment with quoted values
|
||||
if ! grep -q "VAST" /etc/environment; then
|
||||
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
|
||||
@@ -124,9 +132,43 @@ cd "$SERVER_DIR"
|
||||
|
||||
echo "launching PyWorker server"
|
||||
|
||||
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
||||
set +e
|
||||
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
|
||||
PY_STATUS=${PIPESTATUS[0]}
|
||||
set -e
|
||||
|
||||
if [ "${PY_STATUS}" -ne 0 ]; then
|
||||
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
|
||||
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
|
||||
MTOKEN="${MASTER_TOKEN:-}"
|
||||
VERSION="${PYWORKER_VERSION:-0}"
|
||||
|
||||
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
|
||||
for addr in "${REPORT_ADDRS[@]}"; do
|
||||
curl -sS -X POST -H 'Content-Type: application/json' \
|
||||
-d "$(cat <<JSON
|
||||
{
|
||||
"id": ${CONTAINER_ID:-0},
|
||||
"mtoken": "${MTOKEN}",
|
||||
"version": "${VERSION}",
|
||||
"loadtime": 0,
|
||||
"new_load": 0,
|
||||
"cur_load": 0,
|
||||
"rej_load": 0,
|
||||
"max_perf": 0,
|
||||
"cur_perf": 0,
|
||||
"error_msg": "${ERROR_MSG}",
|
||||
"num_requests_working": 0,
|
||||
"num_requests_recieved": 0,
|
||||
"additional_disk_usage": 0,
|
||||
"working_request_idxs": [],
|
||||
"cur_capacity": 0,
|
||||
"max_capacity": 0,
|
||||
"url": "${URL}"
|
||||
}
|
||||
JSON
|
||||
)" "${addr%/}/worker_status/" || true
|
||||
done
|
||||
fi
|
||||
|
||||
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
||||
echo "launching PyWorker server done"
|
||||
+13
-133
@@ -1,107 +1,16 @@
|
||||
import logging
|
||||
from .data_types import count_workload
|
||||
import uuid
|
||||
import random
|
||||
from urllib.parse import urljoin
|
||||
import json
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
import requests
|
||||
from vastai import Serverless
|
||||
|
||||
from lib.test_utils import print_truncate_res
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from .data_types import count_workload
|
||||
async def main():
|
||||
async with Serverless() as client:
|
||||
endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def call_text2image_workflow(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
"""Simple Text2Image using the new modifier-based approach"""
|
||||
|
||||
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
|
||||
"""Helper function for making requests with consistent error handling"""
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
verify=verify
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
log.error(f"HTTP error occurred during {context}: {http_err}")
|
||||
log.error(f"Status Code: {response.status_code}")
|
||||
log.error("Response content:", response.text)
|
||||
return None
|
||||
except requests.exceptions.Timeout:
|
||||
log.error(f"Timeout occurred during {context}: {url}")
|
||||
return None
|
||||
except requests.exceptions.ConnectionError:
|
||||
log.error(f"Connection error occurred during {context}: {url}")
|
||||
return None
|
||||
except json.JSONDecodeError as json_err:
|
||||
log.error(f"Failed to decode JSON response during {context}: {json_err}")
|
||||
if 'response' in locals():
|
||||
print("Response content:", response.text)
|
||||
return None
|
||||
except Exception as err:
|
||||
log.error(f"An unexpected error occurred during {context}: {err}")
|
||||
if 'response' in locals():
|
||||
log.error("Response content (if available):", response.text)
|
||||
return None
|
||||
|
||||
WORKER_ENDPOINT = "/generate/sync"
|
||||
|
||||
# This worker has concurrency = 1. All workloads have cost value 1.0
|
||||
COST = count_workload()
|
||||
|
||||
# Route to get worker URL
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
}
|
||||
|
||||
# First request - get routing information
|
||||
route_response = make_request(
|
||||
url=urljoin(server_url, "/route/"),
|
||||
payload=route_payload,
|
||||
timeout=4,
|
||||
context="route request"
|
||||
)
|
||||
|
||||
if route_response is None:
|
||||
return None
|
||||
|
||||
if "url" not in route_response or not route_response["url"]:
|
||||
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
|
||||
return None
|
||||
|
||||
if "status" in route_response:
|
||||
print(f"Autoscaler status: {route_response['status']}")
|
||||
return None
|
||||
|
||||
# Extract data from route response
|
||||
url = route_response["url"]
|
||||
auth_data = dict(
|
||||
signature=route_response["signature"],
|
||||
cost=route_response["cost"],
|
||||
endpoint=route_response["endpoint"],
|
||||
reqnum=route_response["reqnum"],
|
||||
url=route_response["url"],
|
||||
)
|
||||
|
||||
# Build the payload for the worker request
|
||||
worker_payload = {
|
||||
payload = {
|
||||
"input": {
|
||||
"request_id": str(uuid.uuid4()),
|
||||
"modifier": "Text2Image",
|
||||
@@ -116,40 +25,11 @@ def call_text2image_workflow(
|
||||
}
|
||||
}
|
||||
|
||||
req_data = dict(payload=worker_payload, auth_data=auth_data)
|
||||
worker_url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {worker_url}")
|
||||
|
||||
# Second request - call the worker endpoint
|
||||
worker_response = make_request(
|
||||
url=worker_url,
|
||||
payload=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
context="worker request"
|
||||
)
|
||||
|
||||
return worker_response
|
||||
response = await endpoint.request("/generate/sync", payload, cost=count_workload())
|
||||
|
||||
# Get the file from the path on the local machine using SCP or SFTP
|
||||
# or configure S3 to upload to cloud storage.
|
||||
print(response["response"]["output"][0]["local_path"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lib.test_utils import test_args
|
||||
|
||||
args = test_args.parse_args()
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
|
||||
if endpoint_api_key:
|
||||
result = call_text2image_workflow(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
if result is None:
|
||||
log.error("Text2Image workflow failed")
|
||||
else:
|
||||
print(result)
|
||||
else:
|
||||
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
|
||||
asyncio.run(main())
|
||||
@@ -7,20 +7,13 @@ from lib.test_utils import print_truncate_res
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
|
||||
"""
|
||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
from vastai import Serverless
|
||||
|
||||
|
||||
def call_default_workflow(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
ENDPOINT_NAME = "my-comfyui-endpoint"
|
||||
COST = 100 # Use a constant cost for image generation
|
||||
|
||||
def call_default_workflow(client: Serverless) -> None:
|
||||
WORKER_ENDPOINT = "/prompt"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
@@ -82,6 +75,7 @@ def call_custom_workflow_for_sd3(
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
request_idx=message["request_idx"],
|
||||
)
|
||||
workflow = {
|
||||
"3": {
|
||||
|
||||
+33
-26
@@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
|
||||
|
||||
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
||||
|
||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
|
||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
|
||||
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
|
||||
|
||||
|
||||
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
||||
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
|
||||
## Client Setup (Demo)
|
||||
|
||||
@@ -34,38 +33,20 @@ uv pip install -r requirements.txt
|
||||
|
||||
Several examples have been provided in the client to help you get started with your own implementation.
|
||||
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
First, set your API key as an environment variable:
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||
export VAST_API_KEY=<your_api_key>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||
```
|
||||
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
|
||||
|
||||
### Chat Completion (streaming)
|
||||
|
||||
Call to `/v1/chat/completions` with streaming response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
|
||||
Call to `/v1/chat/completions` with tool and json response.
|
||||
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Interactive Chat (streaming)
|
||||
@@ -75,6 +56,32 @@ Interactive session with calls to `/v1/chat/completions`.
|
||||
Type `clear` to clear the chat history or `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
|
||||
Call to `/v1/chat/completions` with tool and json response.
|
||||
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
|
||||
+358
-412
@@ -1,14 +1,15 @@
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from urllib.parse import urljoin
|
||||
from typing import Dict, Any, Optional, Iterator, Union, List
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
||||
import argparse
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from vastai import Serverless
|
||||
import asyncio
|
||||
|
||||
# ---------------------- Logging ----------------------
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
@@ -16,135 +17,20 @@ logging.basicConfig(
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
COMPLETIONS_PROMPT = "the capital of USA is"
|
||||
# ---------------------- Prompts ----------------------
|
||||
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
|
||||
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
|
||||
|
||||
|
||||
class APIClient:
|
||||
"""Lightweight client focused solely on API communication"""
|
||||
|
||||
# Remove the generic WORKER_ENDPOINT since we're now going direct
|
||||
DEFAULT_COST = 100
|
||||
DEFAULT_TIMEOUT = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_group_name: str,
|
||||
api_key: str,
|
||||
server_url: str,
|
||||
endpoint_api_key: str,
|
||||
):
|
||||
self.endpoint_group_name = endpoint_group_name
|
||||
self.api_key = api_key
|
||||
self.server_url = server_url
|
||||
self.endpoint_api_key = endpoint_api_key
|
||||
|
||||
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
||||
"""Get worker URL and auth data from routing service"""
|
||||
if not self.endpoint_api_key:
|
||||
raise ValueError("No valid endpoint API key available")
|
||||
|
||||
route_payload = {
|
||||
"endpoint": self.endpoint_group_name,
|
||||
"api_key": self.endpoint_api_key,
|
||||
"cost": cost,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
urljoin(self.server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=self.DEFAULT_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create auth data from routing response"""
|
||||
return {
|
||||
"signature": message["signature"],
|
||||
"cost": message["cost"],
|
||||
"endpoint": message["endpoint"],
|
||||
"reqnum": message["reqnum"],
|
||||
"url": message["url"],
|
||||
}
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
payload: Dict[str, Any],
|
||||
endpoint: str,
|
||||
method: str = "POST",
|
||||
stream: bool = False,
|
||||
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
"""Make request directly to the specific worker endpoint"""
|
||||
# Get worker URL and auth data
|
||||
cost = payload.get("max_tokens", self.DEFAULT_COST)
|
||||
message = self._get_worker_url(cost=cost)
|
||||
worker_url = message["url"]
|
||||
auth_data = self._create_auth_data(message)
|
||||
|
||||
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
|
||||
|
||||
url = urljoin(worker_url, endpoint)
|
||||
log.debug(f"Making direct request to: {url}")
|
||||
log.debug(f"Payload: {req_data}")
|
||||
|
||||
# Make the request using the specified method
|
||||
if method.upper() == "POST":
|
||||
response = requests.post(
|
||||
url, json=req_data, stream=stream, verify=get_cert_file_path()
|
||||
)
|
||||
elif method.upper() == "GET":
|
||||
response = requests.get(
|
||||
url, params=req_data, stream=stream, verify=get_cert_file_path()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
if stream:
|
||||
return self._handle_streaming_response(response)
|
||||
else:
|
||||
return response.json()
|
||||
|
||||
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
|
||||
"""Handle streaming response and yield tokens"""
|
||||
try:
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line:
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
yield data # Yield the full chunk
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
log.error(f"Error handling streaming response: {e}")
|
||||
raise
|
||||
|
||||
def call_completions(
|
||||
self, config: CompletionConfig
|
||||
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
payload = config.to_dict()
|
||||
|
||||
return self._make_request(
|
||||
payload=payload, endpoint="/v1/completions", stream=config.stream
|
||||
)
|
||||
|
||||
def call_chat_completions(
|
||||
self, config: ChatCompletionConfig
|
||||
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
payload = config.to_dict()
|
||||
|
||||
return self._make_request(
|
||||
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
|
||||
TOOLS_PROMPT = (
|
||||
"Can you list the files in the current working directory and tell me what you see? "
|
||||
"What do you think this directory might be for?"
|
||||
)
|
||||
|
||||
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
|
||||
MAX_TOKENS = 1024
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
|
||||
# ---------------------- Tooling ----------------------
|
||||
class ToolManager:
|
||||
"""Handles tool definitions and execution"""
|
||||
|
||||
@@ -164,7 +50,7 @@ class ToolManager:
|
||||
|
||||
@staticmethod
|
||||
def get_ls_tool_definition() -> List[Dict[str, Any]]:
|
||||
"""Get the ls tool definition"""
|
||||
"""OpenAI-compatible tool schema"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
@@ -178,98 +64,228 @@ class ToolManager:
|
||||
|
||||
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
||||
"""Execute a tool call and return the result"""
|
||||
function_name = tool_call["function"]["name"]
|
||||
|
||||
function_name = (tool_call.get("function") or {}).get("name")
|
||||
if function_name == "list_files":
|
||||
return self.list_files()
|
||||
else:
|
||||
raise ValueError(f"Unknown tool function: {function_name}")
|
||||
|
||||
|
||||
# ----- Helpers to handle streamed tool_calls assembly -----
|
||||
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
|
||||
"""
|
||||
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
|
||||
We merge into a per-index state dict until the assistant message finishes.
|
||||
"""
|
||||
idx = tc_delta.get("index")
|
||||
if idx is None:
|
||||
return
|
||||
|
||||
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
|
||||
|
||||
if tc_delta.get("id"):
|
||||
entry["id"] = tc_delta["id"]
|
||||
|
||||
fn_delta = tc_delta.get("function") or {}
|
||||
if "name" in fn_delta and fn_delta["name"]:
|
||||
entry["function"]["name"] = fn_delta["name"]
|
||||
if "arguments" in fn_delta and fn_delta["arguments"]:
|
||||
entry["function"]["arguments"] += fn_delta["arguments"]
|
||||
|
||||
|
||||
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
return [state[i] for i in sorted(state.keys())]
|
||||
|
||||
|
||||
# ---- OpenAI-compatible calls (non-streaming) ----
|
||||
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
}
|
||||
}
|
||||
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
# ---- Streaming variants ----
|
||||
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
|
||||
# ---------------------- Demo Runner ----------------------
|
||||
class APIDemo:
|
||||
"""Demo and testing functionality for the API client"""
|
||||
|
||||
def __init__(
|
||||
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
|
||||
):
|
||||
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.endpoint_name = endpoint_name
|
||||
self.tool_manager = tool_manager or ToolManager()
|
||||
|
||||
def handle_streaming_response(
|
||||
self, response_stream, show_reasoning: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Handle streaming chat response and display all output.
|
||||
"""
|
||||
|
||||
# ----- Streaming handler -----
|
||||
async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
|
||||
full_response = ""
|
||||
reasoning_content = ""
|
||||
reasoning_started = False
|
||||
content_started = False
|
||||
printed_reasoning = False
|
||||
printed_answer = False
|
||||
finish_reason = None
|
||||
|
||||
for chunk in response_stream:
|
||||
# Normalize the chunk
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.strip()
|
||||
if chunk.startswith("data: "):
|
||||
chunk = chunk[6:].strip()
|
||||
if chunk in ["[DONE]", ""]:
|
||||
continue
|
||||
try:
|
||||
parsed_chunk = json.loads(chunk)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
elif isinstance(chunk, dict):
|
||||
parsed_chunk = chunk
|
||||
else:
|
||||
continue
|
||||
async for chunk in stream:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Parse delta from the chunk
|
||||
choices = parsed_chunk.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
# Track finish reason
|
||||
if choice.get("finish_reason"):
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
reasoning_token = delta.get("reasoning_content", "")
|
||||
content_token = delta.get("content", "")
|
||||
|
||||
# Print reasoning token if applicable
|
||||
if show_reasoning and reasoning_token:
|
||||
if not reasoning_started:
|
||||
# reasoning tokens
|
||||
rc = delta.get("reasoning_content")
|
||||
if rc and show_reasoning:
|
||||
if not printed_reasoning:
|
||||
print("\n🧠 Reasoning: ", end="", flush=True)
|
||||
reasoning_started = True
|
||||
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
|
||||
reasoning_content += reasoning_token
|
||||
printed_reasoning = True
|
||||
print(rc, end="", flush=True)
|
||||
reasoning_content += rc
|
||||
|
||||
# Print content token
|
||||
if content_token:
|
||||
if not content_started:
|
||||
if show_reasoning and reasoning_started:
|
||||
print(f"\n💬 Response: ", end="", flush=True)
|
||||
# content tokens
|
||||
content_part = delta.get("content")
|
||||
if content_part:
|
||||
if not printed_answer:
|
||||
if show_reasoning and printed_reasoning:
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
else:
|
||||
print("Assistant: ", end="", flush=True)
|
||||
content_started = True
|
||||
print(content_token, end="", flush=True)
|
||||
full_response += content_token
|
||||
|
||||
print() # Ensure newline after response
|
||||
printed_answer = True
|
||||
print(content_part, end="", flush=True)
|
||||
full_response += content_part
|
||||
|
||||
print() # newline
|
||||
if show_reasoning:
|
||||
if reasoning_started or content_started:
|
||||
if printed_reasoning or printed_answer:
|
||||
print("\nStreaming completed.")
|
||||
if reasoning_started:
|
||||
if printed_reasoning:
|
||||
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
||||
if content_started:
|
||||
if printed_answer:
|
||||
print(f"Response tokens: {len(full_response.split())}")
|
||||
if finish_reason:
|
||||
print(f"Finish reason: {finish_reason}")
|
||||
|
||||
return full_response
|
||||
|
||||
def test_tool_support(self) -> bool:
|
||||
"""Test if the endpoint supports function calling"""
|
||||
log.debug("Testing endpoint tool calling support...")
|
||||
async def demo_completions(self) -> None:
|
||||
print("=" * 60)
|
||||
print("COMPLETIONS DEMO")
|
||||
print("=" * 60)
|
||||
|
||||
# Try a simple request with minimal tools to test support
|
||||
response = await call_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
prompt=COMPLETIONS_PROMPT,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
print("\nResponse:")
|
||||
print(json.dumps(response, indent=2))
|
||||
|
||||
async def demo_chat(self, use_streaming: bool = True) -> None:
|
||||
print("=" * 60)
|
||||
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
|
||||
print("=" * 60)
|
||||
|
||||
messages = [{"role": "user", "content": CHAT_PROMPT}]
|
||||
|
||||
if use_streaming:
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
try:
|
||||
await self.handle_streaming_response(stream, show_reasoning=True)
|
||||
except Exception as e:
|
||||
log.error("\nError during streaming: %s", e, exc_info=True)
|
||||
else:
|
||||
response = await call_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
choice = (response.get("choices") or [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
|
||||
if reasoning:
|
||||
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
||||
print(f"\n💬 Assistant: {content}")
|
||||
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
async def test_tool_support(self) -> bool:
|
||||
"""Probe that tool schema is accepted (no actual call)"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
minimal_tool = [
|
||||
{
|
||||
@@ -277,179 +293,158 @@ class APIDemo:
|
||||
"function": {"name": "test_function", "description": "Test function"},
|
||||
}
|
||||
]
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
try:
|
||||
_ = await call_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
endpoint_name=self.endpoint_name,
|
||||
tools=minimal_tool,
|
||||
tool_choice="none", # Don't actually call the tool
|
||||
tool_choice="none",
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.client.call_chat_completions(config)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error: Endpoint does not support tool calling: {e}")
|
||||
log.error("Endpoint does not support tool calling: %s", e)
|
||||
return False
|
||||
|
||||
def demo_completions(self) -> None:
|
||||
"""Demo: test basic completions endpoint"""
|
||||
print("=" * 60)
|
||||
print("COMPLETIONS DEMO")
|
||||
print("=" * 60)
|
||||
|
||||
config = CompletionConfig(
|
||||
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
|
||||
)
|
||||
response = self.client.call_completions(config)
|
||||
|
||||
if isinstance(response, dict):
|
||||
print("\nResponse:")
|
||||
print(json.dumps(response, indent=2))
|
||||
else:
|
||||
log.error("Unexpected response format")
|
||||
|
||||
def demo_chat(self, use_streaming: bool = True) -> None:
|
||||
"""
|
||||
Demo: test chat completions endpoint with optional streaming
|
||||
"""
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": CHAT_PROMPT}],
|
||||
stream=use_streaming,
|
||||
)
|
||||
|
||||
log.info(f"Testing chat completions with model '{self.model}'...")
|
||||
response = self.client.call_chat_completions(config)
|
||||
|
||||
if use_streaming:
|
||||
try:
|
||||
self.handle_streaming_response(response, show_reasoning=True)
|
||||
except Exception as e:
|
||||
log.error(f"\nError during streaming: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
else:
|
||||
if isinstance(response, dict):
|
||||
choice = response.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
reasoning = message.get("reasoning_content", "") or message.get(
|
||||
"reasoning", ""
|
||||
)
|
||||
|
||||
if reasoning:
|
||||
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
||||
|
||||
print(f"\n💬 Assistant: {content}")
|
||||
print(f"\nFull Response:")
|
||||
print(json.dumps(response, indent=2))
|
||||
else:
|
||||
log.error("Unexpected response format")
|
||||
|
||||
def demo_ls_tool(self) -> None:
|
||||
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
|
||||
async def demo_ls_tool(self) -> None:
|
||||
"""Ask to list files using function calling, then provide final analysis"""
|
||||
print("=" * 60)
|
||||
print("TOOL USE DEMO: List Directory Contents")
|
||||
print("=" * 60)
|
||||
|
||||
# Test if tools are supported first
|
||||
if not self.test_tool_support():
|
||||
if not await self.test_tool_support():
|
||||
return
|
||||
|
||||
# Request with tool available
|
||||
messages = [{"role": "user", "content": TOOLS_PROMPT}]
|
||||
messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
# First pass: let the model decide tools, stream tool_calls and partial content
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
tools=self.tool_manager.get_ls_tool_definition(),
|
||||
tool_choice="auto",
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
log.info(f"Making initial request with tool using model '{self.model}'...")
|
||||
response = self.client.call_chat_completions(config)
|
||||
assistant_content_buf: List[str] = []
|
||||
tool_calls_state: Dict[int, Dict[str, Any]] = {}
|
||||
printed_reasoning = False
|
||||
printed_answer = False
|
||||
|
||||
if not isinstance(response, dict):
|
||||
raise ValueError("Expected dict response for tool use")
|
||||
async for chunk in stream:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
choice = response.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
rc = delta.get("reasoning_content")
|
||||
if rc:
|
||||
if not printed_reasoning:
|
||||
printed_reasoning = True
|
||||
print("🧠 Reasoning: ", end="", flush=True)
|
||||
print(rc, end="", flush=True)
|
||||
|
||||
print(f"Assistant response: {message.get('content', 'No content')}")
|
||||
content_part = delta.get("content")
|
||||
if content_part:
|
||||
assistant_content_buf.append(content_part)
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
print(content_part, end="", flush=True)
|
||||
|
||||
# Check for tool calls
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not tool_calls:
|
||||
raise ValueError(
|
||||
"No tool calls made - model may not support function calling"
|
||||
)
|
||||
if "tool_calls" in delta and delta["tool_calls"]:
|
||||
for tc_delta in delta["tool_calls"]:
|
||||
_merge_tool_call_delta(tool_calls_state, tc_delta)
|
||||
|
||||
print(f"Tool calls detected: {len(tool_calls)}")
|
||||
# If no tool calls, we’re done.
|
||||
if not tool_calls_state:
|
||||
print("\n(No tool calls were made.)")
|
||||
return
|
||||
|
||||
# Execute the tool call
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call["function"]["name"]
|
||||
print(f"Executing tool: {function_name}")
|
||||
|
||||
tool_result = self.tool_manager.execute_tool_call(tool_call)
|
||||
print(f"Tool result:\n{tool_result}")
|
||||
|
||||
# Add tool result and continue conversation
|
||||
messages.append(message) # Add assistant's message with tool call
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": tool_result,
|
||||
# Build assistant message with tool_calls
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": "".join(assistant_content_buf) if assistant_content_buf else None,
|
||||
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
|
||||
}
|
||||
)
|
||||
messages.append(assistant_message)
|
||||
|
||||
# Get final response
|
||||
final_config = ChatCompletionConfig(
|
||||
# Execute tools and feed results back
|
||||
for tc in assistant_message["tool_calls"]:
|
||||
tool_name = (tc.get("function") or {}).get("name")
|
||||
call_id = tc.get("id")
|
||||
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
|
||||
|
||||
try:
|
||||
args = json.loads(raw_args) if raw_args.strip() else {}
|
||||
except Exception as e:
|
||||
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
|
||||
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
|
||||
continue
|
||||
|
||||
try:
|
||||
if tool_name == "list_files":
|
||||
tool_result = self.tool_manager.list_files()
|
||||
else:
|
||||
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
|
||||
except Exception as e:
|
||||
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
|
||||
|
||||
print("\n[Tool executed]", tool_name)
|
||||
print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
|
||||
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
|
||||
|
||||
# Second pass: get final streamed answer after tool results
|
||||
stream2 = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=self.tool_manager.get_ls_tool_definition(),
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
print("Getting final response...")
|
||||
final_response = self.client.call_chat_completions(final_config)
|
||||
final_buf = []
|
||||
printed_reasoning2 = False
|
||||
printed_answer2 = False
|
||||
|
||||
if isinstance(final_response, dict):
|
||||
final_choice = final_response.get("choices", [{}])[0]
|
||||
final_message = final_choice.get("message", {})
|
||||
final_content = final_message.get("content", "")
|
||||
async for chunk in stream2:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
rc2 = delta.get("reasoning_content")
|
||||
if rc2:
|
||||
if not printed_reasoning2:
|
||||
printed_reasoning2 = True
|
||||
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
|
||||
print(rc2, end="", flush=True)
|
||||
|
||||
c2 = delta.get("content")
|
||||
if c2:
|
||||
final_buf.append(c2)
|
||||
if not printed_answer2:
|
||||
printed_answer2 = True
|
||||
print("\n💬 Response (final): ", end="", flush=True)
|
||||
print(c2, end="", flush=True)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("FINAL LLM ANALYSIS:")
|
||||
print("=" * 60)
|
||||
print(final_content)
|
||||
print("".join(final_buf))
|
||||
print("=" * 60)
|
||||
|
||||
def interactive_chat(self) -> None:
|
||||
async def interactive_chat(self) -> None:
|
||||
"""Interactive chat session with streaming"""
|
||||
print("=" * 60)
|
||||
print("INTERACTIVE STREAMING CHAT")
|
||||
print("=" * 60)
|
||||
print(f"Using model: {self.model}")
|
||||
print("Type 'quit' to exit, 'clear' to clear history")
|
||||
print()
|
||||
|
||||
messages = []
|
||||
messages: List[Dict[str, Any]] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -467,16 +462,16 @@ class APIDemo:
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model, messages=messages, stream=True, temperature=0.7
|
||||
)
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = self.client.call_chat_completions(config)
|
||||
assistant_content = self.handle_streaming_response(
|
||||
response, show_reasoning=True
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=0.7
|
||||
)
|
||||
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
|
||||
|
||||
# Add assistant response to conversation history
|
||||
messages.append({"role": "assistant", "content": assistant_content})
|
||||
@@ -485,115 +480,66 @@ class APIDemo:
|
||||
print("\n👋 Chat interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
log.error(f"\nError: {e}")
|
||||
log.error("\nError: %s", e)
|
||||
continue
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function with CLI switches for different tests"""
|
||||
from lib.test_utils import test_args
|
||||
# ---------------------- CLI ----------------------
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
|
||||
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
|
||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||
|
||||
# Add mandatory model argument
|
||||
test_args.add_argument(
|
||||
"--model", required=True, help="Model to use for requests (required)"
|
||||
)
|
||||
modes = p.add_mutually_exclusive_group(required=False)
|
||||
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
|
||||
modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
|
||||
modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
|
||||
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
|
||||
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
|
||||
return p
|
||||
|
||||
# Add test mode arguments
|
||||
test_args.add_argument(
|
||||
"--completion", action="store_true", help="Test completions endpoint"
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--chat",
|
||||
action="store_true",
|
||||
help="Test chat completions endpoint (non-streaming)",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--chat-stream",
|
||||
action="store_true",
|
||||
help="Test chat completions endpoint with streaming",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--tools",
|
||||
action="store_true",
|
||||
help="Test function calling with ls tool (non-streaming)",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="Start interactive streaming chat session",
|
||||
)
|
||||
|
||||
args = test_args.parse_args()
|
||||
async def main_async():
|
||||
args = build_arg_parser().parse_args()
|
||||
|
||||
# Check that only one test mode is selected
|
||||
test_modes = [
|
||||
args.completion,
|
||||
args.chat,
|
||||
args.chat_stream,
|
||||
args.tools,
|
||||
args.interactive,
|
||||
]
|
||||
selected_count = sum(test_modes)
|
||||
|
||||
if selected_count == 0:
|
||||
selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
|
||||
if selected == 0:
|
||||
print("Please specify exactly one test mode:")
|
||||
print(" --completion : Test completions endpoint")
|
||||
print(" --chat : Test chat completions endpoint (non-streaming)")
|
||||
print(" --chat-stream : Test chat completions endpoint with streaming")
|
||||
print(" --tools : Test function calling with ls tool (non-streaming)")
|
||||
print(" --tools : Test function calling with ls tool")
|
||||
print(" --interactive : Start interactive streaming chat session")
|
||||
print(
|
||||
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
|
||||
)
|
||||
print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
|
||||
sys.exit(1)
|
||||
elif selected_count > 1:
|
||||
elif selected > 1:
|
||||
print("Please specify exactly one test mode")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
|
||||
if not endpoint_api_key:
|
||||
log.error(
|
||||
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Create the core API client
|
||||
client = APIClient(
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
api_key=args.api_key,
|
||||
server_url=Endpoint.get_autoscaler_server_url(args.instance),
|
||||
endpoint_api_key=endpoint_api_key,
|
||||
)
|
||||
|
||||
# Create tool manager and demo (passing the model parameter)
|
||||
tool_manager = ToolManager()
|
||||
demo = APIDemo(client, args.model, tool_manager)
|
||||
|
||||
print(f"Using model: {args.model}")
|
||||
print("=" * 60)
|
||||
print(f"Using model: {args.model}")
|
||||
print(f"Using endpoint: {args.endpoint}")
|
||||
|
||||
|
||||
try:
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
|
||||
|
||||
# Run the selected test
|
||||
if args.completion:
|
||||
demo.demo_completions()
|
||||
await demo.demo_completions()
|
||||
elif args.chat:
|
||||
demo.demo_chat(use_streaming=False)
|
||||
await demo.demo_chat(use_streaming=False)
|
||||
elif args.chat_stream:
|
||||
demo.demo_chat(use_streaming=True)
|
||||
await demo.demo_chat(use_streaming=True)
|
||||
elif args.tools:
|
||||
demo.demo_ls_tool()
|
||||
await demo.demo_ls_tool()
|
||||
elif args.interactive:
|
||||
demo.interactive_chat()
|
||||
await demo.interactive_chat()
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error during test: {e}", exc_info=True)
|
||||
log.error("Error during test: %s", e, exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
asyncio.run(main_async())
|
||||
|
||||
@@ -11,6 +11,7 @@ MODEL_SERVER_START_LOG_MSG = [
|
||||
"llama runner started", # Ollama
|
||||
'"message":"Connected","target":"text_generation_router"', # TGI
|
||||
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
||||
"main: model loaded" # llama.cpp
|
||||
]
|
||||
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
@@ -34,6 +35,7 @@ backend = Backend(
|
||||
model_server_url=os.environ["MODEL_SERVER_URL"],
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
max_wait_time=600.0,
|
||||
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||
|
||||
+93
-9
@@ -1,19 +1,103 @@
|
||||
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
|
||||
# HuggingFace TGI PyWorker
|
||||
|
||||
1. `generate`: Generates the LLM's response to a given prompt in a single request.
|
||||
2. `generate_stream`: Streams the LLM's response token by token.
|
||||
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||
|
||||
Both endpoints use the following API payload format:
|
||||
## Instance Setup
|
||||
|
||||
1. Pick a template
|
||||
|
||||
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
|
||||
|
||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
|
||||
|
||||
The template can be configured via the template interface. You may want to change the model or startup arguments.
|
||||
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
|
||||
## Client Setup (Demo)
|
||||
|
||||
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vast-ai/pyworker
|
||||
cd pyworker
|
||||
pip install uv
|
||||
uv venv -p 3.12
|
||||
source .venv/bin/activate
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Using the Test Client
|
||||
|
||||
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
|
||||
|
||||
First, set your API key as an environment variable:
|
||||
|
||||
```bash
|
||||
export VAST_API_KEY=<your_api_key>
|
||||
```
|
||||
|
||||
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
|
||||
|
||||
### Generate (Streaming)
|
||||
|
||||
Call to `/generate_stream` with streaming response:
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
### Generate (Non-Streaming)
|
||||
|
||||
Call to `/generate` with json response:
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
### Interactive Session (Streaming)
|
||||
|
||||
Interactive session with streaming responses. Type `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
TGI provides two primary endpoints:
|
||||
|
||||
### Generate (Non-Streaming)
|
||||
|
||||
`/generate` - Returns the complete response in a single request.
|
||||
|
||||
```json
|
||||
{
|
||||
"inputs": "PROMPT",
|
||||
"inputs": "Your prompt here",
|
||||
"parameters": {
|
||||
"max_new_tokens": 250
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0.7,
|
||||
"return_full_text": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
|
||||
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
|
||||
approximately 2 seconds to complete.
|
||||
### Generate Stream (Streaming)
|
||||
|
||||
`/generate_stream` - Streams the response token by token.
|
||||
|
||||
```json
|
||||
{
|
||||
"inputs": "Your prompt here",
|
||||
"parameters": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0.7,
|
||||
"do_sample": true,
|
||||
"return_full_text": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Notes
|
||||
|
||||
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
|
||||
|
||||
+202
-105
@@ -1,11 +1,13 @@
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
from vastai import Serverless
|
||||
import asyncio
|
||||
|
||||
# ---------------------- Logging ----------------------
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
@@ -13,113 +15,208 @@ logging.basicConfig(
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
# ---------------------- Defaults ----------------------
|
||||
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
|
||||
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
|
||||
MAX_TOKENS = 1024
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
|
||||
|
||||
# ---------------------- API Calls ----------------------
|
||||
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
|
||||
"""Non-streaming generation via /generate endpoint"""
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"return_full_text": False,
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=url,
|
||||
)
|
||||
|
||||
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
print(res)
|
||||
|
||||
|
||||
def call_generate_stream(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
WORKER_ENDPOINT = "/generate_stream"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
log.debug("POST /generate %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
|
||||
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
|
||||
"""Streaming generation via /generate_stream endpoint"""
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"do_sample": True,
|
||||
"return_full_text": False,
|
||||
}
|
||||
}
|
||||
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request(
|
||||
"/generate_stream",
|
||||
payload,
|
||||
cost=payload["parameters"]["max_new_tokens"],
|
||||
stream=True,
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
print(f"url: {url}")
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
return resp["response"] # async generator
|
||||
|
||||
|
||||
# ---------------------- Demo Runner ----------------------
|
||||
class APIDemo:
|
||||
"""Demo and testing functionality for the TGI API client"""
|
||||
|
||||
def __init__(self, client: Serverless, endpoint_name: str):
|
||||
self.client = client
|
||||
self.endpoint_name = endpoint_name
|
||||
|
||||
async def handle_streaming_response(self, stream) -> str:
|
||||
"""Process streaming response and print tokens"""
|
||||
full_response = ""
|
||||
printed_answer = False
|
||||
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
print(tok, end="", flush=True)
|
||||
full_response += tok
|
||||
|
||||
print() # newline
|
||||
if printed_answer:
|
||||
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
|
||||
|
||||
return full_response
|
||||
|
||||
async def demo_generate(self) -> None:
|
||||
"""Demo non-streaming generation"""
|
||||
print("=" * 60)
|
||||
print("GENERATE DEMO (NON-STREAMING)")
|
||||
print("=" * 60)
|
||||
|
||||
response = await call_generate(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=DEFAULT_PROMPT,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
response = requests.post(url, json=req_data, stream=True)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
for line in response.iter_lines():
|
||||
payload = line.decode().lstrip("data:").rstrip()
|
||||
if payload:
|
||||
|
||||
print(f"\n💬 Response: {response.get('generated_text', '')}")
|
||||
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
async def demo_generate_stream(self) -> None:
|
||||
"""Demo streaming generation"""
|
||||
print("=" * 60)
|
||||
print("GENERATE DEMO (STREAMING)")
|
||||
print("=" * 60)
|
||||
|
||||
stream = await call_generate_stream(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=DEFAULT_PROMPT,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
print(data["token"]["text"], end="")
|
||||
sys.stdout.flush()
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
log.warning(f"Failed to parse streaming response: {e}")
|
||||
continue
|
||||
await self.handle_streaming_response(stream)
|
||||
except Exception as e:
|
||||
log.error("\nError during streaming: %s", e, exc_info=True)
|
||||
|
||||
async def interactive_chat(self) -> None:
|
||||
"""Interactive session with streaming generation"""
|
||||
print("=" * 60)
|
||||
print("INTERACTIVE STREAMING SESSION")
|
||||
print("=" * 60)
|
||||
print(f"Using endpoint: {self.endpoint_name}")
|
||||
print("Type 'quit' to exit")
|
||||
print()
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("You: ").strip()
|
||||
|
||||
if user_input.lower() == "quit":
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
elif not user_input:
|
||||
continue
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
stream = await call_generate_stream(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=user_input,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
print(tok, end="", flush=True)
|
||||
full_response += tok
|
||||
print() # newline
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Session interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
log.error("\nError: %s", e)
|
||||
continue
|
||||
|
||||
|
||||
# ---------------------- CLI ----------------------
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
|
||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||
|
||||
modes = p.add_mutually_exclusive_group(required=False)
|
||||
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
|
||||
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
|
||||
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
|
||||
return p
|
||||
|
||||
|
||||
async def main_async():
|
||||
args = build_arg_parser().parse_args()
|
||||
|
||||
selected = sum([args.generate, args.generate_stream, args.interactive])
|
||||
if selected == 0:
|
||||
print("Please specify exactly one test mode:")
|
||||
print(" --generate : Test generate endpoint (non-streaming)")
|
||||
print(" --generate-stream : Test generate endpoint with streaming")
|
||||
print(" --interactive : Start interactive streaming session")
|
||||
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
|
||||
sys.exit(1)
|
||||
elif selected > 1:
|
||||
print("Please specify exactly one test mode")
|
||||
sys.exit(1)
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Using endpoint: {args.endpoint}")
|
||||
|
||||
try:
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.endpoint)
|
||||
|
||||
if args.generate:
|
||||
await demo.demo_generate()
|
||||
elif args.generate_stream:
|
||||
await demo.demo_generate_stream()
|
||||
elif args.interactive:
|
||||
await demo.interactive_chat()
|
||||
|
||||
except Exception as e:
|
||||
log.error("Error during test: %s", e, exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lib.test_utils import test_args
|
||||
|
||||
args = test_args.parse_args()
|
||||
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
if endpoint_api_key:
|
||||
try:
|
||||
call_generate(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
call_generate_stream(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error during API call: {e}")
|
||||
else:
|
||||
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
|
||||
asyncio.run(main_async())
|
||||
|
||||
Reference in New Issue
Block a user