Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 222ac2a0dd | |||
| 40aed9b5f8 | |||
| d4d36bf86e | |||
| e839cfc6e8 | |||
| f04138e13b | |||
| de3aa87c8f | |||
| 6b5b1341a7 | |||
| 8be92c03de | |||
| adedb8ba90 | |||
| 2f543c01ad | |||
| 0bcd2219ea | |||
| 0339b471c5 | |||
| e143162438 | |||
| 7986e51e9e | |||
| 9c6ab78503 | |||
| 45e0c7d9ca | |||
| 7a792fd176 | |||
| e0449cb3c7 | |||
| a4339bd3f1 | |||
| 2b26e5e20c | |||
| d3727d4fd7 | |||
| a47c9d1ed0 | |||
| 0b14562a63 | |||
| de9b50abb9 | |||
| c510801723 | |||
| a12523b1d2 | |||
| eedf81c0a3 | |||
| 3adec1826d | |||
| b55bfa9611 | |||
| 7db54f3bd7 | |||
| d63a060202 | |||
| c6521cb6d4 | |||
| b7fe4ebb91 | |||
| 8ae7b74605 | |||
| 106067d716 | |||
| f5134d4bf5 | |||
| 47e5460532 | |||
| ec2ac0a21a | |||
| 2cde573c56 | |||
| b2e4a5db0c | |||
| 7437028cb2 | |||
| 7c0f316eeb | |||
| b4025a744f | |||
| d190308329 | |||
| 944f83fc03 | |||
| f56bbc0ebe |
+2
-1
@@ -2,4 +2,5 @@
|
||||
.envrc
|
||||
__pycache__
|
||||
bin/
|
||||
lib64
|
||||
lib64
|
||||
.venv
|
||||
@@ -39,11 +39,12 @@ reporting these metrics to the autoscaler.
|
||||
|
||||
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
|
||||
|
||||
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00)
|
||||
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447)
|
||||
* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d)
|
||||
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d)
|
||||
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938)
|
||||
|
||||
Currently available workers:
|
||||
* `hello_world`: A simple example worker for a basic LLM server.
|
||||
* `openai`: A simple example worker for a basic vLLM server.
|
||||
* `comfyui`: A worker for the ComfyUI image generation backend.
|
||||
* `tgi`: A worker for the Text Generation Inference backend.
|
||||
|
||||
|
||||
+27
-24
@@ -30,7 +30,7 @@ from lib.data_types import (
|
||||
BenchmarkResult
|
||||
)
|
||||
|
||||
VERSION = "0.1.0"
|
||||
VERSION = "0.2.1"
|
||||
|
||||
MSG_HISTORY_LEN = 100
|
||||
log = logging.getLogger(__file__)
|
||||
@@ -66,10 +66,17 @@ class Backend:
|
||||
unsecured: bool = dataclasses.field(
|
||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
||||
)
|
||||
report_addr: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai")
|
||||
)
|
||||
mtoken: str = dataclasses.field(
|
||||
default_factory=lambda: os.environ.get("MASTER_TOKEN", "")
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.metrics = Metrics()
|
||||
self.metrics._set_version(self.version)
|
||||
self.metrics._set_mtoken(self.mtoken)
|
||||
self._total_pubkey_fetch_errors = 0
|
||||
self._pubkey = self._fetch_pubkey()
|
||||
self.__start_healthcheck: bool = False
|
||||
@@ -104,23 +111,19 @@ class Backend:
|
||||
|
||||
#######################################Private#######################################
|
||||
def _fetch_pubkey(self):
|
||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||
result = subprocess.check_output(command, universal_newlines=True)
|
||||
log.debug("public key:")
|
||||
log.debug(result)
|
||||
key = None
|
||||
for _ in range(5):
|
||||
try:
|
||||
key = RSA.import_key(result)
|
||||
break
|
||||
except ValueError as e:
|
||||
log.debug(f"Error downloading key: {e}")
|
||||
time.sleep(15)
|
||||
if key is None:
|
||||
self._total_pubkey_fetch_errors += 1
|
||||
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
||||
self.backend_errored("Failed to get autoscaler pubkey")
|
||||
return key
|
||||
report_addr = self.report_addr.rstrip("/")
|
||||
command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"]
|
||||
try:
|
||||
result = subprocess.check_output(command, universal_newlines=True)
|
||||
log.debug("public key:")
|
||||
log.debug(result)
|
||||
key = RSA.import_key(result)
|
||||
if key is not None:
|
||||
return key
|
||||
except (ValueError , subprocess.CalledProcessError) as e:
|
||||
log.debug(f"Error downloading key: {e}")
|
||||
self.backend_errored("Failed to get autoscaler pubkey")
|
||||
|
||||
|
||||
async def __handle_request(
|
||||
self,
|
||||
@@ -315,10 +318,10 @@ class Backend:
|
||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||
log.debug("already ran benchmark")
|
||||
# trigger model load
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
_ = await self.__call_api(
|
||||
handler=self.benchmark_handler, payload=payload
|
||||
)
|
||||
# payload = self.benchmark_handler.make_benchmark_payload()
|
||||
# _ = await self.__call_api(
|
||||
# handler=self.benchmark_handler, payload=payload
|
||||
# )
|
||||
return float(f.readline())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
@@ -393,7 +396,7 @@ class Backend:
|
||||
)
|
||||
# some backends need a few seconds after logging successful startup before
|
||||
# they can begin accepting requests
|
||||
await sleep(5)
|
||||
# await sleep(5)
|
||||
try:
|
||||
max_throughput = await run_benchmark()
|
||||
self.__start_healthcheck = True
|
||||
@@ -414,7 +417,7 @@ class Backend:
|
||||
|
||||
async def tail_log():
|
||||
log.debug(f"tailing file: {self.model_log_file}")
|
||||
async with await open_file(self.model_log_file) as f:
|
||||
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
|
||||
while True:
|
||||
line = await f.readline()
|
||||
if line:
|
||||
|
||||
@@ -286,6 +286,7 @@ class AutoScalerData:
|
||||
"""Data that is reported to autoscaler"""
|
||||
|
||||
id: int
|
||||
mtoken: str
|
||||
version: str
|
||||
loadtime: float
|
||||
cur_load: float
|
||||
|
||||
+16
-2
@@ -28,6 +28,7 @@ def get_url() -> str:
|
||||
@dataclass
|
||||
class Metrics:
|
||||
version: str = "0"
|
||||
mtoken: str = ""
|
||||
last_metric_update: float = 0.0
|
||||
last_request_served: float = 0.0
|
||||
update_pending: bool = False
|
||||
@@ -142,12 +143,16 @@ class Metrics:
|
||||
def _set_version(self, version: str) -> None:
|
||||
self.version = version
|
||||
|
||||
def _set_mtoken(self, mtoken: str) -> None:
|
||||
self.mtoken = mtoken
|
||||
|
||||
#######################################Private#######################################
|
||||
|
||||
async def __send_delete_requests_and_reset(self):
|
||||
async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool:
|
||||
data = {
|
||||
"worker_id": self.id,
|
||||
"mtoken": self.mtoken,
|
||||
"request_idxs": idxs,
|
||||
"success": success_flag,
|
||||
}
|
||||
@@ -209,6 +214,7 @@ class Metrics:
|
||||
def compute_autoscaler_data() -> AutoScalerData:
|
||||
return AutoScalerData(
|
||||
id=self.id,
|
||||
mtoken=self.mtoken,
|
||||
version=self.version,
|
||||
loadtime=(loadtime_snapshot or 0.0),
|
||||
new_load=self.model_metrics.workload_processing,
|
||||
@@ -228,17 +234,25 @@ class Metrics:
|
||||
|
||||
async def send_data(report_addr: str) -> bool:
|
||||
data = compute_autoscaler_data()
|
||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||
log_data = asdict(data)
|
||||
def obfuscate(secret: str) -> str:
|
||||
if secret is None:
|
||||
return ""
|
||||
return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret))
|
||||
|
||||
log_data["mtoken"] = obfuscate(log_data.get("mtoken"))
|
||||
log.debug(
|
||||
"\n".join(
|
||||
[
|
||||
"#" * 60,
|
||||
f"sending data to autoscaler",
|
||||
f"{json.dumps((asdict(data)), indent=2)}",
|
||||
f"{json.dumps(log_data, indent=2)}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
session = await self.http()
|
||||
|
||||
+45
-25
@@ -3,38 +3,58 @@ import logging
|
||||
from typing import List
|
||||
import ssl
|
||||
from asyncio import run, gather
|
||||
|
||||
import asyncio
|
||||
|
||||
from lib.backend import Backend
|
||||
from lib.metrics import Metrics
|
||||
from aiohttp import web
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||
log.debug("getting certificate...")
|
||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||
if use_ssl is True:
|
||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context.load_cert_chain(
|
||||
certfile="/etc/instance.crt",
|
||||
keyfile="/etc/instance.key",
|
||||
)
|
||||
else:
|
||||
ssl_context = None
|
||||
try:
|
||||
log.debug("getting certificate...")
|
||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||
if use_ssl is True:
|
||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context.load_cert_chain(
|
||||
certfile="/etc/instance.crt",
|
||||
keyfile="/etc/instance.key",
|
||||
)
|
||||
else:
|
||||
ssl_context = None
|
||||
|
||||
async def main():
|
||||
log.debug("starting server...")
|
||||
app = web.Application()
|
||||
app.add_routes(routes)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(
|
||||
runner,
|
||||
ssl_context=ssl_context,
|
||||
port=int(os.environ["WORKER_PORT"]),
|
||||
**kwargs
|
||||
)
|
||||
await gather(site.start(), backend._start_tracking())
|
||||
async def main():
|
||||
log.debug("starting server...")
|
||||
app = web.Application()
|
||||
app.add_routes(routes)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(
|
||||
runner,
|
||||
ssl_context=ssl_context,
|
||||
port=int(os.environ["WORKER_PORT"]),
|
||||
**kwargs
|
||||
)
|
||||
await gather(site.start(), backend._start_tracking())
|
||||
|
||||
run(main())
|
||||
run(main())
|
||||
|
||||
except Exception as e:
|
||||
err_msg = f"PyWorker failed to launch: {e}"
|
||||
log.error(err_msg)
|
||||
|
||||
async def beacon():
|
||||
metrics = Metrics()
|
||||
metrics._set_version(getattr(backend, "version", "0"))
|
||||
metrics._set_mtoken(getattr(backend, "mtoken", ""))
|
||||
try:
|
||||
while True:
|
||||
metrics._model_errored(err_msg)
|
||||
await metrics._Metrics__send_metrics_and_reset()
|
||||
await asyncio.sleep(10)
|
||||
finally:
|
||||
await metrics.aclose()
|
||||
|
||||
run(beacon())
|
||||
|
||||
@@ -8,3 +8,4 @@ Requests~=2.32
|
||||
transformers~=4.52
|
||||
utils==1.0.*
|
||||
hf_transfer>=0.1.9
|
||||
vastai-sdk>=0.2.0
|
||||
+47
-6
@@ -12,7 +12,6 @@ PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
||||
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
|
||||
USE_SSL="${USE_SSL:-true}"
|
||||
WORKER_PORT="${WORKER_PORT:-3000}"
|
||||
MODEL_TYPE="${MODEL_TYPE:-image}"
|
||||
mkdir -p "$WORKSPACE_DIR"
|
||||
cd "$WORKSPACE_DIR"
|
||||
|
||||
@@ -42,6 +41,14 @@ echo_var DEBUG_LOG
|
||||
echo_var PYWORKER_LOG
|
||||
echo_var MODEL_LOG
|
||||
|
||||
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||
if [ -e "$MODEL_LOG" ]; then
|
||||
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
|
||||
cat "$MODEL_LOG" >> "$MODEL_LOG.old"
|
||||
: > "$MODEL_LOG"
|
||||
fi
|
||||
|
||||
# Populate /etc/environment with quoted values
|
||||
if ! grep -q "VAST" /etc/environment; then
|
||||
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
|
||||
@@ -125,9 +132,43 @@ cd "$SERVER_DIR"
|
||||
|
||||
echo "launching PyWorker server"
|
||||
|
||||
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
||||
set +e
|
||||
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
|
||||
PY_STATUS=${PIPESTATUS[0]}
|
||||
set -e
|
||||
|
||||
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
||||
echo "launching PyWorker server done"
|
||||
if [ "${PY_STATUS}" -ne 0 ]; then
|
||||
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
|
||||
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
|
||||
MTOKEN="${MASTER_TOKEN:-}"
|
||||
VERSION="${PYWORKER_VERSION:-0}"
|
||||
|
||||
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
|
||||
for addr in "${REPORT_ADDRS[@]}"; do
|
||||
curl -sS -X POST -H 'Content-Type: application/json' \
|
||||
-d "$(cat <<JSON
|
||||
{
|
||||
"id": ${CONTAINER_ID:-0},
|
||||
"mtoken": "${MTOKEN}",
|
||||
"version": "${VERSION}",
|
||||
"loadtime": 0,
|
||||
"new_load": 0,
|
||||
"cur_load": 0,
|
||||
"rej_load": 0,
|
||||
"max_perf": 0,
|
||||
"cur_perf": 0,
|
||||
"error_msg": "${ERROR_MSG}",
|
||||
"num_requests_working": 0,
|
||||
"num_requests_recieved": 0,
|
||||
"additional_disk_usage": 0,
|
||||
"working_request_idxs": [],
|
||||
"cur_capacity": 0,
|
||||
"max_capacity": 0,
|
||||
"url": "${URL}"
|
||||
}
|
||||
JSON
|
||||
)" "${addr%/}/worker_status/" || true
|
||||
done
|
||||
fi
|
||||
|
||||
echo "launching PyWorker server done"
|
||||
@@ -1,8 +1,16 @@
|
||||
# ComfyUI PyWorker
|
||||
|
||||
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture.
|
||||
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||
|
||||
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
|
||||
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
|
||||
|
||||
## Instance Setup
|
||||
|
||||
1. Pick a template
|
||||
|
||||
- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless))
|
||||
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
|
||||
## Requirements
|
||||
|
||||
@@ -10,6 +18,88 @@ This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) a
|
||||
|
||||
A docker image is provided but you may use any if the above requirements are met.
|
||||
|
||||
## Client
|
||||
|
||||
The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage.
|
||||
|
||||
### Setup
|
||||
|
||||
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vast-ai/pyworker
|
||||
cd pyworker
|
||||
pip install uv
|
||||
uv venv -p 3.12
|
||||
source .venv/bin/activate
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Set your API key:
|
||||
|
||||
```bash
|
||||
export VAST_API_KEY=<your_api_key>
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
# Default prompt
|
||||
python -m workers.comfyui-json.client
|
||||
|
||||
# Custom prompt
|
||||
python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow"
|
||||
|
||||
# With options
|
||||
python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30
|
||||
|
||||
# Using a custom workflow file
|
||||
python -m workers.comfyui-json.client --workflow my_workflow.json
|
||||
|
||||
# With S3 upload
|
||||
python -m workers.comfyui-json.client --s3
|
||||
```
|
||||
|
||||
### CLI Flags
|
||||
|
||||
| Flag | Default | Description |
|
||||
|------|---------|-------------|
|
||||
| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name |
|
||||
| `--prompt` | (default) | Text prompt for image generation |
|
||||
| `--workflow` | (none) | Path to custom workflow JSON file |
|
||||
| `--width` | 512 | Image width in pixels |
|
||||
| `--height` | 512 | Image height in pixels |
|
||||
| `--steps` | 20 | Number of denoising steps |
|
||||
| `--seed` | (random) | Random seed for reproducibility |
|
||||
| `--s3` | (disabled) | Upload generated images to S3 |
|
||||
|
||||
### Output
|
||||
|
||||
Images are saved to `./generated_images/comfy_{seed}.png`.
|
||||
|
||||
### S3 Upload (Optional)
|
||||
|
||||
You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag.
|
||||
|
||||
**1. Set environment variables:**
|
||||
|
||||
```bash
|
||||
export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com"
|
||||
export S3_BUCKET_NAME="my-bucket"
|
||||
export S3_ACCESS_KEY_ID="your-access-key-id"
|
||||
export S3_SECRET_ACCESS_KEY="your-secret-access-key"
|
||||
```
|
||||
|
||||
**2. Run with S3 upload enabled:**
|
||||
|
||||
```bash
|
||||
python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3
|
||||
```
|
||||
|
||||
Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`.
|
||||
|
||||
**Note:** Requires `boto3` (`pip install boto3`).
|
||||
|
||||
## Benchmarking
|
||||
|
||||
### Custom Benchmark Workflows
|
||||
@@ -212,11 +302,3 @@ WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Client Libraries
|
||||
|
||||
See the test client examples for implementation details on how to integrate with the ComfyUI worker.
|
||||
|
||||
---
|
||||
|
||||
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
|
||||
+293
-136
@@ -1,155 +1,312 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import uuid
|
||||
import random
|
||||
from urllib.parse import urljoin
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
import argparse
|
||||
import aiohttp
|
||||
|
||||
import requests
|
||||
from vastai import Serverless
|
||||
|
||||
from lib.test_utils import print_truncate_res
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from .data_types import count_workload
|
||||
# ---------------------- Config ----------------------
|
||||
DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed"
|
||||
ENDPOINT_NAME = "my-comfyui-endpoint"
|
||||
DEFAULT_WIDTH = 512
|
||||
DEFAULT_HEIGHT = 512
|
||||
DEFAULT_STEPS = 20
|
||||
COST = 100 # Fixed cost for ComfyUI requests
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
# Optional S3 Configuration (from environment variables)
|
||||
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
|
||||
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
||||
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
|
||||
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def call_text2image_workflow(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
"""Simple Text2Image using the new modifier-based approach"""
|
||||
|
||||
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
|
||||
"""Helper function for making requests with consistent error handling"""
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
verify=verify
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
log.error(f"HTTP error occurred during {context}: {http_err}")
|
||||
log.error(f"Status Code: {response.status_code}")
|
||||
log.error("Response content:", response.text)
|
||||
return None
|
||||
except requests.exceptions.Timeout:
|
||||
log.error(f"Timeout occurred during {context}: {url}")
|
||||
return None
|
||||
except requests.exceptions.ConnectionError:
|
||||
log.error(f"Connection error occurred during {context}: {url}")
|
||||
return None
|
||||
except json.JSONDecodeError as json_err:
|
||||
log.error(f"Failed to decode JSON response during {context}: {json_err}")
|
||||
if 'response' in locals():
|
||||
print("Response content:", response.text)
|
||||
return None
|
||||
except Exception as err:
|
||||
log.error(f"An unexpected error occurred during {context}: {err}")
|
||||
if 'response' in locals():
|
||||
log.error("Response content (if available):", response.text)
|
||||
return None
|
||||
|
||||
WORKER_ENDPOINT = "/generate/sync"
|
||||
def get_s3_client():
|
||||
"""Create and return an S3 client configured for the S3-compatible endpoint"""
|
||||
try:
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
except ImportError:
|
||||
log.error("boto3 is required for S3 uploads. Install with: pip install boto3")
|
||||
return None
|
||||
|
||||
# This worker has concurrency = 1. All workloads have cost value 1.0
|
||||
COST = count_workload()
|
||||
|
||||
# Route to get worker URL
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
}
|
||||
|
||||
# First request - get routing information
|
||||
route_response = make_request(
|
||||
url=urljoin(server_url, "/route/"),
|
||||
payload=route_payload,
|
||||
timeout=4,
|
||||
context="route request"
|
||||
if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
|
||||
log.error("S3 environment variables not fully configured. Required:")
|
||||
log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY")
|
||||
return None
|
||||
|
||||
return boto3.client(
|
||||
"s3",
|
||||
endpoint_url=S3_ENDPOINT_URL,
|
||||
aws_access_key_id=S3_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||
config=Config(signature_version="s3v4"),
|
||||
)
|
||||
|
||||
if route_response is None:
|
||||
return None
|
||||
|
||||
if "url" not in route_response or not route_response["url"]:
|
||||
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
|
||||
return None
|
||||
|
||||
if "status" in route_response:
|
||||
print(f"Autoscaler status: {route_response['status']}")
|
||||
return None
|
||||
|
||||
# Extract data from route response
|
||||
url = route_response["url"]
|
||||
auth_data = dict(
|
||||
signature=route_response["signature"],
|
||||
cost=route_response["cost"],
|
||||
endpoint=route_response["endpoint"],
|
||||
reqnum=route_response["reqnum"],
|
||||
url=route_response["url"],
|
||||
)
|
||||
|
||||
# Build the payload for the worker request
|
||||
worker_payload = {
|
||||
|
||||
|
||||
# ---------------------- API Functions ----------------------
|
||||
async def call_generate(
|
||||
client: Serverless,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
prompt: str,
|
||||
width: int,
|
||||
height: int,
|
||||
steps: int,
|
||||
seed: int,
|
||||
) -> dict:
|
||||
"""Generate image using Text2Image modifier"""
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
payload = {
|
||||
"input": {
|
||||
"request_id": str(uuid.uuid4()),
|
||||
"modifier": "Text2Image",
|
||||
"modifications": {
|
||||
"prompt": "a beautiful landscape with mountains and lakes",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"steps": 20,
|
||||
"seed": random.randint(0, 2**32 - 1)
|
||||
"prompt": prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"steps": steps,
|
||||
"seed": seed,
|
||||
},
|
||||
"workflow_json": {} # Empty since using modifier approach
|
||||
}
|
||||
}
|
||||
|
||||
req_data = dict(payload=worker_payload, auth_data=auth_data)
|
||||
worker_url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {worker_url}")
|
||||
|
||||
# Second request - call the worker endpoint
|
||||
worker_response = make_request(
|
||||
url=worker_url,
|
||||
payload=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
context="worker request"
|
||||
)
|
||||
|
||||
return worker_response
|
||||
return await endpoint.request("/generate/sync", payload, cost=COST)
|
||||
|
||||
|
||||
async def call_generate_workflow(
|
||||
client: Serverless,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
workflow_json: dict,
|
||||
) -> dict:
|
||||
"""Generate using custom workflow JSON"""
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
payload = {
|
||||
"input": {
|
||||
"request_id": str(uuid.uuid4()),
|
||||
"workflow_json": workflow_json,
|
||||
}
|
||||
}
|
||||
return await endpoint.request("/generate/sync", payload, cost=COST)
|
||||
|
||||
|
||||
# ---------------------- Demo Class ----------------------
|
||||
class APIDemo:
|
||||
def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False):
|
||||
self.client = client
|
||||
self.endpoint_name = endpoint_name
|
||||
self.upload_s3 = upload_s3
|
||||
self.s3_client = get_s3_client() if upload_s3 else None
|
||||
|
||||
if upload_s3 and not self.s3_client:
|
||||
log.warning("S3 upload requested but client creation failed. Images will only be saved locally.")
|
||||
|
||||
def extract_filename(self, response: dict) -> str | None:
|
||||
"""Extract the generated image filename from ComfyUI response"""
|
||||
if "comfyui_response" in response:
|
||||
for data in response["comfyui_response"].values():
|
||||
if isinstance(data, dict) and "outputs" in data:
|
||||
for node_output in data["outputs"].values():
|
||||
if "images" in node_output and node_output["images"]:
|
||||
return node_output["images"][0].get("filename")
|
||||
return None
|
||||
|
||||
async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
|
||||
"""Fetch and save image locally from the worker, optionally upload to S3"""
|
||||
os.makedirs("generated_images", exist_ok=True)
|
||||
return await self._fetch_image(worker_url, filename, local_name)
|
||||
|
||||
def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None:
|
||||
"""Upload a local file to S3 and return the S3 URL"""
|
||||
if not self.s3_client:
|
||||
return None
|
||||
|
||||
try:
|
||||
self.s3_client.upload_file(
|
||||
local_path,
|
||||
S3_BUCKET_NAME,
|
||||
s3_key,
|
||||
ExtraArgs={"ContentType": "image/png"}
|
||||
)
|
||||
s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}"
|
||||
print(f" ☁️ Uploaded to S3: {s3_key}")
|
||||
return s3_url
|
||||
except Exception as e:
|
||||
log.error(f"Failed to upload to S3: {e}")
|
||||
return None
|
||||
|
||||
async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
|
||||
"""Fetch image from worker's /view endpoint and save locally"""
|
||||
if not worker_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
url = f"{worker_url}/view"
|
||||
params = {"filename": filename, "type": "output"}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, ssl=False) as resp:
|
||||
if resp.status == 200:
|
||||
path = f"generated_images/{local_name}"
|
||||
image_data = await resp.read()
|
||||
with open(path, "wb") as f:
|
||||
f.write(image_data)
|
||||
print(f" 💾 Saved: {path}")
|
||||
|
||||
# Upload to S3 if enabled
|
||||
if self.upload_s3 and self.s3_client:
|
||||
s3_key = f"comfyui/{local_name}"
|
||||
self._upload_to_s3(path, s3_key)
|
||||
|
||||
return path
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def demo_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
width: int,
|
||||
height: int,
|
||||
steps: int,
|
||||
seed: int | None,
|
||||
):
|
||||
"""Demo: Generate image from text prompt"""
|
||||
print("=" * 60)
|
||||
print("COMFYUI TEXT-TO-IMAGE DEMO")
|
||||
print("=" * 60)
|
||||
|
||||
if seed is None:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}")
|
||||
print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}")
|
||||
print("\n🎨 Generating image...")
|
||||
|
||||
response = await call_generate(
|
||||
self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
print("\n✅ Generation complete!")
|
||||
|
||||
# Get worker URL for fetching images
|
||||
worker_url = response.get("url", "")
|
||||
print(f"Worker URL: {worker_url}")
|
||||
|
||||
# Fetch and save image
|
||||
if "response" in response:
|
||||
filename = self.extract_filename(response["response"])
|
||||
if filename:
|
||||
path = await self.save_image(worker_url, filename, f"comfy_{seed}.png")
|
||||
if not path:
|
||||
print(f"❌ Failed to fetch image")
|
||||
else:
|
||||
print("❌ No image in response")
|
||||
else:
|
||||
print("❌ Unexpected response format")
|
||||
|
||||
async def demo_workflow(self, workflow_file: str):
|
||||
"""Demo: Generate using custom workflow file"""
|
||||
print("=" * 60)
|
||||
print("COMFYUI CUSTOM WORKFLOW DEMO")
|
||||
print("=" * 60)
|
||||
|
||||
if not os.path.exists(workflow_file):
|
||||
log.error(f"Workflow file not found: {workflow_file}")
|
||||
return
|
||||
|
||||
with open(workflow_file, "r") as f:
|
||||
workflow_json = json.load(f)
|
||||
|
||||
print(f"Workflow: {workflow_file}")
|
||||
print("\n🎨 Generating...")
|
||||
|
||||
response = await call_generate_workflow(
|
||||
self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
workflow_json=workflow_json,
|
||||
)
|
||||
|
||||
print("\n✅ Generation complete!")
|
||||
|
||||
worker_url = response.get("url", "")
|
||||
|
||||
if "response" in response:
|
||||
filename = self.extract_filename(response["response"])
|
||||
if filename:
|
||||
path = await self.save_image(worker_url, filename, "workflow.png")
|
||||
if not path:
|
||||
print(f"❌ Failed to fetch image")
|
||||
else:
|
||||
print("❌ No image in response")
|
||||
else:
|
||||
print("❌ Unexpected response format")
|
||||
|
||||
|
||||
# ---------------------- CLI ----------------------
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)")
|
||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||
p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT",
|
||||
help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')")
|
||||
p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead")
|
||||
p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})")
|
||||
p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})")
|
||||
p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})")
|
||||
p.add_argument("--seed", type=int, default=None, help="Seed (default: random)")
|
||||
p.add_argument("--s3", action="store_true",
|
||||
help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)")
|
||||
return p
|
||||
|
||||
|
||||
async def main_async():
|
||||
args = build_arg_parser().parse_args()
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Using endpoint: {args.endpoint}")
|
||||
if args.s3:
|
||||
print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})")
|
||||
|
||||
try:
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.endpoint, upload_s3=args.s3)
|
||||
|
||||
if args.workflow:
|
||||
await demo.demo_workflow(workflow_file=args.workflow)
|
||||
else:
|
||||
await demo.demo_prompt(
|
||||
prompt=args.prompt,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
steps=args.steps,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
except AttributeError as e:
|
||||
if "API key" in str(e):
|
||||
log.error("API key missing. Set VAST_API_KEY environment variable.")
|
||||
else:
|
||||
log.error(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
log.error(f"Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lib.test_utils import test_args
|
||||
|
||||
args = test_args.parse_args()
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
|
||||
if endpoint_api_key:
|
||||
result = call_text2image_workflow(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
if result is None:
|
||||
log.error("Text2Image workflow failed")
|
||||
else:
|
||||
print(result)
|
||||
else:
|
||||
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
|
||||
asyncio.run(main_async())
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import dataclasses
|
||||
from enum import Enum
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
@@ -15,12 +13,6 @@ from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
class ModelType(Enum):
|
||||
image = "image"
|
||||
audio = "audio"
|
||||
video = "video"
|
||||
|
||||
|
||||
def count_workload() -> float:
|
||||
# Always 100.0 where there is a single instance of ComfyUI handling requests
|
||||
# Results will indicate % or a job completed per second. Avoids sub 0.1 sec performance indication
|
||||
@@ -29,11 +21,6 @@ def count_workload() -> float:
|
||||
@dataclasses.dataclass
|
||||
class ComfyWorkflowData(ApiPayload):
|
||||
input: dict
|
||||
model_type: ModelType = dataclasses.field(
|
||||
default_factory=lambda: ModelType(
|
||||
os.environ.get("MODEL_TYPE", "image").lower()
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_test(cls):
|
||||
@@ -43,17 +30,15 @@ class ComfyWorkflowData(ApiPayload):
|
||||
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
|
||||
"""
|
||||
# Try to load benchmark.json
|
||||
#Note: We should cross check with Rob if the audio sample benchmark file is correct
|
||||
model_type = ModelType(os.environ.get("MODEL_TYPE", "image").lower())
|
||||
benchmark_file = Path(f"workers/comfyui-json/misc/benchmark_{model_type.value}.json")
|
||||
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
|
||||
|
||||
if benchmark_file.exists():
|
||||
try:
|
||||
with open(benchmark_file, "r") as f:
|
||||
benchmark_workflow = json.load(f)
|
||||
log.info(f"using benchmark json file for {model_type.value}")
|
||||
return cls(
|
||||
input={
|
||||
"request_id": f"{model_type.value}-{random.randint(1000, 99999)}",
|
||||
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||
"workflow_json": benchmark_workflow
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
{
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": 98942092715729,
|
||||
"steps": 50,
|
||||
"cfg": 4.98,
|
||||
"sampler_name": "dpmpp_3m_sde_gpu",
|
||||
"scheduler": "exponential",
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"11",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"ckpt_name": "stable-audio-open-1.0.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "heaven church electronic dance music",
|
||||
"clip": [
|
||||
"10",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "",
|
||||
"clip": [
|
||||
"10",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"10": {
|
||||
"inputs": {
|
||||
"clip_name": "t5-base.safetensors",
|
||||
"type": "stable_audio",
|
||||
"device": "default"
|
||||
},
|
||||
"class_type": "CLIPLoader",
|
||||
"_meta": {
|
||||
"title": "Load CLIP"
|
||||
}
|
||||
},
|
||||
"11": {
|
||||
"inputs": {
|
||||
"seconds": 47.6,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentAudio",
|
||||
"_meta": {
|
||||
"title": "EmptyLatentAudio"
|
||||
}
|
||||
},
|
||||
"12": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecodeAudio",
|
||||
"_meta": {
|
||||
"title": "VAEDecodeAudio"
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"inputs": {
|
||||
"filename_prefix": "audio/ComfyUI",
|
||||
"audioUI": "",
|
||||
"audio": [
|
||||
"12",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveAudio",
|
||||
"_meta": {
|
||||
"title": "SaveAudio"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
{
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": 588445435278533,
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "text, watermark",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,216 +0,0 @@
|
||||
{
|
||||
"90": {
|
||||
"inputs": {
|
||||
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||
"type": "wan",
|
||||
"device": "default"
|
||||
},
|
||||
"class_type": "CLIPLoader",
|
||||
"_meta": {
|
||||
"title": "Load CLIP"
|
||||
}
|
||||
},
|
||||
"91": {
|
||||
"inputs": {
|
||||
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW",
|
||||
"clip": [
|
||||
"90",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Negative Prompt)"
|
||||
}
|
||||
},
|
||||
"92": {
|
||||
"inputs": {
|
||||
"vae_name": "wan_2.1_vae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "Load VAE"
|
||||
}
|
||||
},
|
||||
"93": {
|
||||
"inputs": {
|
||||
"shift": 8.000000000000002,
|
||||
"model": [
|
||||
"101",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "ModelSamplingSD3"
|
||||
}
|
||||
},
|
||||
"94": {
|
||||
"inputs": {
|
||||
"shift": 8,
|
||||
"model": [
|
||||
"102",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "ModelSamplingSD3"
|
||||
}
|
||||
},
|
||||
"95": {
|
||||
"inputs": {
|
||||
"add_noise": "disable",
|
||||
"noise_seed": 0,
|
||||
"steps": 20,
|
||||
"cfg": 3.5,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"start_at_step": 10,
|
||||
"end_at_step": 10000,
|
||||
"return_with_leftover_noise": "disable",
|
||||
"model": [
|
||||
"94",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"99",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"91",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"96",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced",
|
||||
"_meta": {
|
||||
"title": "KSampler (Advanced)"
|
||||
}
|
||||
},
|
||||
"96": {
|
||||
"inputs": {
|
||||
"add_noise": "enable",
|
||||
"noise_seed": "__RANDOM_INT__",
|
||||
"steps": 20,
|
||||
"cfg": 3.5,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"start_at_step": 0,
|
||||
"end_at_step": 10,
|
||||
"return_with_leftover_noise": "enable",
|
||||
"model": [
|
||||
"93",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"99",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"91",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"104",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced",
|
||||
"_meta": {
|
||||
"title": "KSampler (Advanced)"
|
||||
}
|
||||
},
|
||||
"97": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"95",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"92",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"98": {
|
||||
"inputs": {
|
||||
"filename_prefix": "video/ComfyUI",
|
||||
"format": "auto",
|
||||
"codec": "auto",
|
||||
"video": [
|
||||
"100",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveVideo",
|
||||
"_meta": {
|
||||
"title": "Save Video"
|
||||
}
|
||||
},
|
||||
"99": {
|
||||
"inputs": {
|
||||
"text": "Beautiful young European woman with honey blonde hair gracefully turning her head back over shoulder, gentle smile, bright eyes looking at camera. Hair flowing in slow motion as she turns. Soft natural lighting, clean background, cinematic portrait.",
|
||||
"clip": [
|
||||
"90",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Positive Prompt)"
|
||||
}
|
||||
},
|
||||
"100": {
|
||||
"inputs": {
|
||||
"fps": 16,
|
||||
"images": [
|
||||
"97",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CreateVideo",
|
||||
"_meta": {
|
||||
"title": "Create Video"
|
||||
}
|
||||
},
|
||||
"101": {
|
||||
"inputs": {
|
||||
"unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "Load Diffusion Model"
|
||||
}
|
||||
},
|
||||
"102": {
|
||||
"inputs": {
|
||||
"unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "Load Diffusion Model"
|
||||
}
|
||||
},
|
||||
"104": {
|
||||
"inputs": {
|
||||
"width": 640,
|
||||
"height": 640,
|
||||
"length": 81,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyHunyuanLatentVideo",
|
||||
"_meta": {
|
||||
"title": "EmptyHunyuanLatentVideo"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import dataclasses
|
||||
import base64
|
||||
from typing import Optional, Union, Type
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web, ClientResponse
|
||||
|
||||
from lib.backend import Backend, LogAction
|
||||
@@ -13,6 +14,7 @@ from .data_types import ComfyWorkflowData
|
||||
|
||||
|
||||
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
|
||||
COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:18188") # Raw ComfyUI server
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
||||
@@ -108,8 +110,39 @@ async def handle_ping(_):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
async def handle_view(request: web.Request) -> web.Response:
|
||||
"""Proxy /view requests to raw ComfyUI server to fetch generated images"""
|
||||
# Forward query params to raw ComfyUI (not the API wrapper)
|
||||
query_string = request.query_string
|
||||
url = f"{COMFYUI_URL}/view?{query_string}"
|
||||
|
||||
log.debug(f"Proxying /view request to: {url}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 200:
|
||||
content = await resp.read()
|
||||
return web.Response(
|
||||
body=content,
|
||||
status=200,
|
||||
content_type=resp.content_type or "image/png"
|
||||
)
|
||||
else:
|
||||
text = await resp.text()
|
||||
return web.Response(
|
||||
text=text,
|
||||
status=resp.status,
|
||||
content_type="text/plain"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error proxying /view: {e}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
|
||||
web.get("/view", handle_view),
|
||||
web.get("/ping", handle_ping),
|
||||
]
|
||||
|
||||
|
||||
@@ -7,20 +7,13 @@ from lib.test_utils import print_truncate_res
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
|
||||
"""
|
||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
from vastai import Serverless
|
||||
|
||||
|
||||
def call_default_workflow(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
ENDPOINT_NAME = "my-comfyui-endpoint"
|
||||
COST = 100 # Use a constant cost for image generation
|
||||
|
||||
def call_default_workflow(client: Serverless) -> None:
|
||||
WORKER_ENDPOINT = "/prompt"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
@@ -82,6 +75,7 @@ def call_custom_workflow_for_sd3(
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
request_idx=message["request_idx"],
|
||||
)
|
||||
workflow = {
|
||||
"3": {
|
||||
|
||||
+33
-26
@@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
|
||||
|
||||
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
||||
|
||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
|
||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
|
||||
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
|
||||
|
||||
|
||||
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
||||
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
|
||||
## Client Setup (Demo)
|
||||
|
||||
@@ -34,38 +33,20 @@ uv pip install -r requirements.txt
|
||||
|
||||
Several examples have been provided in the client to help you get started with your own implementation.
|
||||
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
First, set your API key as an environment variable:
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||
export VAST_API_KEY=<your_api_key>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||
```
|
||||
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
|
||||
|
||||
### Chat Completion (streaming)
|
||||
|
||||
Call to `/v1/chat/completions` with streaming response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
|
||||
Call to `/v1/chat/completions` with tool and json response.
|
||||
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Interactive Chat (streaming)
|
||||
@@ -75,6 +56,32 @@ Interactive session with calls to `/v1/chat/completions`.
|
||||
Type `clear` to clear the chat history or `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
|
||||
Call to `/v1/chat/completions` with tool and json response.
|
||||
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
|
||||
+375
-429
@@ -1,14 +1,15 @@
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from urllib.parse import urljoin
|
||||
from typing import Dict, Any, Optional, Iterator, Union, List
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
||||
import argparse
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from vastai import Serverless
|
||||
import asyncio
|
||||
|
||||
# ---------------------- Logging ----------------------
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
@@ -16,135 +17,20 @@ logging.basicConfig(
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
COMPLETIONS_PROMPT = "the capital of USA is"
|
||||
# ---------------------- Prompts ----------------------
|
||||
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
|
||||
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
|
||||
|
||||
|
||||
class APIClient:
|
||||
"""Lightweight client focused solely on API communication"""
|
||||
|
||||
# Remove the generic WORKER_ENDPOINT since we're now going direct
|
||||
DEFAULT_COST = 100
|
||||
DEFAULT_TIMEOUT = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_group_name: str,
|
||||
api_key: str,
|
||||
server_url: str,
|
||||
endpoint_api_key: str,
|
||||
):
|
||||
self.endpoint_group_name = endpoint_group_name
|
||||
self.api_key = api_key
|
||||
self.server_url = server_url
|
||||
self.endpoint_api_key = endpoint_api_key
|
||||
|
||||
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
||||
"""Get worker URL and auth data from routing service"""
|
||||
if not self.endpoint_api_key:
|
||||
raise ValueError("No valid endpoint API key available")
|
||||
|
||||
route_payload = {
|
||||
"endpoint": self.endpoint_group_name,
|
||||
"api_key": self.endpoint_api_key,
|
||||
"cost": cost,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
urljoin(self.server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=self.DEFAULT_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create auth data from routing response"""
|
||||
return {
|
||||
"signature": message["signature"],
|
||||
"cost": message["cost"],
|
||||
"endpoint": message["endpoint"],
|
||||
"reqnum": message["reqnum"],
|
||||
"url": message["url"],
|
||||
}
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
payload: Dict[str, Any],
|
||||
endpoint: str,
|
||||
method: str = "POST",
|
||||
stream: bool = False,
|
||||
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
"""Make request directly to the specific worker endpoint"""
|
||||
# Get worker URL and auth data
|
||||
cost = payload.get("max_tokens", self.DEFAULT_COST)
|
||||
message = self._get_worker_url(cost=cost)
|
||||
worker_url = message["url"]
|
||||
auth_data = self._create_auth_data(message)
|
||||
|
||||
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
|
||||
|
||||
url = urljoin(worker_url, endpoint)
|
||||
log.debug(f"Making direct request to: {url}")
|
||||
log.debug(f"Payload: {req_data}")
|
||||
|
||||
# Make the request using the specified method
|
||||
if method.upper() == "POST":
|
||||
response = requests.post(
|
||||
url, json=req_data, stream=stream, verify=get_cert_file_path()
|
||||
)
|
||||
elif method.upper() == "GET":
|
||||
response = requests.get(
|
||||
url, params=req_data, stream=stream, verify=get_cert_file_path()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
if stream:
|
||||
return self._handle_streaming_response(response)
|
||||
else:
|
||||
return response.json()
|
||||
|
||||
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
|
||||
"""Handle streaming response and yield tokens"""
|
||||
try:
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line:
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
yield data # Yield the full chunk
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
log.error(f"Error handling streaming response: {e}")
|
||||
raise
|
||||
|
||||
def call_completions(
|
||||
self, config: CompletionConfig
|
||||
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
payload = config.to_dict()
|
||||
|
||||
return self._make_request(
|
||||
payload=payload, endpoint="/v1/completions", stream=config.stream
|
||||
)
|
||||
|
||||
def call_chat_completions(
|
||||
self, config: ChatCompletionConfig
|
||||
) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
payload = config.to_dict()
|
||||
|
||||
return self._make_request(
|
||||
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
|
||||
)
|
||||
TOOLS_PROMPT = (
|
||||
"Can you list the files in the current working directory and tell me what you see? "
|
||||
"What do you think this directory might be for?"
|
||||
)
|
||||
|
||||
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
|
||||
MAX_TOKENS = 1024
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
|
||||
# ---------------------- Tooling ----------------------
|
||||
class ToolManager:
|
||||
"""Handles tool definitions and execution"""
|
||||
|
||||
@@ -164,7 +50,7 @@ class ToolManager:
|
||||
|
||||
@staticmethod
|
||||
def get_ls_tool_definition() -> List[Dict[str, Any]]:
|
||||
"""Get the ls tool definition"""
|
||||
"""OpenAI-compatible tool schema"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
@@ -178,98 +64,228 @@ class ToolManager:
|
||||
|
||||
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
||||
"""Execute a tool call and return the result"""
|
||||
function_name = tool_call["function"]["name"]
|
||||
|
||||
function_name = (tool_call.get("function") or {}).get("name")
|
||||
if function_name == "list_files":
|
||||
return self.list_files()
|
||||
else:
|
||||
raise ValueError(f"Unknown tool function: {function_name}")
|
||||
raise ValueError(f"Unknown tool function: {function_name}")
|
||||
|
||||
|
||||
# ----- Helpers to handle streamed tool_calls assembly -----
|
||||
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
|
||||
"""
|
||||
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
|
||||
We merge into a per-index state dict until the assistant message finishes.
|
||||
"""
|
||||
idx = tc_delta.get("index")
|
||||
if idx is None:
|
||||
return
|
||||
|
||||
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
|
||||
|
||||
if tc_delta.get("id"):
|
||||
entry["id"] = tc_delta["id"]
|
||||
|
||||
fn_delta = tc_delta.get("function") or {}
|
||||
if "name" in fn_delta and fn_delta["name"]:
|
||||
entry["function"]["name"] = fn_delta["name"]
|
||||
if "arguments" in fn_delta and fn_delta["arguments"]:
|
||||
entry["function"]["arguments"] += fn_delta["arguments"]
|
||||
|
||||
|
||||
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
return [state[i] for i in sorted(state.keys())]
|
||||
|
||||
|
||||
# ---- OpenAI-compatible calls (non-streaming) ----
|
||||
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
}
|
||||
}
|
||||
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
# ---- Streaming variants ----
|
||||
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
}
|
||||
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
|
||||
# ---------------------- Demo Runner ----------------------
|
||||
class APIDemo:
|
||||
"""Demo and testing functionality for the API client"""
|
||||
|
||||
def __init__(
|
||||
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
|
||||
):
|
||||
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.endpoint_name = endpoint_name
|
||||
self.tool_manager = tool_manager or ToolManager()
|
||||
|
||||
def handle_streaming_response(
|
||||
self, response_stream, show_reasoning: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Handle streaming chat response and display all output.
|
||||
"""
|
||||
|
||||
# ----- Streaming handler -----
|
||||
async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
|
||||
full_response = ""
|
||||
reasoning_content = ""
|
||||
reasoning_started = False
|
||||
content_started = False
|
||||
printed_reasoning = False
|
||||
printed_answer = False
|
||||
finish_reason = None
|
||||
|
||||
for chunk in response_stream:
|
||||
# Normalize the chunk
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.strip()
|
||||
if chunk.startswith("data: "):
|
||||
chunk = chunk[6:].strip()
|
||||
if chunk in ["[DONE]", ""]:
|
||||
continue
|
||||
try:
|
||||
parsed_chunk = json.loads(chunk)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
elif isinstance(chunk, dict):
|
||||
parsed_chunk = chunk
|
||||
else:
|
||||
continue
|
||||
async for chunk in stream:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Track finish reason
|
||||
if choice.get("finish_reason"):
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
# Parse delta from the chunk
|
||||
choices = parsed_chunk.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
reasoning_token = delta.get("reasoning_content", "")
|
||||
content_token = delta.get("content", "")
|
||||
|
||||
# Print reasoning token if applicable
|
||||
if show_reasoning and reasoning_token:
|
||||
if not reasoning_started:
|
||||
# reasoning tokens
|
||||
rc = delta.get("reasoning_content")
|
||||
if rc and show_reasoning:
|
||||
if not printed_reasoning:
|
||||
print("\n🧠 Reasoning: ", end="", flush=True)
|
||||
reasoning_started = True
|
||||
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
|
||||
reasoning_content += reasoning_token
|
||||
printed_reasoning = True
|
||||
print(rc, end="", flush=True)
|
||||
reasoning_content += rc
|
||||
|
||||
# Print content token
|
||||
if content_token:
|
||||
if not content_started:
|
||||
if show_reasoning and reasoning_started:
|
||||
print(f"\n💬 Response: ", end="", flush=True)
|
||||
# content tokens
|
||||
content_part = delta.get("content")
|
||||
if content_part:
|
||||
if not printed_answer:
|
||||
if show_reasoning and printed_reasoning:
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
else:
|
||||
print("Assistant: ", end="", flush=True)
|
||||
content_started = True
|
||||
print(content_token, end="", flush=True)
|
||||
full_response += content_token
|
||||
|
||||
print() # Ensure newline after response
|
||||
printed_answer = True
|
||||
print(content_part, end="", flush=True)
|
||||
full_response += content_part
|
||||
|
||||
print() # newline
|
||||
if show_reasoning:
|
||||
if reasoning_started or content_started:
|
||||
if printed_reasoning or printed_answer:
|
||||
print("\nStreaming completed.")
|
||||
if reasoning_started:
|
||||
if printed_reasoning:
|
||||
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
||||
if content_started:
|
||||
if printed_answer:
|
||||
print(f"Response tokens: {len(full_response.split())}")
|
||||
if finish_reason:
|
||||
print(f"Finish reason: {finish_reason}")
|
||||
|
||||
return full_response
|
||||
|
||||
async def demo_completions(self) -> None:
|
||||
print("=" * 60)
|
||||
print("COMPLETIONS DEMO")
|
||||
print("=" * 60)
|
||||
|
||||
def test_tool_support(self) -> bool:
|
||||
"""Test if the endpoint supports function calling"""
|
||||
log.debug("Testing endpoint tool calling support...")
|
||||
response = await call_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
prompt=COMPLETIONS_PROMPT,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
print("\nResponse:")
|
||||
print(json.dumps(response, indent=2))
|
||||
|
||||
# Try a simple request with minimal tools to test support
|
||||
async def demo_chat(self, use_streaming: bool = True) -> None:
|
||||
print("=" * 60)
|
||||
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
|
||||
print("=" * 60)
|
||||
|
||||
messages = [{"role": "user", "content": CHAT_PROMPT}]
|
||||
|
||||
if use_streaming:
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
try:
|
||||
await self.handle_streaming_response(stream, show_reasoning=True)
|
||||
except Exception as e:
|
||||
log.error("\nError during streaming: %s", e, exc_info=True)
|
||||
else:
|
||||
response = await call_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
choice = (response.get("choices") or [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
|
||||
if reasoning:
|
||||
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
||||
print(f"\n💬 Assistant: {content}")
|
||||
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
async def test_tool_support(self) -> bool:
|
||||
"""Probe that tool schema is accepted (no actual call)"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
minimal_tool = [
|
||||
{
|
||||
@@ -277,179 +293,158 @@ class APIDemo:
|
||||
"function": {"name": "test_function", "description": "Test function"},
|
||||
}
|
||||
]
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
tools=minimal_tool,
|
||||
tool_choice="none", # Don't actually call the tool
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.client.call_chat_completions(config)
|
||||
_ = await call_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
tools=minimal_tool,
|
||||
tool_choice="none",
|
||||
max_tokens=10
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error: Endpoint does not support tool calling: {e}")
|
||||
log.error("Endpoint does not support tool calling: %s", e)
|
||||
return False
|
||||
|
||||
def demo_completions(self) -> None:
|
||||
"""Demo: test basic completions endpoint"""
|
||||
print("=" * 60)
|
||||
print("COMPLETIONS DEMO")
|
||||
print("=" * 60)
|
||||
|
||||
config = CompletionConfig(
|
||||
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
|
||||
)
|
||||
response = self.client.call_completions(config)
|
||||
|
||||
if isinstance(response, dict):
|
||||
print("\nResponse:")
|
||||
print(json.dumps(response, indent=2))
|
||||
else:
|
||||
log.error("Unexpected response format")
|
||||
|
||||
def demo_chat(self, use_streaming: bool = True) -> None:
|
||||
"""
|
||||
Demo: test chat completions endpoint with optional streaming
|
||||
"""
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": CHAT_PROMPT}],
|
||||
stream=use_streaming,
|
||||
)
|
||||
|
||||
log.info(f"Testing chat completions with model '{self.model}'...")
|
||||
response = self.client.call_chat_completions(config)
|
||||
|
||||
if use_streaming:
|
||||
try:
|
||||
self.handle_streaming_response(response, show_reasoning=True)
|
||||
except Exception as e:
|
||||
log.error(f"\nError during streaming: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
else:
|
||||
if isinstance(response, dict):
|
||||
choice = response.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
reasoning = message.get("reasoning_content", "") or message.get(
|
||||
"reasoning", ""
|
||||
)
|
||||
|
||||
if reasoning:
|
||||
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
||||
|
||||
print(f"\n💬 Assistant: {content}")
|
||||
print(f"\nFull Response:")
|
||||
print(json.dumps(response, indent=2))
|
||||
else:
|
||||
log.error("Unexpected response format")
|
||||
|
||||
def demo_ls_tool(self) -> None:
|
||||
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
|
||||
async def demo_ls_tool(self) -> None:
|
||||
"""Ask to list files using function calling, then provide final analysis"""
|
||||
print("=" * 60)
|
||||
print("TOOL USE DEMO: List Directory Contents")
|
||||
print("=" * 60)
|
||||
|
||||
# Test if tools are supported first
|
||||
if not self.test_tool_support():
|
||||
if not await self.test_tool_support():
|
||||
return
|
||||
|
||||
# Request with tool available
|
||||
messages = [{"role": "user", "content": TOOLS_PROMPT}]
|
||||
messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
# First pass: let the model decide tools, stream tool_calls and partial content
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
tools=self.tool_manager.get_ls_tool_definition(),
|
||||
tool_choice="auto",
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
log.info(f"Making initial request with tool using model '{self.model}'...")
|
||||
response = self.client.call_chat_completions(config)
|
||||
assistant_content_buf: List[str] = []
|
||||
tool_calls_state: Dict[int, Dict[str, Any]] = {}
|
||||
printed_reasoning = False
|
||||
printed_answer = False
|
||||
|
||||
if not isinstance(response, dict):
|
||||
raise ValueError("Expected dict response for tool use")
|
||||
async for chunk in stream:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
choice = response.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
rc = delta.get("reasoning_content")
|
||||
if rc:
|
||||
if not printed_reasoning:
|
||||
printed_reasoning = True
|
||||
print("🧠 Reasoning: ", end="", flush=True)
|
||||
print(rc, end="", flush=True)
|
||||
|
||||
print(f"Assistant response: {message.get('content', 'No content')}")
|
||||
content_part = delta.get("content")
|
||||
if content_part:
|
||||
assistant_content_buf.append(content_part)
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
print(content_part, end="", flush=True)
|
||||
|
||||
# Check for tool calls
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not tool_calls:
|
||||
raise ValueError(
|
||||
"No tool calls made - model may not support function calling"
|
||||
)
|
||||
if "tool_calls" in delta and delta["tool_calls"]:
|
||||
for tc_delta in delta["tool_calls"]:
|
||||
_merge_tool_call_delta(tool_calls_state, tc_delta)
|
||||
|
||||
print(f"Tool calls detected: {len(tool_calls)}")
|
||||
# If no tool calls, we’re done.
|
||||
if not tool_calls_state:
|
||||
print("\n(No tool calls were made.)")
|
||||
return
|
||||
|
||||
# Execute the tool call
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call["function"]["name"]
|
||||
print(f"Executing tool: {function_name}")
|
||||
# Build assistant message with tool_calls
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": "".join(assistant_content_buf) if assistant_content_buf else None,
|
||||
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
|
||||
}
|
||||
messages.append(assistant_message)
|
||||
|
||||
tool_result = self.tool_manager.execute_tool_call(tool_call)
|
||||
print(f"Tool result:\n{tool_result}")
|
||||
# Execute tools and feed results back
|
||||
for tc in assistant_message["tool_calls"]:
|
||||
tool_name = (tc.get("function") or {}).get("name")
|
||||
call_id = tc.get("id")
|
||||
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
|
||||
|
||||
# Add tool result and continue conversation
|
||||
messages.append(message) # Add assistant's message with tool call
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
try:
|
||||
args = json.loads(raw_args) if raw_args.strip() else {}
|
||||
except Exception as e:
|
||||
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
|
||||
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
|
||||
continue
|
||||
|
||||
# Get final response
|
||||
final_config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=self.tool_manager.get_ls_tool_definition(),
|
||||
)
|
||||
try:
|
||||
if tool_name == "list_files":
|
||||
tool_result = self.tool_manager.list_files()
|
||||
else:
|
||||
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
|
||||
except Exception as e:
|
||||
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
|
||||
|
||||
print("Getting final response...")
|
||||
final_response = self.client.call_chat_completions(final_config)
|
||||
print("\n[Tool executed]", tool_name)
|
||||
print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
|
||||
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
|
||||
|
||||
if isinstance(final_response, dict):
|
||||
final_choice = final_response.get("choices", [{}])[0]
|
||||
final_message = final_choice.get("message", {})
|
||||
final_content = final_message.get("content", "")
|
||||
# Second pass: get final streamed answer after tool results
|
||||
stream2 = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("FINAL LLM ANALYSIS:")
|
||||
print("=" * 60)
|
||||
print(final_content)
|
||||
print("=" * 60)
|
||||
final_buf = []
|
||||
printed_reasoning2 = False
|
||||
printed_answer2 = False
|
||||
|
||||
def interactive_chat(self) -> None:
|
||||
async for chunk in stream2:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
rc2 = delta.get("reasoning_content")
|
||||
if rc2:
|
||||
if not printed_reasoning2:
|
||||
printed_reasoning2 = True
|
||||
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
|
||||
print(rc2, end="", flush=True)
|
||||
|
||||
c2 = delta.get("content")
|
||||
if c2:
|
||||
final_buf.append(c2)
|
||||
if not printed_answer2:
|
||||
printed_answer2 = True
|
||||
print("\n💬 Response (final): ", end="", flush=True)
|
||||
print(c2, end="", flush=True)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("FINAL LLM ANALYSIS:")
|
||||
print("=" * 60)
|
||||
print("".join(final_buf))
|
||||
print("=" * 60)
|
||||
|
||||
async def interactive_chat(self) -> None:
|
||||
"""Interactive chat session with streaming"""
|
||||
print("=" * 60)
|
||||
print("INTERACTIVE STREAMING CHAT")
|
||||
print("=" * 60)
|
||||
print(f"Using model: {self.model}")
|
||||
print("Type 'quit' to exit, 'clear' to clear history")
|
||||
print()
|
||||
|
||||
messages = []
|
||||
messages: List[Dict[str, Any]] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -467,16 +462,16 @@ class APIDemo:
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model, messages=messages, stream=True, temperature=0.7
|
||||
)
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = self.client.call_chat_completions(config)
|
||||
assistant_content = self.handle_streaming_response(
|
||||
response, show_reasoning=True
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=0.7
|
||||
)
|
||||
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
|
||||
|
||||
# Add assistant response to conversation history
|
||||
messages.append({"role": "assistant", "content": assistant_content})
|
||||
@@ -485,115 +480,66 @@ class APIDemo:
|
||||
print("\n👋 Chat interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
log.error(f"\nError: {e}")
|
||||
log.error("\nError: %s", e)
|
||||
continue
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function with CLI switches for different tests"""
|
||||
from lib.test_utils import test_args
|
||||
# ---------------------- CLI ----------------------
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
|
||||
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
|
||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||
|
||||
# Add mandatory model argument
|
||||
test_args.add_argument(
|
||||
"--model", required=True, help="Model to use for requests (required)"
|
||||
)
|
||||
modes = p.add_mutually_exclusive_group(required=False)
|
||||
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
|
||||
modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
|
||||
modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
|
||||
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
|
||||
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
|
||||
return p
|
||||
|
||||
# Add test mode arguments
|
||||
test_args.add_argument(
|
||||
"--completion", action="store_true", help="Test completions endpoint"
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--chat",
|
||||
action="store_true",
|
||||
help="Test chat completions endpoint (non-streaming)",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--chat-stream",
|
||||
action="store_true",
|
||||
help="Test chat completions endpoint with streaming",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--tools",
|
||||
action="store_true",
|
||||
help="Test function calling with ls tool (non-streaming)",
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="Start interactive streaming chat session",
|
||||
)
|
||||
|
||||
args = test_args.parse_args()
|
||||
async def main_async():
|
||||
args = build_arg_parser().parse_args()
|
||||
|
||||
# Check that only one test mode is selected
|
||||
test_modes = [
|
||||
args.completion,
|
||||
args.chat,
|
||||
args.chat_stream,
|
||||
args.tools,
|
||||
args.interactive,
|
||||
]
|
||||
selected_count = sum(test_modes)
|
||||
|
||||
if selected_count == 0:
|
||||
selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
|
||||
if selected == 0:
|
||||
print("Please specify exactly one test mode:")
|
||||
print(" --completion : Test completions endpoint")
|
||||
print(" --chat : Test chat completions endpoint (non-streaming)")
|
||||
print(" --chat-stream : Test chat completions endpoint with streaming")
|
||||
print(" --tools : Test function calling with ls tool (non-streaming)")
|
||||
print(" --tools : Test function calling with ls tool")
|
||||
print(" --interactive : Start interactive streaming chat session")
|
||||
print(
|
||||
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
|
||||
)
|
||||
print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
|
||||
sys.exit(1)
|
||||
elif selected_count > 1:
|
||||
elif selected > 1:
|
||||
print("Please specify exactly one test mode")
|
||||
sys.exit(1)
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Using model: {args.model}")
|
||||
print(f"Using endpoint: {args.endpoint}")
|
||||
|
||||
|
||||
try:
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
|
||||
|
||||
if not endpoint_api_key:
|
||||
log.error(
|
||||
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Create the core API client
|
||||
client = APIClient(
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
api_key=args.api_key,
|
||||
server_url=Endpoint.get_autoscaler_server_url(args.instance),
|
||||
endpoint_api_key=endpoint_api_key,
|
||||
)
|
||||
|
||||
# Create tool manager and demo (passing the model parameter)
|
||||
tool_manager = ToolManager()
|
||||
demo = APIDemo(client, args.model, tool_manager)
|
||||
|
||||
print(f"Using model: {args.model}")
|
||||
print("=" * 60)
|
||||
|
||||
# Run the selected test
|
||||
if args.completion:
|
||||
demo.demo_completions()
|
||||
elif args.chat:
|
||||
demo.demo_chat(use_streaming=False)
|
||||
elif args.chat_stream:
|
||||
demo.demo_chat(use_streaming=True)
|
||||
elif args.tools:
|
||||
demo.demo_ls_tool()
|
||||
elif args.interactive:
|
||||
demo.interactive_chat()
|
||||
if args.completion:
|
||||
await demo.demo_completions()
|
||||
elif args.chat:
|
||||
await demo.demo_chat(use_streaming=False)
|
||||
elif args.chat_stream:
|
||||
await demo.demo_chat(use_streaming=True)
|
||||
elif args.tools:
|
||||
await demo.demo_ls_tool()
|
||||
elif args.interactive:
|
||||
await demo.interactive_chat()
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error during test: {e}", exc_info=True)
|
||||
log.error("Error during test: %s", e, exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
asyncio.run(main_async())
|
||||
|
||||
@@ -11,6 +11,7 @@ MODEL_SERVER_START_LOG_MSG = [
|
||||
"llama runner started", # Ollama
|
||||
'"message":"Connected","target":"text_generation_router"', # TGI
|
||||
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
||||
"main: model loaded" # llama.cpp
|
||||
]
|
||||
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
@@ -34,6 +35,7 @@ backend = Backend(
|
||||
model_server_url=os.environ["MODEL_SERVER_URL"],
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
max_wait_time=600.0,
|
||||
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||
|
||||
+93
-9
@@ -1,19 +1,103 @@
|
||||
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
|
||||
# HuggingFace TGI PyWorker
|
||||
|
||||
1. `generate`: Generates the LLM's response to a given prompt in a single request.
|
||||
2. `generate_stream`: Streams the LLM's response token by token.
|
||||
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||
|
||||
Both endpoints use the following API payload format:
|
||||
## Instance Setup
|
||||
|
||||
1. Pick a template
|
||||
|
||||
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
|
||||
|
||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
|
||||
|
||||
The template can be configured via the template interface. You may want to change the model or startup arguments.
|
||||
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
|
||||
## Client Setup (Demo)
|
||||
|
||||
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vast-ai/pyworker
|
||||
cd pyworker
|
||||
pip install uv
|
||||
uv venv -p 3.12
|
||||
source .venv/bin/activate
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Using the Test Client
|
||||
|
||||
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
|
||||
|
||||
First, set your API key as an environment variable:
|
||||
|
||||
```bash
|
||||
export VAST_API_KEY=<your_api_key>
|
||||
```
|
||||
|
||||
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
|
||||
|
||||
### Generate (Streaming)
|
||||
|
||||
Call to `/generate_stream` with streaming response:
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
### Generate (Non-Streaming)
|
||||
|
||||
Call to `/generate` with json response:
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
### Interactive Session (Streaming)
|
||||
|
||||
Interactive session with streaming responses. Type `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
TGI provides two primary endpoints:
|
||||
|
||||
### Generate (Non-Streaming)
|
||||
|
||||
`/generate` - Returns the complete response in a single request.
|
||||
|
||||
```json
|
||||
{
|
||||
"inputs": "PROMPT",
|
||||
"inputs": "Your prompt here",
|
||||
"parameters": {
|
||||
"max_new_tokens": 250
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0.7,
|
||||
"return_full_text": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
|
||||
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
|
||||
approximately 2 seconds to complete.
|
||||
### Generate Stream (Streaming)
|
||||
|
||||
`/generate_stream` - Streams the response token by token.
|
||||
|
||||
```json
|
||||
{
|
||||
"inputs": "Your prompt here",
|
||||
"parameters": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0.7,
|
||||
"do_sample": true,
|
||||
"return_full_text": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Notes
|
||||
|
||||
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
|
||||
|
||||
+201
-104
@@ -1,11 +1,13 @@
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
from vastai import Serverless
|
||||
import asyncio
|
||||
|
||||
# ---------------------- Logging ----------------------
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
@@ -13,113 +15,208 @@ logging.basicConfig(
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
# ---------------------- Defaults ----------------------
|
||||
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
|
||||
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
|
||||
MAX_TOKENS = 1024
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
|
||||
|
||||
# ---------------------- API Calls ----------------------
|
||||
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
|
||||
"""Non-streaming generation via /generate endpoint"""
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"return_full_text": False,
|
||||
}
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=url,
|
||||
)
|
||||
|
||||
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
print(res)
|
||||
log.debug("POST /generate %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
|
||||
def call_generate_stream(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
WORKER_ENDPOINT = "/generate_stream"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
|
||||
"""Streaming generation via /generate_stream endpoint"""
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"do_sample": True,
|
||||
"return_full_text": False,
|
||||
}
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request(
|
||||
"/generate_stream",
|
||||
payload,
|
||||
cost=payload["parameters"]["max_new_tokens"],
|
||||
stream=True,
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
print(f"url: {url}")
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
)
|
||||
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
response = requests.post(url, json=req_data, stream=True)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
for line in response.iter_lines():
|
||||
payload = line.decode().lstrip("data:").rstrip()
|
||||
if payload:
|
||||
return resp["response"] # async generator
|
||||
|
||||
|
||||
# ---------------------- Demo Runner ----------------------
|
||||
class APIDemo:
|
||||
"""Demo and testing functionality for the TGI API client"""
|
||||
|
||||
def __init__(self, client: Serverless, endpoint_name: str):
|
||||
self.client = client
|
||||
self.endpoint_name = endpoint_name
|
||||
|
||||
async def handle_streaming_response(self, stream) -> str:
|
||||
"""Process streaming response and print tokens"""
|
||||
full_response = ""
|
||||
printed_answer = False
|
||||
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
print(tok, end="", flush=True)
|
||||
full_response += tok
|
||||
|
||||
print() # newline
|
||||
if printed_answer:
|
||||
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
|
||||
|
||||
return full_response
|
||||
|
||||
async def demo_generate(self) -> None:
|
||||
"""Demo non-streaming generation"""
|
||||
print("=" * 60)
|
||||
print("GENERATE DEMO (NON-STREAMING)")
|
||||
print("=" * 60)
|
||||
|
||||
response = await call_generate(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=DEFAULT_PROMPT,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
print(f"\n💬 Response: {response.get('generated_text', '')}")
|
||||
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
async def demo_generate_stream(self) -> None:
|
||||
"""Demo streaming generation"""
|
||||
print("=" * 60)
|
||||
print("GENERATE DEMO (STREAMING)")
|
||||
print("=" * 60)
|
||||
|
||||
stream = await call_generate_stream(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=DEFAULT_PROMPT,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
try:
|
||||
await self.handle_streaming_response(stream)
|
||||
except Exception as e:
|
||||
log.error("\nError during streaming: %s", e, exc_info=True)
|
||||
|
||||
async def interactive_chat(self) -> None:
|
||||
"""Interactive session with streaming generation"""
|
||||
print("=" * 60)
|
||||
print("INTERACTIVE STREAMING SESSION")
|
||||
print("=" * 60)
|
||||
print(f"Using endpoint: {self.endpoint_name}")
|
||||
print("Type 'quit' to exit")
|
||||
print()
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
print(data["token"]["text"], end="")
|
||||
sys.stdout.flush()
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
log.warning(f"Failed to parse streaming response: {e}")
|
||||
user_input = input("You: ").strip()
|
||||
|
||||
if user_input.lower() == "quit":
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
elif not user_input:
|
||||
continue
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
stream = await call_generate_stream(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=user_input,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
print(tok, end="", flush=True)
|
||||
full_response += tok
|
||||
print() # newline
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Session interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
log.error("\nError: %s", e)
|
||||
continue
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------- CLI ----------------------
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
|
||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||
|
||||
modes = p.add_mutually_exclusive_group(required=False)
|
||||
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
|
||||
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
|
||||
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
|
||||
return p
|
||||
|
||||
|
||||
async def main_async():
|
||||
args = build_arg_parser().parse_args()
|
||||
|
||||
selected = sum([args.generate, args.generate_stream, args.interactive])
|
||||
if selected == 0:
|
||||
print("Please specify exactly one test mode:")
|
||||
print(" --generate : Test generate endpoint (non-streaming)")
|
||||
print(" --generate-stream : Test generate endpoint with streaming")
|
||||
print(" --interactive : Start interactive streaming session")
|
||||
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
|
||||
sys.exit(1)
|
||||
elif selected > 1:
|
||||
print("Please specify exactly one test mode")
|
||||
sys.exit(1)
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Using endpoint: {args.endpoint}")
|
||||
|
||||
try:
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.endpoint)
|
||||
|
||||
if args.generate:
|
||||
await demo.demo_generate()
|
||||
elif args.generate_stream:
|
||||
await demo.demo_generate_stream()
|
||||
elif args.interactive:
|
||||
await demo.interactive_chat()
|
||||
|
||||
except Exception as e:
|
||||
log.error("Error during test: %s", e, exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lib.test_utils import test_args
|
||||
|
||||
args = test_args.parse_args()
|
||||
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
instance=args.instance,
|
||||
)
|
||||
if endpoint_api_key:
|
||||
try:
|
||||
call_generate(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
call_generate_stream(
|
||||
api_key=endpoint_api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error during API call: {e}")
|
||||
else:
|
||||
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
|
||||
asyncio.run(main_async())
|
||||
|
||||
Reference in New Issue
Block a user