Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 62fbfb061d | |||
| c772e1651b | |||
| ecc6a3ce0d |
+34
-8
@@ -30,7 +30,7 @@ from lib.data_types import (
|
|||||||
BenchmarkResult
|
BenchmarkResult
|
||||||
)
|
)
|
||||||
|
|
||||||
VERSION = "0.2.1"
|
VERSION = "0.2.0"
|
||||||
|
|
||||||
MSG_HISTORY_LEN = 100
|
MSG_HISTORY_LEN = 100
|
||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
@@ -235,10 +235,14 @@ class Backend:
|
|||||||
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
first_healthcheck = True
|
||||||
while True:
|
while True:
|
||||||
await sleep(10)
|
await sleep(10)
|
||||||
if self.__start_healthcheck is False:
|
if self.__start_healthcheck is False:
|
||||||
continue
|
continue
|
||||||
|
if first_healthcheck:
|
||||||
|
log.info(f"[healthcheck] First healthcheck starting (model is now loaded)")
|
||||||
|
first_healthcheck = False
|
||||||
try:
|
try:
|
||||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||||
async with self.healthcheck_session.get(health_check_url) as response:
|
async with self.healthcheck_session.get(health_check_url) as response:
|
||||||
@@ -256,9 +260,22 @@ class Backend:
|
|||||||
self.backend_errored(str(e))
|
self.backend_errored(str(e))
|
||||||
|
|
||||||
async def _start_tracking(self) -> None:
|
async def _start_tracking(self) -> None:
|
||||||
await gather(
|
log.info("Starting tracking tasks (read_logs, send_metrics_loop, healthcheck, send_delete_requests_loop)")
|
||||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
task_names = ["read_logs", "send_metrics_loop", "healthcheck", "send_delete_requests_loop"]
|
||||||
|
results = await gather(
|
||||||
|
self.__read_logs(),
|
||||||
|
self.metrics._send_metrics_loop(),
|
||||||
|
self.__healthcheck(),
|
||||||
|
self.metrics._send_delete_requests_loop(),
|
||||||
|
return_exceptions=True
|
||||||
)
|
)
|
||||||
|
# If we get here, one or more tasks exited (they should run forever)
|
||||||
|
log.error(f"CRITICAL: _start_tracking gather returned! This should never happen. Results: {results}")
|
||||||
|
for name, result in zip(task_names, results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
log.error(f"Tracking task '{name}' crashed with exception: {result}", exc_info=result)
|
||||||
|
elif result is not None:
|
||||||
|
log.warning(f"Tracking task '{name}' exited unexpectedly with result: {result}")
|
||||||
|
|
||||||
def backend_errored(self, msg: str) -> None:
|
def backend_errored(self, msg: str) -> None:
|
||||||
self.metrics._model_errored(msg)
|
self.metrics._model_errored(msg)
|
||||||
@@ -399,15 +416,20 @@ class Backend:
|
|||||||
# await sleep(5)
|
# await sleep(5)
|
||||||
try:
|
try:
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
|
log.info(f"[benchmark] Benchmark complete, max_throughput={max_throughput}, setting healthcheck=True")
|
||||||
self.__start_healthcheck = True
|
self.__start_healthcheck = True
|
||||||
self.metrics._model_loaded(
|
self.metrics._model_loaded(
|
||||||
max_throughput=max_throughput,
|
max_throughput=max_throughput,
|
||||||
)
|
)
|
||||||
|
log.info(f"[benchmark] _model_loaded() called, returning from handle_log_line")
|
||||||
except ClientConnectorError as e:
|
except ClientConnectorError as e:
|
||||||
log.debug(
|
log.debug(
|
||||||
f"failed to connect to comfyui api during benchmark"
|
f"failed to connect to model api during benchmark"
|
||||||
)
|
)
|
||||||
self.backend_errored(str(e))
|
self.backend_errored(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Unexpected error during benchmark: {e}", exc_info=True)
|
||||||
|
self.backend_errored(f"Benchmark failed: {e}")
|
||||||
case LogAction.ModelError if msg in log_line:
|
case LogAction.ModelError if msg in log_line:
|
||||||
log.debug(f"Got log line indicating error: {log_line}")
|
log.debug(f"Got log line indicating error: {log_line}")
|
||||||
self.backend_errored(msg)
|
self.backend_errored(msg)
|
||||||
@@ -419,10 +441,14 @@ class Backend:
|
|||||||
log.debug(f"tailing file: {self.model_log_file}")
|
log.debug(f"tailing file: {self.model_log_file}")
|
||||||
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
|
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
|
||||||
while True:
|
while True:
|
||||||
line = await f.readline()
|
try:
|
||||||
if line:
|
line = await f.readline()
|
||||||
await handle_log_line(line.rstrip())
|
if line:
|
||||||
else:
|
await handle_log_line(line.rstrip())
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(LOG_POLL_INTERVAL)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error processing log line: {e}", exc_info=True)
|
||||||
await asyncio.sleep(LOG_POLL_INTERVAL)
|
await asyncio.sleep(LOG_POLL_INTERVAL)
|
||||||
|
|
||||||
###########
|
###########
|
||||||
|
|||||||
+1
-1
@@ -66,7 +66,7 @@ class AuthData:
|
|||||||
"""data used to authenticate requester"""
|
"""data used to authenticate requester"""
|
||||||
|
|
||||||
cost: str
|
cost: str
|
||||||
endpoint_id: int
|
endpoint: str
|
||||||
reqnum: int
|
reqnum: int
|
||||||
request_idx: int
|
request_idx: int
|
||||||
signature: str
|
signature: str
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
@@ -17,6 +18,14 @@ DELETE_REQUESTS_INTERVAL = 1
|
|||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_logs():
|
||||||
|
"""Force flush all log handlers and stdout/stderr."""
|
||||||
|
for handler in logging.root.handlers:
|
||||||
|
handler.flush()
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_url() -> str:
|
def get_url() -> str:
|
||||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||||
@@ -119,22 +128,41 @@ class Metrics:
|
|||||||
await self.__send_delete_requests_and_reset()
|
await self.__send_delete_requests_and_reset()
|
||||||
|
|
||||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||||
|
loop_count = 0
|
||||||
|
first_loaded_send_done = False
|
||||||
while True:
|
while True:
|
||||||
await sleep(METRICS_UPDATE_INTERVAL)
|
await sleep(METRICS_UPDATE_INTERVAL)
|
||||||
|
loop_count += 1
|
||||||
elapsed = time.time() - self.last_metric_update
|
elapsed = time.time() - self.last_metric_update
|
||||||
|
# Log heartbeat every 30 seconds to confirm loop is running
|
||||||
|
if loop_count % 30 == 0:
|
||||||
|
log.debug(f"[heartbeat] metrics loop alive, loop_count={loop_count}, model_loaded={self.system_metrics.model_is_loaded}")
|
||||||
|
_flush_logs()
|
||||||
|
# Extra logging for first few iterations after model loads
|
||||||
|
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
|
||||||
|
log.info(f"[transition] First iteration with model_loaded=True, loop_count={loop_count}, elapsed={elapsed:.1f}")
|
||||||
|
_flush_logs()
|
||||||
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
||||||
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
||||||
await self.__send_metrics_and_reset()
|
await self.__send_metrics_and_reset()
|
||||||
elif self.update_pending or elapsed > 10:
|
elif self.update_pending or elapsed > 10:
|
||||||
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
||||||
await self.__send_metrics_and_reset()
|
await self.__send_metrics_and_reset()
|
||||||
|
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
|
||||||
|
first_loaded_send_done = True
|
||||||
|
log.info(f"[transition] First loaded metrics send complete, continuing to next iteration...")
|
||||||
|
_flush_logs()
|
||||||
|
|
||||||
def _model_loaded(self, max_throughput: float) -> None:
|
def _model_loaded(self, max_throughput: float) -> None:
|
||||||
|
log.info(f"MODEL LOADED: Setting model_is_loaded=True, max_throughput={max_throughput}")
|
||||||
|
_flush_logs()
|
||||||
self.system_metrics.model_loading_time = (
|
self.system_metrics.model_loading_time = (
|
||||||
time.time() - self.system_metrics.model_loading_start
|
time.time() - self.system_metrics.model_loading_start
|
||||||
)
|
)
|
||||||
self.system_metrics.model_is_loaded = True
|
self.system_metrics.model_is_loaded = True
|
||||||
self.model_metrics.max_throughput = max_throughput
|
self.model_metrics.max_throughput = max_throughput
|
||||||
|
log.info(f"MODEL LOADED: model_loading_time={self.system_metrics.model_loading_time}")
|
||||||
|
_flush_logs()
|
||||||
|
|
||||||
def _model_errored(self, error_msg: str) -> None:
|
def _model_errored(self, error_msg: str) -> None:
|
||||||
self.model_metrics.set_errored(error_msg)
|
self.model_metrics.set_errored(error_msg)
|
||||||
@@ -271,6 +299,7 @@ class Metrics:
|
|||||||
###########
|
###########
|
||||||
|
|
||||||
self.system_metrics.update_disk_usage()
|
self.system_metrics.update_disk_usage()
|
||||||
|
had_loadtime = loadtime_snapshot is not None and loadtime_snapshot > 0
|
||||||
|
|
||||||
sent = False
|
sent = False
|
||||||
for report_addr in self.report_addr:
|
for report_addr in self.report_addr:
|
||||||
@@ -279,8 +308,14 @@ class Metrics:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if sent:
|
if sent:
|
||||||
|
if had_loadtime:
|
||||||
|
log.info(f"FIRST LOADTIME METRICS SENT SUCCESSFULLY! loadtime={loadtime_snapshot}")
|
||||||
|
_flush_logs()
|
||||||
# clear the one-shot loadtime only if we actually sent *this* value
|
# clear the one-shot loadtime only if we actually sent *this* value
|
||||||
self.system_metrics.reset(expected=loadtime_snapshot)
|
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||||
self.update_pending = False
|
self.update_pending = False
|
||||||
self.model_metrics.reset()
|
self.model_metrics.reset()
|
||||||
self.last_metric_update = time.time()
|
self.last_metric_update = time.time()
|
||||||
|
if had_loadtime:
|
||||||
|
log.info(f"POST-SEND: reset complete, last_metric_update={self.last_metric_update}, continuing loop...")
|
||||||
|
_flush_logs()
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
from typing import List
|
from typing import List
|
||||||
import ssl
|
import ssl
|
||||||
from asyncio import run, gather
|
from asyncio import run, gather
|
||||||
@@ -12,7 +14,25 @@ from aiohttp import web
|
|||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_signal_handlers():
|
||||||
|
"""Setup signal handlers to log when process receives termination signals."""
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
sig_name = signal.Signals(signum).name
|
||||||
|
log.error(f"SIGNAL RECEIVED: {sig_name} ({signum}) - process is being terminated")
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
sys.exit(128 + signum)
|
||||||
|
|
||||||
|
# Handle common termination signals
|
||||||
|
for sig in [signal.SIGTERM, signal.SIGINT, signal.SIGHUP]:
|
||||||
|
try:
|
||||||
|
signal.signal(sig, signal_handler)
|
||||||
|
except (OSError, ValueError):
|
||||||
|
pass # Some signals may not be available
|
||||||
|
|
||||||
|
|
||||||
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||||
|
_setup_signal_handlers()
|
||||||
try:
|
try:
|
||||||
log.debug("getting certificate...")
|
log.debug("getting certificate...")
|
||||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||||
|
|||||||
+3
-7
@@ -75,7 +75,6 @@ def print_truncate_res(res: str):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ClientState:
|
class ClientState:
|
||||||
endpoint_group_name: str
|
endpoint_group_name: str
|
||||||
endpoint_id: int
|
|
||||||
api_key: str
|
api_key: str
|
||||||
server_url: str
|
server_url: str
|
||||||
worker_endpoint: str
|
worker_endpoint: str
|
||||||
@@ -96,7 +95,7 @@ class ClientState:
|
|||||||
self.status = ClientStatus.Error
|
self.status = ClientStatus.Error
|
||||||
return
|
return
|
||||||
route_payload = {
|
route_payload = {
|
||||||
"endpoint_id": self.endpoint_id,
|
"endpoint": self.endpoint_group_name,
|
||||||
"api_key": self.api_key,
|
"api_key": self.api_key,
|
||||||
"cost": self.payload.count_workload(),
|
"cost": self.payload.count_workload(),
|
||||||
}
|
}
|
||||||
@@ -245,19 +244,16 @@ def run_test(
|
|||||||
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
|
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
|
||||||
print_thread.daemon = True # makes threads get killed on program exit
|
print_thread.daemon = True # makes threads get killed on program exit
|
||||||
print_thread.start()
|
print_thread.start()
|
||||||
endpoint_info = Endpoint.get_endpoint_info(
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
|
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
|
||||||
)
|
)
|
||||||
if not endpoint_info:
|
if not endpoint_api_key:
|
||||||
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
||||||
return
|
return
|
||||||
endpoint_id = endpoint_info["id"]
|
|
||||||
endpoint_api_key = endpoint_info["api_key"]
|
|
||||||
try:
|
try:
|
||||||
for _ in range(num_requests):
|
for _ in range(num_requests):
|
||||||
client = ClientState(
|
client = ClientState(
|
||||||
endpoint_group_name=endpoint_group_name,
|
endpoint_group_name=endpoint_group_name,
|
||||||
endpoint_id=endpoint_id,
|
|
||||||
api_key=endpoint_api_key,
|
api_key=endpoint_api_key,
|
||||||
server_url=server_url,
|
server_url=server_url,
|
||||||
worker_endpoint=worker_endpoint,
|
worker_endpoint=worker_endpoint,
|
||||||
|
|||||||
@@ -1,16 +1,8 @@
|
|||||||
# ComfyUI PyWorker
|
# ComfyUI PyWorker
|
||||||
|
|
||||||
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture.
|
||||||
|
|
||||||
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
|
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
|
||||||
|
|
||||||
## Instance Setup
|
|
||||||
|
|
||||||
1. Pick a template
|
|
||||||
|
|
||||||
- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless))
|
|
||||||
|
|
||||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
@@ -18,88 +10,6 @@ This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) a
|
|||||||
|
|
||||||
A docker image is provided but you may use any if the above requirements are met.
|
A docker image is provided but you may use any if the above requirements are met.
|
||||||
|
|
||||||
## Client
|
|
||||||
|
|
||||||
The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage.
|
|
||||||
|
|
||||||
### Setup
|
|
||||||
|
|
||||||
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/vast-ai/pyworker
|
|
||||||
cd pyworker
|
|
||||||
pip install uv
|
|
||||||
uv venv -p 3.12
|
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Set your API key:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export VAST_API_KEY=<your_api_key>
|
|
||||||
```
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Default prompt
|
|
||||||
python -m workers.comfyui-json.client
|
|
||||||
|
|
||||||
# Custom prompt
|
|
||||||
python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow"
|
|
||||||
|
|
||||||
# With options
|
|
||||||
python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30
|
|
||||||
|
|
||||||
# Using a custom workflow file
|
|
||||||
python -m workers.comfyui-json.client --workflow my_workflow.json
|
|
||||||
|
|
||||||
# With S3 upload
|
|
||||||
python -m workers.comfyui-json.client --s3
|
|
||||||
```
|
|
||||||
|
|
||||||
### CLI Flags
|
|
||||||
|
|
||||||
| Flag | Default | Description |
|
|
||||||
|------|---------|-------------|
|
|
||||||
| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name |
|
|
||||||
| `--prompt` | (default) | Text prompt for image generation |
|
|
||||||
| `--workflow` | (none) | Path to custom workflow JSON file |
|
|
||||||
| `--width` | 512 | Image width in pixels |
|
|
||||||
| `--height` | 512 | Image height in pixels |
|
|
||||||
| `--steps` | 20 | Number of denoising steps |
|
|
||||||
| `--seed` | (random) | Random seed for reproducibility |
|
|
||||||
| `--s3` | (disabled) | Upload generated images to S3 |
|
|
||||||
|
|
||||||
### Output
|
|
||||||
|
|
||||||
Images are saved to `./generated_images/comfy_{seed}.png`.
|
|
||||||
|
|
||||||
### S3 Upload (Optional)
|
|
||||||
|
|
||||||
You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag.
|
|
||||||
|
|
||||||
**1. Set environment variables:**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com"
|
|
||||||
export S3_BUCKET_NAME="my-bucket"
|
|
||||||
export S3_ACCESS_KEY_ID="your-access-key-id"
|
|
||||||
export S3_SECRET_ACCESS_KEY="your-secret-access-key"
|
|
||||||
```
|
|
||||||
|
|
||||||
**2. Run with S3 upload enabled:**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3
|
|
||||||
```
|
|
||||||
|
|
||||||
Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`.
|
|
||||||
|
|
||||||
**Note:** Requires `boto3` (`pip install boto3`).
|
|
||||||
|
|
||||||
## Benchmarking
|
## Benchmarking
|
||||||
|
|
||||||
### Custom Benchmark Workflows
|
### Custom Benchmark Workflows
|
||||||
@@ -302,3 +212,11 @@ WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Client Libraries
|
||||||
|
|
||||||
|
See the test client examples for implementation details on how to integrate with the ComfyUI worker.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
|
||||||
+23
-300
@@ -1,312 +1,35 @@
|
|||||||
import os
|
from .data_types import count_workload
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import uuid
|
import uuid
|
||||||
import random
|
import random
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import random
|
||||||
import argparse
|
|
||||||
import aiohttp
|
|
||||||
|
|
||||||
from vastai import Serverless
|
from vastai import Serverless
|
||||||
|
|
||||||
# ---------------------- Config ----------------------
|
async def main():
|
||||||
DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed"
|
async with Serverless() as client:
|
||||||
ENDPOINT_NAME = "my-comfyui-endpoint"
|
endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name
|
||||||
DEFAULT_WIDTH = 512
|
|
||||||
DEFAULT_HEIGHT = 512
|
|
||||||
DEFAULT_STEPS = 20
|
|
||||||
COST = 100 # Fixed cost for ComfyUI requests
|
|
||||||
|
|
||||||
# Optional S3 Configuration (from environment variables)
|
payload = {
|
||||||
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
|
"input": {
|
||||||
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
"request_id": str(uuid.uuid4()),
|
||||||
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
|
"modifier": "Text2Image",
|
||||||
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
|
"modifications": {
|
||||||
|
"prompt": "a beautiful landscape with mountains and lakes",
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
|
"width": 1024,
|
||||||
log = logging.getLogger(__name__)
|
"height": 1024,
|
||||||
|
"steps": 20,
|
||||||
|
"seed": random.randint(0, 2**32 - 1)
|
||||||
def get_s3_client():
|
},
|
||||||
"""Create and return an S3 client configured for the S3-compatible endpoint"""
|
"workflow_json": {} # Empty since using modifier approach
|
||||||
try:
|
}
|
||||||
import boto3
|
|
||||||
from botocore.config import Config
|
|
||||||
except ImportError:
|
|
||||||
log.error("boto3 is required for S3 uploads. Install with: pip install boto3")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
|
|
||||||
log.error("S3 environment variables not fully configured. Required:")
|
|
||||||
log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return boto3.client(
|
|
||||||
"s3",
|
|
||||||
endpoint_url=S3_ENDPOINT_URL,
|
|
||||||
aws_access_key_id=S3_ACCESS_KEY_ID,
|
|
||||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
|
||||||
config=Config(signature_version="s3v4"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- API Functions ----------------------
|
|
||||||
async def call_generate(
|
|
||||||
client: Serverless,
|
|
||||||
*,
|
|
||||||
endpoint_name: str,
|
|
||||||
prompt: str,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
steps: int,
|
|
||||||
seed: int,
|
|
||||||
) -> dict:
|
|
||||||
"""Generate image using Text2Image modifier"""
|
|
||||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
|
||||||
payload = {
|
|
||||||
"input": {
|
|
||||||
"request_id": str(uuid.uuid4()),
|
|
||||||
"modifier": "Text2Image",
|
|
||||||
"modifications": {
|
|
||||||
"prompt": prompt,
|
|
||||||
"width": width,
|
|
||||||
"height": height,
|
|
||||||
"steps": steps,
|
|
||||||
"seed": seed,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return await endpoint.request("/generate/sync", payload, cost=COST)
|
|
||||||
|
|
||||||
|
|
||||||
async def call_generate_workflow(
|
|
||||||
client: Serverless,
|
|
||||||
*,
|
|
||||||
endpoint_name: str,
|
|
||||||
workflow_json: dict,
|
|
||||||
) -> dict:
|
|
||||||
"""Generate using custom workflow JSON"""
|
|
||||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
|
||||||
payload = {
|
|
||||||
"input": {
|
|
||||||
"request_id": str(uuid.uuid4()),
|
|
||||||
"workflow_json": workflow_json,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return await endpoint.request("/generate/sync", payload, cost=COST)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Demo Class ----------------------
|
|
||||||
class APIDemo:
|
|
||||||
def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False):
|
|
||||||
self.client = client
|
|
||||||
self.endpoint_name = endpoint_name
|
|
||||||
self.upload_s3 = upload_s3
|
|
||||||
self.s3_client = get_s3_client() if upload_s3 else None
|
|
||||||
|
|
||||||
if upload_s3 and not self.s3_client:
|
response = await endpoint.request("/generate/sync", payload, cost=count_workload())
|
||||||
log.warning("S3 upload requested but client creation failed. Images will only be saved locally.")
|
|
||||||
|
|
||||||
def extract_filename(self, response: dict) -> str | None:
|
|
||||||
"""Extract the generated image filename from ComfyUI response"""
|
|
||||||
if "comfyui_response" in response:
|
|
||||||
for data in response["comfyui_response"].values():
|
|
||||||
if isinstance(data, dict) and "outputs" in data:
|
|
||||||
for node_output in data["outputs"].values():
|
|
||||||
if "images" in node_output and node_output["images"]:
|
|
||||||
return node_output["images"][0].get("filename")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
|
|
||||||
"""Fetch and save image locally from the worker, optionally upload to S3"""
|
|
||||||
os.makedirs("generated_images", exist_ok=True)
|
|
||||||
return await self._fetch_image(worker_url, filename, local_name)
|
|
||||||
|
|
||||||
def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None:
|
|
||||||
"""Upload a local file to S3 and return the S3 URL"""
|
|
||||||
if not self.s3_client:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.s3_client.upload_file(
|
|
||||||
local_path,
|
|
||||||
S3_BUCKET_NAME,
|
|
||||||
s3_key,
|
|
||||||
ExtraArgs={"ContentType": "image/png"}
|
|
||||||
)
|
|
||||||
s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}"
|
|
||||||
print(f" ☁️ Uploaded to S3: {s3_key}")
|
|
||||||
return s3_url
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Failed to upload to S3: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
|
|
||||||
"""Fetch image from worker's /view endpoint and save locally"""
|
|
||||||
if not worker_url:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
url = f"{worker_url}/view"
|
|
||||||
params = {"filename": filename, "type": "output"}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(url, params=params, ssl=False) as resp:
|
|
||||||
if resp.status == 200:
|
|
||||||
path = f"generated_images/{local_name}"
|
|
||||||
image_data = await resp.read()
|
|
||||||
with open(path, "wb") as f:
|
|
||||||
f.write(image_data)
|
|
||||||
print(f" 💾 Saved: {path}")
|
|
||||||
|
|
||||||
# Upload to S3 if enabled
|
|
||||||
if self.upload_s3 and self.s3_client:
|
|
||||||
s3_key = f"comfyui/{local_name}"
|
|
||||||
self._upload_to_s3(path, s3_key)
|
|
||||||
|
|
||||||
return path
|
|
||||||
return None
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def demo_prompt(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
steps: int,
|
|
||||||
seed: int | None,
|
|
||||||
):
|
|
||||||
"""Demo: Generate image from text prompt"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("COMFYUI TEXT-TO-IMAGE DEMO")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
if seed is None:
|
|
||||||
seed = random.randint(0, 2**32 - 1)
|
|
||||||
|
|
||||||
print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}")
|
|
||||||
print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}")
|
|
||||||
print("\n🎨 Generating image...")
|
|
||||||
|
|
||||||
response = await call_generate(
|
|
||||||
self.client,
|
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
prompt=prompt,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
steps=steps,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\n✅ Generation complete!")
|
|
||||||
|
|
||||||
# Get worker URL for fetching images
|
|
||||||
worker_url = response.get("url", "")
|
|
||||||
print(f"Worker URL: {worker_url}")
|
|
||||||
|
|
||||||
# Fetch and save image
|
|
||||||
if "response" in response:
|
|
||||||
filename = self.extract_filename(response["response"])
|
|
||||||
if filename:
|
|
||||||
path = await self.save_image(worker_url, filename, f"comfy_{seed}.png")
|
|
||||||
if not path:
|
|
||||||
print(f"❌ Failed to fetch image")
|
|
||||||
else:
|
|
||||||
print("❌ No image in response")
|
|
||||||
else:
|
|
||||||
print("❌ Unexpected response format")
|
|
||||||
|
|
||||||
async def demo_workflow(self, workflow_file: str):
|
|
||||||
"""Demo: Generate using custom workflow file"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("COMFYUI CUSTOM WORKFLOW DEMO")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
if not os.path.exists(workflow_file):
|
|
||||||
log.error(f"Workflow file not found: {workflow_file}")
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(workflow_file, "r") as f:
|
|
||||||
workflow_json = json.load(f)
|
|
||||||
|
|
||||||
print(f"Workflow: {workflow_file}")
|
|
||||||
print("\n🎨 Generating...")
|
|
||||||
|
|
||||||
response = await call_generate_workflow(
|
|
||||||
self.client,
|
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
workflow_json=workflow_json,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\n✅ Generation complete!")
|
|
||||||
|
|
||||||
worker_url = response.get("url", "")
|
|
||||||
|
|
||||||
if "response" in response:
|
|
||||||
filename = self.extract_filename(response["response"])
|
|
||||||
if filename:
|
|
||||||
path = await self.save_image(worker_url, filename, "workflow.png")
|
|
||||||
if not path:
|
|
||||||
print(f"❌ Failed to fetch image")
|
|
||||||
else:
|
|
||||||
print("❌ No image in response")
|
|
||||||
else:
|
|
||||||
print("❌ Unexpected response format")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- CLI ----------------------
|
|
||||||
def build_arg_parser() -> argparse.ArgumentParser:
|
|
||||||
p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)")
|
|
||||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
|
||||||
p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT",
|
|
||||||
help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')")
|
|
||||||
p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead")
|
|
||||||
p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})")
|
|
||||||
p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})")
|
|
||||||
p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})")
|
|
||||||
p.add_argument("--seed", type=int, default=None, help="Seed (default: random)")
|
|
||||||
p.add_argument("--s3", action="store_true",
|
|
||||||
help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)")
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
async def main_async():
|
|
||||||
args = build_arg_parser().parse_args()
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Using endpoint: {args.endpoint}")
|
|
||||||
if args.s3:
|
|
||||||
print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with Serverless() as client:
|
|
||||||
demo = APIDemo(client, args.endpoint, upload_s3=args.s3)
|
|
||||||
|
|
||||||
if args.workflow:
|
|
||||||
await demo.demo_workflow(workflow_file=args.workflow)
|
|
||||||
else:
|
|
||||||
await demo.demo_prompt(
|
|
||||||
prompt=args.prompt,
|
|
||||||
width=args.width,
|
|
||||||
height=args.height,
|
|
||||||
steps=args.steps,
|
|
||||||
seed=args.seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
except AttributeError as e:
|
|
||||||
if "API key" in str(e):
|
|
||||||
log.error("API key missing. Set VAST_API_KEY environment variable.")
|
|
||||||
else:
|
|
||||||
log.error(f"Error: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
# Get the file from the path on the local machine using SCP or SFTP
|
||||||
|
# or configure S3 to upload to cloud storage.
|
||||||
|
print(response["response"]["output"][0]["local_path"])
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main_async())
|
asyncio.run(main())
|
||||||
@@ -4,7 +4,6 @@ import dataclasses
|
|||||||
import base64
|
import base64
|
||||||
from typing import Optional, Union, Type
|
from typing import Optional, Union, Type
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from aiohttp import web, ClientResponse
|
from aiohttp import web, ClientResponse
|
||||||
|
|
||||||
from lib.backend import Backend, LogAction
|
from lib.backend import Backend, LogAction
|
||||||
@@ -14,7 +13,6 @@ from .data_types import ComfyWorkflowData
|
|||||||
|
|
||||||
|
|
||||||
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
|
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
|
||||||
COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:18188") # Raw ComfyUI server
|
|
||||||
|
|
||||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||||
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
|
||||||
@@ -110,39 +108,8 @@ async def handle_ping(_):
|
|||||||
return web.Response(body="pong")
|
return web.Response(body="pong")
|
||||||
|
|
||||||
|
|
||||||
async def handle_view(request: web.Request) -> web.Response:
|
|
||||||
"""Proxy /view requests to raw ComfyUI server to fetch generated images"""
|
|
||||||
# Forward query params to raw ComfyUI (not the API wrapper)
|
|
||||||
query_string = request.query_string
|
|
||||||
url = f"{COMFYUI_URL}/view?{query_string}"
|
|
||||||
|
|
||||||
log.debug(f"Proxying /view request to: {url}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(url) as resp:
|
|
||||||
if resp.status == 200:
|
|
||||||
content = await resp.read()
|
|
||||||
return web.Response(
|
|
||||||
body=content,
|
|
||||||
status=200,
|
|
||||||
content_type=resp.content_type or "image/png"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text = await resp.text()
|
|
||||||
return web.Response(
|
|
||||||
text=text,
|
|
||||||
status=resp.status,
|
|
||||||
content_type="text/plain"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error proxying /view: {e}")
|
|
||||||
return web.Response(text=str(e), status=500)
|
|
||||||
|
|
||||||
|
|
||||||
routes = [
|
routes = [
|
||||||
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
|
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
|
||||||
web.get("/view", handle_view),
|
|
||||||
web.get("/ping", handle_ping),
|
web.get("/ping", handle_ping),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
+11
-14
@@ -13,11 +13,11 @@ from vastai import Serverless
|
|||||||
ENDPOINT_NAME = "my-comfyui-endpoint"
|
ENDPOINT_NAME = "my-comfyui-endpoint"
|
||||||
COST = 100 # Use a constant cost for image generation
|
COST = 100 # Use a constant cost for image generation
|
||||||
|
|
||||||
def call_default_workflow(endpoint_id: int, api_key: str, server_url: str) -> None:
|
def call_default_workflow(client: Serverless) -> None:
|
||||||
WORKER_ENDPOINT = "/prompt"
|
WORKER_ENDPOINT = "/prompt"
|
||||||
COST = 100
|
COST = 100
|
||||||
route_payload = {
|
route_payload = {
|
||||||
"endpoint_id": endpoint_id,
|
"endpoint": endpoint_group_name,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"cost": COST,
|
"cost": COST,
|
||||||
}
|
}
|
||||||
@@ -32,7 +32,7 @@ def call_default_workflow(endpoint_id: int, api_key: str, server_url: str) -> No
|
|||||||
auth_data = dict(
|
auth_data = dict(
|
||||||
signature=message["signature"],
|
signature=message["signature"],
|
||||||
cost=message["cost"],
|
cost=message["cost"],
|
||||||
endpoint_id=message["endpoint_id"],
|
endpoint=message["endpoint"],
|
||||||
reqnum=message["reqnum"],
|
reqnum=message["reqnum"],
|
||||||
url=message["url"],
|
url=message["url"],
|
||||||
)
|
)
|
||||||
@@ -52,12 +52,12 @@ def call_default_workflow(endpoint_id: int, api_key: str, server_url: str) -> No
|
|||||||
|
|
||||||
|
|
||||||
def call_custom_workflow_for_sd3(
|
def call_custom_workflow_for_sd3(
|
||||||
endpoint_id: int, api_key: str, server_url: str
|
endpoint_group_name: str, api_key: str, server_url: str
|
||||||
) -> None:
|
) -> None:
|
||||||
WORKER_ENDPOINT = "/custom-workflow"
|
WORKER_ENDPOINT = "/custom-workflow"
|
||||||
COST = 100
|
COST = 100
|
||||||
route_payload = {
|
route_payload = {
|
||||||
"endpoint_id": endpoint_id,
|
"endpoint": endpoint_group_name,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"cost": COST,
|
"cost": COST,
|
||||||
}
|
}
|
||||||
@@ -72,7 +72,7 @@ def call_custom_workflow_for_sd3(
|
|||||||
auth_data = dict(
|
auth_data = dict(
|
||||||
signature=message["signature"],
|
signature=message["signature"],
|
||||||
cost=message["cost"],
|
cost=message["cost"],
|
||||||
endpoint_id=message["endpoint_id"],
|
endpoint=message["endpoint"],
|
||||||
reqnum=message["reqnum"],
|
reqnum=message["reqnum"],
|
||||||
url=message["url"],
|
url=message["url"],
|
||||||
request_idx=message["request_idx"],
|
request_idx=message["request_idx"],
|
||||||
@@ -146,28 +146,25 @@ def call_custom_workflow_for_sd3(
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from lib.test_utils import test_args
|
from lib.test_utils import test_args
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
args = test_args.parse_args()
|
args = test_args.parse_args()
|
||||||
endpoint_info = Endpoint.get_endpoint_info(
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
endpoint_name=args.endpoint_group_name,
|
endpoint_name=args.endpoint_group_name,
|
||||||
account_api_key=args.api_key,
|
account_api_key=args.api_key,
|
||||||
instance=args.instance,
|
instance=args.instance,
|
||||||
)
|
)
|
||||||
if endpoint_info:
|
if endpoint_api_key:
|
||||||
endpoint_id = endpoint_info["id"]
|
|
||||||
endpoint_api_key = endpoint_info["api_key"]
|
|
||||||
try:
|
try:
|
||||||
call_default_workflow(
|
call_default_workflow(
|
||||||
endpoint_id=endpoint_id,
|
|
||||||
api_key=endpoint_api_key,
|
api_key=endpoint_api_key,
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
server_url=args.server_url,
|
server_url=args.server_url,
|
||||||
)
|
)
|
||||||
call_custom_workflow_for_sd3(
|
call_custom_workflow_for_sd3(
|
||||||
endpoint_id=endpoint_id,
|
|
||||||
api_key=endpoint_api_key,
|
api_key=endpoint_api_key,
|
||||||
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
server_url=args.server_url,
|
server_url=args.server_url,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error during API call: {e}")
|
log.error(f"Error during API call: {e}")
|
||||||
else:
|
else:
|
||||||
log.error(f"Failed to get endpoint info for {args.endpoint_group_name}")
|
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
|
||||||
|
|||||||
+22
-29
@@ -8,13 +8,14 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
|
|||||||
|
|
||||||
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
||||||
|
|
||||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
|
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
|
||||||
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
||||||
|
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
|
||||||
|
|
||||||
|
|
||||||
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
||||||
|
|
||||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||||
|
|
||||||
## Client Setup (Demo)
|
## Client Setup (Demo)
|
||||||
|
|
||||||
@@ -33,30 +34,12 @@ uv pip install -r requirements.txt
|
|||||||
|
|
||||||
Several examples have been provided in the client to help you get started with your own implementation.
|
Several examples have been provided in the client to help you get started with your own implementation.
|
||||||
|
|
||||||
First, set your API key as an environment variable:
|
### Completions
|
||||||
|
|
||||||
|
Call to `/v1/completions` with json response
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export VAST_API_KEY=<your_api_key>
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||||
```
|
|
||||||
|
|
||||||
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
|
|
||||||
|
|
||||||
### Chat Completion (streaming)
|
|
||||||
|
|
||||||
Call to `/v1/chat/completions` with streaming response
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
|
||||||
```
|
|
||||||
|
|
||||||
### Interactive Chat (streaming)
|
|
||||||
|
|
||||||
Interactive session with calls to `/v1/chat/completions`.
|
|
||||||
|
|
||||||
Type `clear` to clear the chat history or `quit` to exit.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Chat Completion (json)
|
### Chat Completion (json)
|
||||||
@@ -64,7 +47,15 @@ python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model
|
|||||||
Call to `/v1/chat/completions` with json response
|
Call to `/v1/chat/completions` with json response
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Chat Completion (streaming)
|
||||||
|
|
||||||
|
Call to `/v1/chat/completions` with streaming response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||||
```
|
```
|
||||||
|
|
||||||
### Tool Use (json)
|
### Tool Use (json)
|
||||||
@@ -74,14 +65,16 @@ Call to `/v1/chat/completions` with tool and json response.
|
|||||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||||
```
|
```
|
||||||
|
|
||||||
### Completions
|
### Interactive Chat (streaming)
|
||||||
|
|
||||||
Call to `/v1/completions` with json response
|
Interactive session with calls to `/v1/chat/completions`.
|
||||||
|
|
||||||
|
Type `clear` to clear the chat history or `quit` to exit.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
+16
-32
@@ -18,7 +18,7 @@ logging.basicConfig(
|
|||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
# ---------------------- Prompts ----------------------
|
# ---------------------- Prompts ----------------------
|
||||||
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
|
COMPLETIONS_PROMPT = "the capital of USA is"
|
||||||
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||||
TOOLS_PROMPT = (
|
TOOLS_PROMPT = (
|
||||||
"Can you list the files in the current working directory and tell me what you see? "
|
"Can you list the files in the current working directory and tell me what you see? "
|
||||||
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
|
|||||||
|
|
||||||
|
|
||||||
# ---- OpenAI-compatible calls (non-streaming) ----
|
# ---- OpenAI-compatible calls (non-streaming) ----
|
||||||
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
|
||||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"input": {
|
"input": {
|
||||||
@@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, endpo
|
|||||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
||||||
return resp["response"]
|
return resp["response"]
|
||||||
|
|
||||||
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
||||||
|
|
||||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"input": {
|
"input": {
|
||||||
@@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
|
|||||||
return resp["response"]
|
return resp["response"]
|
||||||
|
|
||||||
# ---- Streaming variants ----
|
# ---- Streaming variants ----
|
||||||
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
|
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
|
||||||
|
|
||||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"input": {
|
"input": {
|
||||||
@@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, end
|
|||||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||||
return resp["response"] # async generator
|
return resp["response"] # async generator
|
||||||
|
|
||||||
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
|
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
|
||||||
|
|
||||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"input": {
|
"input": {
|
||||||
@@ -174,10 +174,9 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
|
|||||||
class APIDemo:
|
class APIDemo:
|
||||||
"""Demo and testing functionality for the API client"""
|
"""Demo and testing functionality for the API client"""
|
||||||
|
|
||||||
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
|
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.model = model
|
self.model = model
|
||||||
self.endpoint_name = endpoint_name
|
|
||||||
self.tool_manager = tool_manager or ToolManager()
|
self.tool_manager = tool_manager or ToolManager()
|
||||||
|
|
||||||
# ----- Streaming handler -----
|
# ----- Streaming handler -----
|
||||||
@@ -186,15 +185,10 @@ class APIDemo:
|
|||||||
reasoning_content = ""
|
reasoning_content = ""
|
||||||
printed_reasoning = False
|
printed_reasoning = False
|
||||||
printed_answer = False
|
printed_answer = False
|
||||||
finish_reason = None
|
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
choice = (chunk.get("choices") or [{}])[0]
|
choice = (chunk.get("choices") or [{}])[0]
|
||||||
delta = choice.get("delta", {})
|
delta = choice.get("delta", {})
|
||||||
|
|
||||||
# Track finish reason
|
|
||||||
if choice.get("finish_reason"):
|
|
||||||
finish_reason = choice.get("finish_reason")
|
|
||||||
|
|
||||||
# reasoning tokens
|
# reasoning tokens
|
||||||
rc = delta.get("reasoning_content")
|
rc = delta.get("reasoning_content")
|
||||||
@@ -225,8 +219,6 @@ class APIDemo:
|
|||||||
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
||||||
if printed_answer:
|
if printed_answer:
|
||||||
print(f"Response tokens: {len(full_response.split())}")
|
print(f"Response tokens: {len(full_response.split())}")
|
||||||
if finish_reason:
|
|
||||||
print(f"Finish reason: {finish_reason}")
|
|
||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
@@ -239,7 +231,6 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
prompt=COMPLETIONS_PROMPT,
|
prompt=COMPLETIONS_PROMPT,
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
temperature=DEFAULT_TEMPERATURE,
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
)
|
)
|
||||||
@@ -258,7 +249,6 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
temperature=DEFAULT_TEMPERATURE
|
temperature=DEFAULT_TEMPERATURE
|
||||||
)
|
)
|
||||||
@@ -271,7 +261,6 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
temperature=DEFAULT_TEMPERATURE
|
temperature=DEFAULT_TEMPERATURE
|
||||||
)
|
)
|
||||||
@@ -298,7 +287,6 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
tools=minimal_tool,
|
tools=minimal_tool,
|
||||||
tool_choice="none",
|
tool_choice="none",
|
||||||
max_tokens=10
|
max_tokens=10
|
||||||
@@ -324,7 +312,6 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
tools=self.tool_manager.get_ls_tool_definition(),
|
tools=self.tool_manager.get_ls_tool_definition(),
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
@@ -402,7 +389,6 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
temperature=DEFAULT_TEMPERATURE,
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
)
|
)
|
||||||
@@ -441,6 +427,7 @@ class APIDemo:
|
|||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("INTERACTIVE STREAMING CHAT")
|
print("INTERACTIVE STREAMING CHAT")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
print(f"Using model: {self.model}")
|
||||||
print("Type 'quit' to exit, 'clear' to clear history")
|
print("Type 'quit' to exit, 'clear' to clear history")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
@@ -466,8 +453,7 @@ class APIDemo:
|
|||||||
stream = await stream_chat_completions(
|
stream = await stream_chat_completions(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
temperature=0.7
|
temperature=0.7
|
||||||
)
|
)
|
||||||
@@ -487,8 +473,8 @@ class APIDemo:
|
|||||||
# ---------------------- CLI ----------------------
|
# ---------------------- CLI ----------------------
|
||||||
def build_arg_parser() -> argparse.ArgumentParser:
|
def build_arg_parser() -> argparse.ArgumentParser:
|
||||||
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
|
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
|
||||||
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
|
p.add_argument("--model", required=True, help="Model to use for requests (required)")
|
||||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
|
||||||
|
|
||||||
modes = p.add_mutually_exclusive_group(required=False)
|
modes = p.add_mutually_exclusive_group(required=False)
|
||||||
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
|
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
|
||||||
@@ -516,14 +502,12 @@ async def main_async():
|
|||||||
print("Please specify exactly one test mode")
|
print("Please specify exactly one test mode")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Using model: {args.model}")
|
print(f"Using model: {args.model}")
|
||||||
print(f"Using endpoint: {args.endpoint}")
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with Serverless() as client:
|
async with Serverless() as client:
|
||||||
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
|
demo = APIDemo(client, args.model, ToolManager())
|
||||||
|
|
||||||
if args.completion:
|
if args.completion:
|
||||||
await demo.demo_completions()
|
await demo.demo_completions()
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ MODEL_SERVER_START_LOG_MSG = [
|
|||||||
"llama runner started", # Ollama
|
"llama runner started", # Ollama
|
||||||
'"message":"Connected","target":"text_generation_router"', # TGI
|
'"message":"Connected","target":"text_generation_router"', # TGI
|
||||||
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
||||||
"main: model loaded" # llama.cpp
|
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||||
@@ -35,7 +34,6 @@ backend = Backend(
|
|||||||
model_server_url=os.environ["MODEL_SERVER_URL"],
|
model_server_url=os.environ["MODEL_SERVER_URL"],
|
||||||
model_log_file=os.environ["MODEL_LOG"],
|
model_log_file=os.environ["MODEL_LOG"],
|
||||||
allow_parallel_requests=True,
|
allow_parallel_requests=True,
|
||||||
max_wait_time=600.0,
|
|
||||||
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
||||||
log_actions=[
|
log_actions=[
|
||||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ def do_one(endpoint_name: str,
|
|||||||
worker_session):
|
worker_session):
|
||||||
try:
|
try:
|
||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
route_payload = {"endpoint_id": endpoint_id, "api_key": endpoint_api_key, "cost": workload}
|
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
|
||||||
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
|
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
|
||||||
start = time.time()
|
start = time.time()
|
||||||
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
|
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
|
||||||
|
|||||||
+9
-93
@@ -1,103 +1,19 @@
|
|||||||
# HuggingFace TGI PyWorker
|
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
|
||||||
|
|
||||||
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
1. `generate`: Generates the LLM's response to a given prompt in a single request.
|
||||||
|
2. `generate_stream`: Streams the LLM's response token by token.
|
||||||
|
|
||||||
## Instance Setup
|
Both endpoints use the following API payload format:
|
||||||
|
|
||||||
1. Pick a template
|
|
||||||
|
|
||||||
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
|
|
||||||
|
|
||||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
|
|
||||||
|
|
||||||
The template can be configured via the template interface. You may want to change the model or startup arguments.
|
|
||||||
|
|
||||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
|
||||||
|
|
||||||
## Client Setup (Demo)
|
|
||||||
|
|
||||||
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/vast-ai/pyworker
|
|
||||||
cd pyworker
|
|
||||||
pip install uv
|
|
||||||
uv venv -p 3.12
|
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
## Using the Test Client
|
|
||||||
|
|
||||||
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
|
|
||||||
|
|
||||||
First, set your API key as an environment variable:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export VAST_API_KEY=<your_api_key>
|
|
||||||
```
|
|
||||||
|
|
||||||
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
|
|
||||||
|
|
||||||
### Generate (Streaming)
|
|
||||||
|
|
||||||
Call to `/generate_stream` with streaming response:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
|
|
||||||
```
|
|
||||||
|
|
||||||
### Generate (Non-Streaming)
|
|
||||||
|
|
||||||
Call to `/generate` with json response:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
|
|
||||||
```
|
|
||||||
|
|
||||||
### Interactive Session (Streaming)
|
|
||||||
|
|
||||||
Interactive session with streaming responses. Type `quit` to exit.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
|
|
||||||
```
|
|
||||||
|
|
||||||
## API Endpoints
|
|
||||||
|
|
||||||
TGI provides two primary endpoints:
|
|
||||||
|
|
||||||
### Generate (Non-Streaming)
|
|
||||||
|
|
||||||
`/generate` - Returns the complete response in a single request.
|
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"inputs": "Your prompt here",
|
"inputs": "PROMPT",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_new_tokens": 1024,
|
"max_new_tokens": 250
|
||||||
"temperature": 0.7,
|
|
||||||
"return_full_text": false
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Generate Stream (Streaming)
|
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
|
||||||
|
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
|
||||||
`/generate_stream` - Streams the response token by token.
|
approximately 2 seconds to complete.
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"inputs": "Your prompt here",
|
|
||||||
"parameters": {
|
|
||||||
"max_new_tokens": 1024,
|
|
||||||
"temperature": 0.7,
|
|
||||||
"do_sample": true,
|
|
||||||
"return_full_text": false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Performance Notes
|
|
||||||
|
|
||||||
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
|
|
||||||
|
|||||||
+33
-194
@@ -1,222 +1,61 @@
|
|||||||
import logging
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from vastai import Serverless
|
from vastai import Serverless
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# ---------------------- Logging ----------------------
|
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.DEBUG,
|
|
||||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
|
||||||
log = logging.getLogger(__file__)
|
|
||||||
|
|
||||||
# ---------------------- Defaults ----------------------
|
|
||||||
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
|
||||||
|
|
||||||
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
|
|
||||||
MAX_TOKENS = 1024
|
MAX_TOKENS = 1024
|
||||||
DEFAULT_TEMPERATURE = 0.7
|
PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||||
|
|
||||||
|
async def call_generate(client: Serverless) -> None:
|
||||||
# ---------------------- API Calls ----------------------
|
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||||
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
|
|
||||||
"""Non-streaming generation via /generate endpoint"""
|
|
||||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"inputs": prompt,
|
"inputs": PROMPT,
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
"max_new_tokens": MAX_TOKENS,
|
||||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
"temperature": 0.7,
|
||||||
"return_full_text": False,
|
"return_full_text": False
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.debug("POST /generate %s", json.dumps(payload)[:500])
|
|
||||||
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
|
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
|
||||||
return resp["response"]
|
|
||||||
|
print(resp["response"]["generated_text"])
|
||||||
|
|
||||||
|
|
||||||
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
|
async def call_generate_stream(client: Serverless) -> None:
|
||||||
"""Streaming generation via /generate_stream endpoint"""
|
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"inputs": prompt,
|
"inputs": PROMPT,
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
"max_new_tokens": MAX_TOKENS,
|
||||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
"temperature": 0.7,
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"return_full_text": False,
|
"return_full_text": False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
|
|
||||||
resp = await endpoint.request(
|
resp = await endpoint.request(
|
||||||
"/generate_stream",
|
"/generate_stream",
|
||||||
payload,
|
payload,
|
||||||
cost=payload["parameters"]["max_new_tokens"],
|
cost=MAX_TOKENS,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
return resp["response"] # async generator
|
stream = resp["response"]
|
||||||
|
|
||||||
|
printed_answer = False
|
||||||
|
async for event in stream:
|
||||||
|
tok = (event.get("token") or {}).get("text")
|
||||||
|
if tok:
|
||||||
|
if not printed_answer:
|
||||||
|
printed_answer = True
|
||||||
|
print("Answer:\n", end="", flush=True)
|
||||||
|
print(tok, end="", flush=True)
|
||||||
|
|
||||||
# ---------------------- Demo Runner ----------------------
|
async def main():
|
||||||
class APIDemo:
|
async with Serverless() as client:
|
||||||
"""Demo and testing functionality for the TGI API client"""
|
await call_generate(client)
|
||||||
|
await call_generate_stream(client)
|
||||||
def __init__(self, client: Serverless, endpoint_name: str):
|
|
||||||
self.client = client
|
|
||||||
self.endpoint_name = endpoint_name
|
|
||||||
|
|
||||||
async def handle_streaming_response(self, stream) -> str:
|
|
||||||
"""Process streaming response and print tokens"""
|
|
||||||
full_response = ""
|
|
||||||
printed_answer = False
|
|
||||||
|
|
||||||
async for event in stream:
|
|
||||||
tok = (event.get("token") or {}).get("text")
|
|
||||||
if tok:
|
|
||||||
if not printed_answer:
|
|
||||||
printed_answer = True
|
|
||||||
print("\n💬 Response: ", end="", flush=True)
|
|
||||||
print(tok, end="", flush=True)
|
|
||||||
full_response += tok
|
|
||||||
|
|
||||||
print() # newline
|
|
||||||
if printed_answer:
|
|
||||||
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
|
|
||||||
|
|
||||||
return full_response
|
|
||||||
|
|
||||||
async def demo_generate(self) -> None:
|
|
||||||
"""Demo non-streaming generation"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("GENERATE DEMO (NON-STREAMING)")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
response = await call_generate(
|
|
||||||
client=self.client,
|
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
prompt=DEFAULT_PROMPT,
|
|
||||||
max_tokens=MAX_TOKENS,
|
|
||||||
temperature=DEFAULT_TEMPERATURE,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\n💬 Response: {response.get('generated_text', '')}")
|
|
||||||
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
|
||||||
|
|
||||||
async def demo_generate_stream(self) -> None:
|
|
||||||
"""Demo streaming generation"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("GENERATE DEMO (STREAMING)")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
stream = await call_generate_stream(
|
|
||||||
client=self.client,
|
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
prompt=DEFAULT_PROMPT,
|
|
||||||
max_tokens=MAX_TOKENS,
|
|
||||||
temperature=DEFAULT_TEMPERATURE,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.handle_streaming_response(stream)
|
|
||||||
except Exception as e:
|
|
||||||
log.error("\nError during streaming: %s", e, exc_info=True)
|
|
||||||
|
|
||||||
async def interactive_chat(self) -> None:
|
|
||||||
"""Interactive session with streaming generation"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("INTERACTIVE STREAMING SESSION")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Using endpoint: {self.endpoint_name}")
|
|
||||||
print("Type 'quit' to exit")
|
|
||||||
print()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
user_input = input("You: ").strip()
|
|
||||||
|
|
||||||
if user_input.lower() == "quit":
|
|
||||||
print("👋 Goodbye!")
|
|
||||||
break
|
|
||||||
elif not user_input:
|
|
||||||
continue
|
|
||||||
|
|
||||||
print("Assistant: ", end="", flush=True)
|
|
||||||
stream = await call_generate_stream(
|
|
||||||
client=self.client,
|
|
||||||
endpoint_name=self.endpoint_name,
|
|
||||||
prompt=user_input,
|
|
||||||
max_tokens=MAX_TOKENS,
|
|
||||||
temperature=DEFAULT_TEMPERATURE,
|
|
||||||
)
|
|
||||||
|
|
||||||
full_response = ""
|
|
||||||
async for event in stream:
|
|
||||||
tok = (event.get("token") or {}).get("text")
|
|
||||||
if tok:
|
|
||||||
print(tok, end="", flush=True)
|
|
||||||
full_response += tok
|
|
||||||
print() # newline
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n👋 Session interrupted. Goodbye!")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
log.error("\nError: %s", e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- CLI ----------------------
|
|
||||||
def build_arg_parser() -> argparse.ArgumentParser:
|
|
||||||
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
|
|
||||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
|
||||||
|
|
||||||
modes = p.add_mutually_exclusive_group(required=False)
|
|
||||||
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
|
|
||||||
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
|
|
||||||
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
async def main_async():
|
|
||||||
args = build_arg_parser().parse_args()
|
|
||||||
|
|
||||||
selected = sum([args.generate, args.generate_stream, args.interactive])
|
|
||||||
if selected == 0:
|
|
||||||
print("Please specify exactly one test mode:")
|
|
||||||
print(" --generate : Test generate endpoint (non-streaming)")
|
|
||||||
print(" --generate-stream : Test generate endpoint with streaming")
|
|
||||||
print(" --interactive : Start interactive streaming session")
|
|
||||||
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
|
|
||||||
sys.exit(1)
|
|
||||||
elif selected > 1:
|
|
||||||
print("Please specify exactly one test mode")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Using endpoint: {args.endpoint}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with Serverless() as client:
|
|
||||||
demo = APIDemo(client, args.endpoint)
|
|
||||||
|
|
||||||
if args.generate:
|
|
||||||
await demo.demo_generate()
|
|
||||||
elif args.generate_stream:
|
|
||||||
await demo.demo_generate_stream()
|
|
||||||
elif args.interactive:
|
|
||||||
await demo.interactive_chat()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log.error("Error during test: %s", e, exc_info=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main_async())
|
asyncio.run(main())
|
||||||
|
|||||||
Reference in New Issue
Block a user