Compare commits

..

9 Commits

Author SHA1 Message Date
Colter Downing 6b5b1341a7 update tgi client 2025-12-03 18:38:42 -08:00
Colter-Downing 8be92c03de Merge pull request #69 from vast-ai/AUTO-874--fix-openai-worker-client
defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first
2025-12-03 16:59:56 -08:00
Colter Downing adedb8ba90 defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first if present 2025-12-03 16:57:28 -08:00
LucasArmandVast 2f543c01ad Merge pull request #68 from vast-ai/fix-vllm-concurrency
Increase model wait time for vLLM
2025-12-03 16:13:51 -05:00
Lucas Armand 0bcd2219ea Increase model wait time for vLLM 2025-12-03 12:38:52 -08:00
LucasArmandVast 0339b471c5 Merge pull request #66 from vast-ai/synthesis
PyWorker Error Handling
2025-11-25 16:02:26 -08:00
Lucas Armand e143162438 bumpy pyworker version 2025-11-25 16:01:23 -08:00
LucasArmandVast 7a792fd176 Merge pull request #64 from vast-ai/add-llama-log
add llama log
2025-11-21 10:24:27 -08:00
Lucas Armand e0449cb3c7 add llama log 2025-11-21 10:22:16 -08:00
8 changed files with 363 additions and 174 deletions
+8 -34
View File
@@ -30,7 +30,7 @@ from lib.data_types import (
BenchmarkResult
)
VERSION = "0.2.0"
VERSION = "0.2.1"
MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
@@ -235,14 +235,10 @@ class Backend:
log.debug("No healthcheck endpoint defined, skipping healthcheck")
return
first_healthcheck = True
while True:
await sleep(10)
if self.__start_healthcheck is False:
continue
if first_healthcheck:
log.info(f"[healthcheck] First healthcheck starting (model is now loaded)")
first_healthcheck = False
try:
log.debug(f"Performing healthcheck on {health_check_url}")
async with self.healthcheck_session.get(health_check_url) as response:
@@ -260,22 +256,9 @@ class Backend:
self.backend_errored(str(e))
async def _start_tracking(self) -> None:
log.info("Starting tracking tasks (read_logs, send_metrics_loop, healthcheck, send_delete_requests_loop)")
task_names = ["read_logs", "send_metrics_loop", "healthcheck", "send_delete_requests_loop"]
results = await gather(
self.__read_logs(),
self.metrics._send_metrics_loop(),
self.__healthcheck(),
self.metrics._send_delete_requests_loop(),
return_exceptions=True
await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
)
# If we get here, one or more tasks exited (they should run forever)
log.error(f"CRITICAL: _start_tracking gather returned! This should never happen. Results: {results}")
for name, result in zip(task_names, results):
if isinstance(result, Exception):
log.error(f"Tracking task '{name}' crashed with exception: {result}", exc_info=result)
elif result is not None:
log.warning(f"Tracking task '{name}' exited unexpectedly with result: {result}")
def backend_errored(self, msg: str) -> None:
self.metrics._model_errored(msg)
@@ -416,20 +399,15 @@ class Backend:
# await sleep(5)
try:
max_throughput = await run_benchmark()
log.info(f"[benchmark] Benchmark complete, max_throughput={max_throughput}, setting healthcheck=True")
self.__start_healthcheck = True
self.metrics._model_loaded(
max_throughput=max_throughput,
)
log.info(f"[benchmark] _model_loaded() called, returning from handle_log_line")
except ClientConnectorError as e:
log.debug(
f"failed to connect to model api during benchmark"
f"failed to connect to comfyui api during benchmark"
)
self.backend_errored(str(e))
except Exception as e:
log.error(f"Unexpected error during benchmark: {e}", exc_info=True)
self.backend_errored(f"Benchmark failed: {e}")
case LogAction.ModelError if msg in log_line:
log.debug(f"Got log line indicating error: {log_line}")
self.backend_errored(msg)
@@ -441,14 +419,10 @@ class Backend:
log.debug(f"tailing file: {self.model_log_file}")
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
while True:
try:
line = await f.readline()
if line:
await handle_log_line(line.rstrip())
else:
await asyncio.sleep(LOG_POLL_INTERVAL)
except Exception as e:
log.error(f"Error processing log line: {e}", exc_info=True)
line = await f.readline()
if line:
await handle_log_line(line.rstrip())
else:
await asyncio.sleep(LOG_POLL_INTERVAL)
###########
-35
View File
@@ -1,5 +1,4 @@
import os
import sys
import time
import logging
import json
@@ -18,14 +17,6 @@ DELETE_REQUESTS_INTERVAL = 1
log = logging.getLogger(__file__)
def _flush_logs():
"""Force flush all log handlers and stdout/stderr."""
for handler in logging.root.handlers:
handler.flush()
sys.stdout.flush()
sys.stderr.flush()
@cache
def get_url() -> str:
use_ssl = os.environ.get("USE_SSL", "false") == "true"
@@ -128,41 +119,22 @@ class Metrics:
await self.__send_delete_requests_and_reset()
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
loop_count = 0
first_loaded_send_done = False
while True:
await sleep(METRICS_UPDATE_INTERVAL)
loop_count += 1
elapsed = time.time() - self.last_metric_update
# Log heartbeat every 30 seconds to confirm loop is running
if loop_count % 30 == 0:
log.debug(f"[heartbeat] metrics loop alive, loop_count={loop_count}, model_loaded={self.system_metrics.model_is_loaded}")
_flush_logs()
# Extra logging for first few iterations after model loads
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
log.info(f"[transition] First iteration with model_loaded=True, loop_count={loop_count}, elapsed={elapsed:.1f}")
_flush_logs()
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
await self.__send_metrics_and_reset()
elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
await self.__send_metrics_and_reset()
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
first_loaded_send_done = True
log.info(f"[transition] First loaded metrics send complete, continuing to next iteration...")
_flush_logs()
def _model_loaded(self, max_throughput: float) -> None:
log.info(f"MODEL LOADED: Setting model_is_loaded=True, max_throughput={max_throughput}")
_flush_logs()
self.system_metrics.model_loading_time = (
time.time() - self.system_metrics.model_loading_start
)
self.system_metrics.model_is_loaded = True
self.model_metrics.max_throughput = max_throughput
log.info(f"MODEL LOADED: model_loading_time={self.system_metrics.model_loading_time}")
_flush_logs()
def _model_errored(self, error_msg: str) -> None:
self.model_metrics.set_errored(error_msg)
@@ -299,7 +271,6 @@ class Metrics:
###########
self.system_metrics.update_disk_usage()
had_loadtime = loadtime_snapshot is not None and loadtime_snapshot > 0
sent = False
for report_addr in self.report_addr:
@@ -308,14 +279,8 @@ class Metrics:
break
if sent:
if had_loadtime:
log.info(f"FIRST LOADTIME METRICS SENT SUCCESSFULLY! loadtime={loadtime_snapshot}")
_flush_logs()
# clear the one-shot loadtime only if we actually sent *this* value
self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False
self.model_metrics.reset()
self.last_metric_update = time.time()
if had_loadtime:
log.info(f"POST-SEND: reset complete, last_metric_update={self.last_metric_update}, continuing loop...")
_flush_logs()
-20
View File
@@ -1,7 +1,5 @@
import os
import logging
import signal
import sys
from typing import List
import ssl
from asyncio import run, gather
@@ -14,25 +12,7 @@ from aiohttp import web
log = logging.getLogger(__file__)
def _setup_signal_handlers():
"""Setup signal handlers to log when process receives termination signals."""
def signal_handler(signum, frame):
sig_name = signal.Signals(signum).name
log.error(f"SIGNAL RECEIVED: {sig_name} ({signum}) - process is being terminated")
sys.stdout.flush()
sys.stderr.flush()
sys.exit(128 + signum)
# Handle common termination signals
for sig in [signal.SIGTERM, signal.SIGINT, signal.SIGHUP]:
try:
signal.signal(sig, signal_handler)
except (OSError, ValueError):
pass # Some signals may not be available
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
_setup_signal_handlers()
try:
log.debug("getting certificate...")
use_ssl = os.environ.get("USE_SSL", "false") == "true"
+33 -26
View File
@@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
@@ -34,38 +33,20 @@ uv pip install -r requirements.txt
Several examples have been provided in the client to help you get started with your own implementation.
### Completions
Call to `/v1/completions` with json response
First, set your API key as an environment variable:
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
export VAST_API_KEY=<your_api_key>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Interactive Chat (streaming)
@@ -75,6 +56,32 @@ Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
+32 -16
View File
@@ -18,7 +18,7 @@ logging.basicConfig(
log = logging.getLogger(__file__)
# ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "the capital of USA is"
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? "
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
# ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
@@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
@@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
return resp["response"]
# ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
@@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
@@ -174,9 +174,10 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
class APIDemo:
"""Demo and testing functionality for the API client"""
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
self.client = client
self.model = model
self.endpoint_name = endpoint_name
self.tool_manager = tool_manager or ToolManager()
# ----- Streaming handler -----
@@ -185,10 +186,15 @@ class APIDemo:
reasoning_content = ""
printed_reasoning = False
printed_answer = False
finish_reason = None
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
# Track finish reason
if choice.get("finish_reason"):
finish_reason = choice.get("finish_reason")
# reasoning tokens
rc = delta.get("reasoning_content")
@@ -219,6 +225,8 @@ class APIDemo:
print(f"Reasoning tokens: {len(reasoning_content.split())}")
if printed_answer:
print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
return full_response
@@ -231,6 +239,7 @@ class APIDemo:
client=self.client,
model=self.model,
prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
@@ -249,6 +258,7 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
@@ -261,6 +271,7 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
@@ -287,6 +298,7 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool,
tool_choice="none",
max_tokens=10
@@ -312,6 +324,7 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto",
max_tokens=MAX_TOKENS,
@@ -389,6 +402,7 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
@@ -427,7 +441,6 @@ class APIDemo:
print("=" * 60)
print("INTERACTIVE STREAMING CHAT")
print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history")
print()
@@ -453,7 +466,8 @@ class APIDemo:
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=0.7
)
@@ -473,8 +487,8 @@ class APIDemo:
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", required=True, help="Model to use for requests (required)")
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
@@ -502,12 +516,14 @@ async def main_async():
print("Please specify exactly one test mode")
sys.exit(1)
print(f"Using model: {args.model}")
print("=" * 60)
print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}")
try:
async with Serverless() as client:
demo = APIDemo(client, args.model, ToolManager())
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
if args.completion:
await demo.demo_completions()
+2
View File
@@ -11,6 +11,7 @@ MODEL_SERVER_START_LOG_MSG = [
"llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI
"main: model loaded" # llama.cpp
]
MODEL_SERVER_ERROR_LOG_MSGS = [
@@ -34,6 +35,7 @@ backend = Backend(
model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
max_wait_time=600.0,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
+93 -9
View File
@@ -1,19 +1,103 @@
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
# HuggingFace TGI PyWorker
1. `generate`: Generates the LLM's response to a given prompt in a single request.
2. `generate_stream`: Streams the LLM's response token by token.
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
Both endpoints use the following API payload format:
## Instance Setup
1. Pick a template
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
The template can be configured via the template interface. You may want to change the model or startup arguments.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
## Using the Test Client
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
First, set your API key as an environment variable:
```bash
export VAST_API_KEY=<your_api_key>
```
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
### Generate (Streaming)
Call to `/generate_stream` with streaming response:
```bash
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
```
### Generate (Non-Streaming)
Call to `/generate` with json response:
```bash
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
```
### Interactive Session (Streaming)
Interactive session with streaming responses. Type `quit` to exit.
```bash
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
```
## API Endpoints
TGI provides two primary endpoints:
### Generate (Non-Streaming)
`/generate` - Returns the complete response in a single request.
```json
{
"inputs": "PROMPT",
"inputs": "Your prompt here",
"parameters": {
"max_new_tokens": 250
"max_new_tokens": 1024,
"temperature": 0.7,
"return_full_text": false
}
}
```
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
approximately 2 seconds to complete.
### Generate Stream (Streaming)
`/generate_stream` - Streams the response token by token.
```json
{
"inputs": "Your prompt here",
"parameters": {
"max_new_tokens": 1024,
"temperature": 0.7,
"do_sample": true,
"return_full_text": false
}
}
```
## Performance Notes
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
+195 -34
View File
@@ -1,61 +1,222 @@
import logging
import json
import os
import sys
import argparse
from vastai import Serverless
import asyncio
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
MAX_TOKENS = 1024
PROMPT = "Think step by step: Tell me about the Python programming language."
# ---------------------- Logging ----------------------
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def call_generate(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
# ---------------------- Defaults ----------------------
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- API Calls ----------------------
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
"""Non-streaming generation via /generate endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"inputs": PROMPT,
"inputs": prompt,
"parameters": {
"max_new_tokens": MAX_TOKENS,
"temperature": 0.7,
"return_full_text": False
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"return_full_text": False,
}
}
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
print(resp["response"]["generated_text"])
log.debug("POST /generate %s", json.dumps(payload)[:500])
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
return resp["response"]
async def call_generate_stream(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
"""Streaming generation via /generate_stream endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"inputs": PROMPT,
"inputs": prompt,
"parameters": {
"max_new_tokens": MAX_TOKENS,
"temperature": 0.7,
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"do_sample": True,
"return_full_text": False,
}
}
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
resp = await endpoint.request(
"/generate_stream",
payload,
cost=MAX_TOKENS,
cost=payload["parameters"]["max_new_tokens"],
stream=True,
)
stream = resp["response"]
return resp["response"] # async generator
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("Answer:\n", end="", flush=True)
print(tok, end="", flush=True)
async def main():
async with Serverless() as client:
await call_generate(client)
await call_generate_stream(client)
# ---------------------- Demo Runner ----------------------
class APIDemo:
"""Demo and testing functionality for the TGI API client"""
def __init__(self, client: Serverless, endpoint_name: str):
self.client = client
self.endpoint_name = endpoint_name
async def handle_streaming_response(self, stream) -> str:
"""Process streaming response and print tokens"""
full_response = ""
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(tok, end="", flush=True)
full_response += tok
print() # newline
if printed_answer:
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
return full_response
async def demo_generate(self) -> None:
"""Demo non-streaming generation"""
print("=" * 60)
print("GENERATE DEMO (NON-STREAMING)")
print("=" * 60)
response = await call_generate(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print(f"\n💬 Response: {response.get('generated_text', '')}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def demo_generate_stream(self) -> None:
"""Demo streaming generation"""
print("=" * 60)
print("GENERATE DEMO (STREAMING)")
print("=" * 60)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
try:
await self.handle_streaming_response(stream)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
async def interactive_chat(self) -> None:
"""Interactive session with streaming generation"""
print("=" * 60)
print("INTERACTIVE STREAMING SESSION")
print("=" * 60)
print(f"Using endpoint: {self.endpoint_name}")
print("Type 'quit' to exit")
print()
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() == "quit":
print("👋 Goodbye!")
break
elif not user_input:
continue
print("Assistant: ", end="", flush=True)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=user_input,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
full_response = ""
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
print(tok, end="", flush=True)
full_response += tok
print() # newline
except KeyboardInterrupt:
print("\n👋 Session interrupted. Goodbye!")
break
except Exception as e:
log.error("\nError: %s", e)
continue
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
return p
async def main_async():
args = build_arg_parser().parse_args()
selected = sum([args.generate, args.generate_stream, args.interactive])
if selected == 0:
print("Please specify exactly one test mode:")
print(" --generate : Test generate endpoint (non-streaming)")
print(" --generate-stream : Test generate endpoint with streaming")
print(" --interactive : Start interactive streaming session")
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
sys.exit(1)
elif selected > 1:
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint)
if args.generate:
await demo.demo_generate()
elif args.generate_stream:
await demo.demo_generate_stream()
elif args.interactive:
await demo.interactive_chat()
except Exception as e:
log.error("Error during test: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main_async())