Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6b5b1341a7 | |||
| 8be92c03de | |||
| adedb8ba90 | |||
| 2f543c01ad | |||
| 0bcd2219ea | |||
| 0339b471c5 | |||
| e143162438 | |||
| 7a792fd176 | |||
| e0449cb3c7 |
+8
-34
@@ -30,7 +30,7 @@ from lib.data_types import (
|
||||
BenchmarkResult
|
||||
)
|
||||
|
||||
VERSION = "0.2.0"
|
||||
VERSION = "0.2.1"
|
||||
|
||||
MSG_HISTORY_LEN = 100
|
||||
log = logging.getLogger(__file__)
|
||||
@@ -235,14 +235,10 @@ class Backend:
|
||||
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||
return
|
||||
|
||||
first_healthcheck = True
|
||||
while True:
|
||||
await sleep(10)
|
||||
if self.__start_healthcheck is False:
|
||||
continue
|
||||
if first_healthcheck:
|
||||
log.info(f"[healthcheck] First healthcheck starting (model is now loaded)")
|
||||
first_healthcheck = False
|
||||
try:
|
||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||
async with self.healthcheck_session.get(health_check_url) as response:
|
||||
@@ -260,22 +256,9 @@ class Backend:
|
||||
self.backend_errored(str(e))
|
||||
|
||||
async def _start_tracking(self) -> None:
|
||||
log.info("Starting tracking tasks (read_logs, send_metrics_loop, healthcheck, send_delete_requests_loop)")
|
||||
task_names = ["read_logs", "send_metrics_loop", "healthcheck", "send_delete_requests_loop"]
|
||||
results = await gather(
|
||||
self.__read_logs(),
|
||||
self.metrics._send_metrics_loop(),
|
||||
self.__healthcheck(),
|
||||
self.metrics._send_delete_requests_loop(),
|
||||
return_exceptions=True
|
||||
await gather(
|
||||
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
|
||||
)
|
||||
# If we get here, one or more tasks exited (they should run forever)
|
||||
log.error(f"CRITICAL: _start_tracking gather returned! This should never happen. Results: {results}")
|
||||
for name, result in zip(task_names, results):
|
||||
if isinstance(result, Exception):
|
||||
log.error(f"Tracking task '{name}' crashed with exception: {result}", exc_info=result)
|
||||
elif result is not None:
|
||||
log.warning(f"Tracking task '{name}' exited unexpectedly with result: {result}")
|
||||
|
||||
def backend_errored(self, msg: str) -> None:
|
||||
self.metrics._model_errored(msg)
|
||||
@@ -416,20 +399,15 @@ class Backend:
|
||||
# await sleep(5)
|
||||
try:
|
||||
max_throughput = await run_benchmark()
|
||||
log.info(f"[benchmark] Benchmark complete, max_throughput={max_throughput}, setting healthcheck=True")
|
||||
self.__start_healthcheck = True
|
||||
self.metrics._model_loaded(
|
||||
max_throughput=max_throughput,
|
||||
)
|
||||
log.info(f"[benchmark] _model_loaded() called, returning from handle_log_line")
|
||||
except ClientConnectorError as e:
|
||||
log.debug(
|
||||
f"failed to connect to model api during benchmark"
|
||||
f"failed to connect to comfyui api during benchmark"
|
||||
)
|
||||
self.backend_errored(str(e))
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error during benchmark: {e}", exc_info=True)
|
||||
self.backend_errored(f"Benchmark failed: {e}")
|
||||
case LogAction.ModelError if msg in log_line:
|
||||
log.debug(f"Got log line indicating error: {log_line}")
|
||||
self.backend_errored(msg)
|
||||
@@ -441,14 +419,10 @@ class Backend:
|
||||
log.debug(f"tailing file: {self.model_log_file}")
|
||||
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
|
||||
while True:
|
||||
try:
|
||||
line = await f.readline()
|
||||
if line:
|
||||
await handle_log_line(line.rstrip())
|
||||
else:
|
||||
await asyncio.sleep(LOG_POLL_INTERVAL)
|
||||
except Exception as e:
|
||||
log.error(f"Error processing log line: {e}", exc_info=True)
|
||||
line = await f.readline()
|
||||
if line:
|
||||
await handle_log_line(line.rstrip())
|
||||
else:
|
||||
await asyncio.sleep(LOG_POLL_INTERVAL)
|
||||
|
||||
###########
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
@@ -18,14 +17,6 @@ DELETE_REQUESTS_INTERVAL = 1
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def _flush_logs():
|
||||
"""Force flush all log handlers and stdout/stderr."""
|
||||
for handler in logging.root.handlers:
|
||||
handler.flush()
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
|
||||
@cache
|
||||
def get_url() -> str:
|
||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||
@@ -128,41 +119,22 @@ class Metrics:
|
||||
await self.__send_delete_requests_and_reset()
|
||||
|
||||
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
|
||||
loop_count = 0
|
||||
first_loaded_send_done = False
|
||||
while True:
|
||||
await sleep(METRICS_UPDATE_INTERVAL)
|
||||
loop_count += 1
|
||||
elapsed = time.time() - self.last_metric_update
|
||||
# Log heartbeat every 30 seconds to confirm loop is running
|
||||
if loop_count % 30 == 0:
|
||||
log.debug(f"[heartbeat] metrics loop alive, loop_count={loop_count}, model_loaded={self.system_metrics.model_is_loaded}")
|
||||
_flush_logs()
|
||||
# Extra logging for first few iterations after model loads
|
||||
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
|
||||
log.info(f"[transition] First iteration with model_loaded=True, loop_count={loop_count}, elapsed={elapsed:.1f}")
|
||||
_flush_logs()
|
||||
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
|
||||
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
|
||||
await self.__send_metrics_and_reset()
|
||||
elif self.update_pending or elapsed > 10:
|
||||
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
|
||||
await self.__send_metrics_and_reset()
|
||||
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
|
||||
first_loaded_send_done = True
|
||||
log.info(f"[transition] First loaded metrics send complete, continuing to next iteration...")
|
||||
_flush_logs()
|
||||
|
||||
def _model_loaded(self, max_throughput: float) -> None:
|
||||
log.info(f"MODEL LOADED: Setting model_is_loaded=True, max_throughput={max_throughput}")
|
||||
_flush_logs()
|
||||
self.system_metrics.model_loading_time = (
|
||||
time.time() - self.system_metrics.model_loading_start
|
||||
)
|
||||
self.system_metrics.model_is_loaded = True
|
||||
self.model_metrics.max_throughput = max_throughput
|
||||
log.info(f"MODEL LOADED: model_loading_time={self.system_metrics.model_loading_time}")
|
||||
_flush_logs()
|
||||
|
||||
def _model_errored(self, error_msg: str) -> None:
|
||||
self.model_metrics.set_errored(error_msg)
|
||||
@@ -299,7 +271,6 @@ class Metrics:
|
||||
###########
|
||||
|
||||
self.system_metrics.update_disk_usage()
|
||||
had_loadtime = loadtime_snapshot is not None and loadtime_snapshot > 0
|
||||
|
||||
sent = False
|
||||
for report_addr in self.report_addr:
|
||||
@@ -308,14 +279,8 @@ class Metrics:
|
||||
break
|
||||
|
||||
if sent:
|
||||
if had_loadtime:
|
||||
log.info(f"FIRST LOADTIME METRICS SENT SUCCESSFULLY! loadtime={loadtime_snapshot}")
|
||||
_flush_logs()
|
||||
# clear the one-shot loadtime only if we actually sent *this* value
|
||||
self.system_metrics.reset(expected=loadtime_snapshot)
|
||||
self.update_pending = False
|
||||
self.model_metrics.reset()
|
||||
self.last_metric_update = time.time()
|
||||
if had_loadtime:
|
||||
log.info(f"POST-SEND: reset complete, last_metric_update={self.last_metric_update}, continuing loop...")
|
||||
_flush_logs()
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from typing import List
|
||||
import ssl
|
||||
from asyncio import run, gather
|
||||
@@ -14,25 +12,7 @@ from aiohttp import web
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def _setup_signal_handlers():
|
||||
"""Setup signal handlers to log when process receives termination signals."""
|
||||
def signal_handler(signum, frame):
|
||||
sig_name = signal.Signals(signum).name
|
||||
log.error(f"SIGNAL RECEIVED: {sig_name} ({signum}) - process is being terminated")
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
sys.exit(128 + signum)
|
||||
|
||||
# Handle common termination signals
|
||||
for sig in [signal.SIGTERM, signal.SIGINT, signal.SIGHUP]:
|
||||
try:
|
||||
signal.signal(sig, signal_handler)
|
||||
except (OSError, ValueError):
|
||||
pass # Some signals may not be available
|
||||
|
||||
|
||||
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
||||
_setup_signal_handlers()
|
||||
try:
|
||||
log.debug("getting certificate...")
|
||||
use_ssl = os.environ.get("USE_SSL", "false") == "true"
|
||||
|
||||
+33
-26
@@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
|
||||
|
||||
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
||||
|
||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
|
||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
|
||||
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
|
||||
|
||||
|
||||
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
||||
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
|
||||
## Client Setup (Demo)
|
||||
|
||||
@@ -34,38 +33,20 @@ uv pip install -r requirements.txt
|
||||
|
||||
Several examples have been provided in the client to help you get started with your own implementation.
|
||||
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
First, set your API key as an environment variable:
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||
export VAST_API_KEY=<your_api_key>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||
```
|
||||
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
|
||||
|
||||
### Chat Completion (streaming)
|
||||
|
||||
Call to `/v1/chat/completions` with streaming response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
|
||||
Call to `/v1/chat/completions` with tool and json response.
|
||||
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Interactive Chat (streaming)
|
||||
@@ -75,6 +56,32 @@ Interactive session with calls to `/v1/chat/completions`.
|
||||
Type `clear` to clear the chat history or `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
|
||||
Call to `/v1/chat/completions` with tool and json response.
|
||||
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
|
||||
+32
-16
@@ -18,7 +18,7 @@ logging.basicConfig(
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
# ---------------------- Prompts ----------------------
|
||||
COMPLETIONS_PROMPT = "the capital of USA is"
|
||||
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
|
||||
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
TOOLS_PROMPT = (
|
||||
"Can you list the files in the current working directory and tell me what you see? "
|
||||
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
|
||||
|
||||
|
||||
# ---- OpenAI-compatible calls (non-streaming) ----
|
||||
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
|
||||
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
@@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
||||
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
@@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
|
||||
return resp["response"]
|
||||
|
||||
# ---- Streaming variants ----
|
||||
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
|
||||
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
@@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
|
||||
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
@@ -174,9 +174,10 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
|
||||
class APIDemo:
|
||||
"""Demo and testing functionality for the API client"""
|
||||
|
||||
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
|
||||
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.endpoint_name = endpoint_name
|
||||
self.tool_manager = tool_manager or ToolManager()
|
||||
|
||||
# ----- Streaming handler -----
|
||||
@@ -185,10 +186,15 @@ class APIDemo:
|
||||
reasoning_content = ""
|
||||
printed_reasoning = False
|
||||
printed_answer = False
|
||||
finish_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Track finish reason
|
||||
if choice.get("finish_reason"):
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
# reasoning tokens
|
||||
rc = delta.get("reasoning_content")
|
||||
@@ -219,6 +225,8 @@ class APIDemo:
|
||||
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
||||
if printed_answer:
|
||||
print(f"Response tokens: {len(full_response.split())}")
|
||||
if finish_reason:
|
||||
print(f"Finish reason: {finish_reason}")
|
||||
|
||||
return full_response
|
||||
|
||||
@@ -231,6 +239,7 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
prompt=COMPLETIONS_PROMPT,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
@@ -249,6 +258,7 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
@@ -261,6 +271,7 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
@@ -287,6 +298,7 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
tools=minimal_tool,
|
||||
tool_choice="none",
|
||||
max_tokens=10
|
||||
@@ -312,6 +324,7 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
tools=self.tool_manager.get_ls_tool_definition(),
|
||||
tool_choice="auto",
|
||||
max_tokens=MAX_TOKENS,
|
||||
@@ -389,6 +402,7 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
@@ -427,7 +441,6 @@ class APIDemo:
|
||||
print("=" * 60)
|
||||
print("INTERACTIVE STREAMING CHAT")
|
||||
print("=" * 60)
|
||||
print(f"Using model: {self.model}")
|
||||
print("Type 'quit' to exit, 'clear' to clear history")
|
||||
print()
|
||||
|
||||
@@ -453,7 +466,8 @@ class APIDemo:
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=0.7
|
||||
)
|
||||
@@ -473,8 +487,8 @@ class APIDemo:
|
||||
# ---------------------- CLI ----------------------
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
|
||||
p.add_argument("--model", required=True, help="Model to use for requests (required)")
|
||||
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
|
||||
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
|
||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||
|
||||
modes = p.add_mutually_exclusive_group(required=False)
|
||||
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
|
||||
@@ -502,12 +516,14 @@ async def main_async():
|
||||
print("Please specify exactly one test mode")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using model: {args.model}")
|
||||
print("=" * 60)
|
||||
print(f"Using model: {args.model}")
|
||||
print(f"Using endpoint: {args.endpoint}")
|
||||
|
||||
|
||||
try:
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.model, ToolManager())
|
||||
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
|
||||
|
||||
if args.completion:
|
||||
await demo.demo_completions()
|
||||
|
||||
@@ -11,6 +11,7 @@ MODEL_SERVER_START_LOG_MSG = [
|
||||
"llama runner started", # Ollama
|
||||
'"message":"Connected","target":"text_generation_router"', # TGI
|
||||
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
||||
"main: model loaded" # llama.cpp
|
||||
]
|
||||
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
@@ -34,6 +35,7 @@ backend = Backend(
|
||||
model_server_url=os.environ["MODEL_SERVER_URL"],
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
max_wait_time=600.0,
|
||||
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||
|
||||
+93
-9
@@ -1,19 +1,103 @@
|
||||
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
|
||||
# HuggingFace TGI PyWorker
|
||||
|
||||
1. `generate`: Generates the LLM's response to a given prompt in a single request.
|
||||
2. `generate_stream`: Streams the LLM's response token by token.
|
||||
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||
|
||||
Both endpoints use the following API payload format:
|
||||
## Instance Setup
|
||||
|
||||
1. Pick a template
|
||||
|
||||
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
|
||||
|
||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
|
||||
|
||||
The template can be configured via the template interface. You may want to change the model or startup arguments.
|
||||
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
|
||||
## Client Setup (Demo)
|
||||
|
||||
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vast-ai/pyworker
|
||||
cd pyworker
|
||||
pip install uv
|
||||
uv venv -p 3.12
|
||||
source .venv/bin/activate
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Using the Test Client
|
||||
|
||||
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
|
||||
|
||||
First, set your API key as an environment variable:
|
||||
|
||||
```bash
|
||||
export VAST_API_KEY=<your_api_key>
|
||||
```
|
||||
|
||||
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
|
||||
|
||||
### Generate (Streaming)
|
||||
|
||||
Call to `/generate_stream` with streaming response:
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
### Generate (Non-Streaming)
|
||||
|
||||
Call to `/generate` with json response:
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
### Interactive Session (Streaming)
|
||||
|
||||
Interactive session with streaming responses. Type `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
TGI provides two primary endpoints:
|
||||
|
||||
### Generate (Non-Streaming)
|
||||
|
||||
`/generate` - Returns the complete response in a single request.
|
||||
|
||||
```json
|
||||
{
|
||||
"inputs": "PROMPT",
|
||||
"inputs": "Your prompt here",
|
||||
"parameters": {
|
||||
"max_new_tokens": 250
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0.7,
|
||||
"return_full_text": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
|
||||
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
|
||||
approximately 2 seconds to complete.
|
||||
### Generate Stream (Streaming)
|
||||
|
||||
`/generate_stream` - Streams the response token by token.
|
||||
|
||||
```json
|
||||
{
|
||||
"inputs": "Your prompt here",
|
||||
"parameters": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0.7,
|
||||
"do_sample": true,
|
||||
"return_full_text": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Notes
|
||||
|
||||
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
|
||||
|
||||
+195
-34
@@ -1,61 +1,222 @@
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
from vastai import Serverless
|
||||
import asyncio
|
||||
|
||||
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
|
||||
MAX_TOKENS = 1024
|
||||
PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
# ---------------------- Logging ----------------------
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
async def call_generate(client: Serverless) -> None:
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
# ---------------------- Defaults ----------------------
|
||||
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
|
||||
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
|
||||
MAX_TOKENS = 1024
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
|
||||
|
||||
# ---------------------- API Calls ----------------------
|
||||
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
|
||||
"""Non-streaming generation via /generate endpoint"""
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"inputs": PROMPT,
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": MAX_TOKENS,
|
||||
"temperature": 0.7,
|
||||
"return_full_text": False
|
||||
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"return_full_text": False,
|
||||
}
|
||||
}
|
||||
|
||||
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
|
||||
|
||||
print(resp["response"]["generated_text"])
|
||||
log.debug("POST /generate %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
|
||||
async def call_generate_stream(client: Serverless) -> None:
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
|
||||
"""Streaming generation via /generate_stream endpoint"""
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"inputs": PROMPT,
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": MAX_TOKENS,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"do_sample": True,
|
||||
"return_full_text": False,
|
||||
}
|
||||
}
|
||||
|
||||
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request(
|
||||
"/generate_stream",
|
||||
payload,
|
||||
cost=MAX_TOKENS,
|
||||
cost=payload["parameters"]["max_new_tokens"],
|
||||
stream=True,
|
||||
)
|
||||
stream = resp["response"]
|
||||
return resp["response"] # async generator
|
||||
|
||||
printed_answer = False
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("Answer:\n", end="", flush=True)
|
||||
print(tok, end="", flush=True)
|
||||
|
||||
async def main():
|
||||
async with Serverless() as client:
|
||||
await call_generate(client)
|
||||
await call_generate_stream(client)
|
||||
# ---------------------- Demo Runner ----------------------
|
||||
class APIDemo:
|
||||
"""Demo and testing functionality for the TGI API client"""
|
||||
|
||||
def __init__(self, client: Serverless, endpoint_name: str):
|
||||
self.client = client
|
||||
self.endpoint_name = endpoint_name
|
||||
|
||||
async def handle_streaming_response(self, stream) -> str:
|
||||
"""Process streaming response and print tokens"""
|
||||
full_response = ""
|
||||
printed_answer = False
|
||||
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
print(tok, end="", flush=True)
|
||||
full_response += tok
|
||||
|
||||
print() # newline
|
||||
if printed_answer:
|
||||
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
|
||||
|
||||
return full_response
|
||||
|
||||
async def demo_generate(self) -> None:
|
||||
"""Demo non-streaming generation"""
|
||||
print("=" * 60)
|
||||
print("GENERATE DEMO (NON-STREAMING)")
|
||||
print("=" * 60)
|
||||
|
||||
response = await call_generate(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=DEFAULT_PROMPT,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
print(f"\n💬 Response: {response.get('generated_text', '')}")
|
||||
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
async def demo_generate_stream(self) -> None:
|
||||
"""Demo streaming generation"""
|
||||
print("=" * 60)
|
||||
print("GENERATE DEMO (STREAMING)")
|
||||
print("=" * 60)
|
||||
|
||||
stream = await call_generate_stream(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=DEFAULT_PROMPT,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
try:
|
||||
await self.handle_streaming_response(stream)
|
||||
except Exception as e:
|
||||
log.error("\nError during streaming: %s", e, exc_info=True)
|
||||
|
||||
async def interactive_chat(self) -> None:
|
||||
"""Interactive session with streaming generation"""
|
||||
print("=" * 60)
|
||||
print("INTERACTIVE STREAMING SESSION")
|
||||
print("=" * 60)
|
||||
print(f"Using endpoint: {self.endpoint_name}")
|
||||
print("Type 'quit' to exit")
|
||||
print()
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("You: ").strip()
|
||||
|
||||
if user_input.lower() == "quit":
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
elif not user_input:
|
||||
continue
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
stream = await call_generate_stream(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=user_input,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
print(tok, end="", flush=True)
|
||||
full_response += tok
|
||||
print() # newline
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Session interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
log.error("\nError: %s", e)
|
||||
continue
|
||||
|
||||
|
||||
# ---------------------- CLI ----------------------
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
|
||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||
|
||||
modes = p.add_mutually_exclusive_group(required=False)
|
||||
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
|
||||
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
|
||||
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
|
||||
return p
|
||||
|
||||
|
||||
async def main_async():
|
||||
args = build_arg_parser().parse_args()
|
||||
|
||||
selected = sum([args.generate, args.generate_stream, args.interactive])
|
||||
if selected == 0:
|
||||
print("Please specify exactly one test mode:")
|
||||
print(" --generate : Test generate endpoint (non-streaming)")
|
||||
print(" --generate-stream : Test generate endpoint with streaming")
|
||||
print(" --interactive : Start interactive streaming session")
|
||||
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
|
||||
sys.exit(1)
|
||||
elif selected > 1:
|
||||
print("Please specify exactly one test mode")
|
||||
sys.exit(1)
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Using endpoint: {args.endpoint}")
|
||||
|
||||
try:
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.endpoint)
|
||||
|
||||
if args.generate:
|
||||
await demo.demo_generate()
|
||||
elif args.generate_stream:
|
||||
await demo.demo_generate_stream()
|
||||
elif args.interactive:
|
||||
await demo.interactive_chat()
|
||||
|
||||
except Exception as e:
|
||||
log.error("Error during test: %s", e, exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main_async())
|
||||
|
||||
Reference in New Issue
Block a user