Compare commits

...

12 Commits

Author SHA1 Message Date
Colter Downing adedb8ba90 defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first if present 2025-12-03 16:57:28 -08:00
LucasArmandVast 0339b471c5 Merge pull request #66 from vast-ai/synthesis
PyWorker Error Handling
2025-11-25 16:02:26 -08:00
Lucas Armand e143162438 bumpy pyworker version 2025-11-25 16:01:23 -08:00
Lucas Armand 7986e51e9e early errors 2025-11-24 15:24:06 -08:00
Lucas Armand 9c6ab78503 Move model log line 2025-11-24 15:22:23 -08:00
Lucas Armand 45e0c7d9ca Move model log rotate to top 2025-11-24 15:02:33 -08:00
LucasArmandVast 7a792fd176 Merge pull request #64 from vast-ai/add-llama-log
add llama log
2025-11-21 10:24:27 -08:00
Lucas Armand a47c9d1ed0 remove test bugs 2025-11-11 18:13:46 -08:00
Lucas Armand 0b14562a63 dont exit on pyworker fail 2025-11-11 17:57:08 -08:00
Lucas Armand de9b50abb9 use set +e 2025-11-11 17:53:36 -08:00
Lucas Armand c510801723 fix 2025-11-11 17:49:34 -08:00
Lucas Armand a12523b1d2 Added bad code to tgi server to test 2025-11-11 17:41:12 -08:00
5 changed files with 156 additions and 70 deletions
+1 -1
View File
@@ -30,7 +30,7 @@ from lib.data_types import (
BenchmarkResult
)
VERSION = "0.2.0"
VERSION = "0.2.1"
MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
+45 -25
View File
@@ -3,38 +3,58 @@ import logging
from typing import List
import ssl
from asyncio import run, gather
import asyncio
from lib.backend import Backend
from lib.metrics import Metrics
from aiohttp import web
log = logging.getLogger(__file__)
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("getting certificate...")
use_ssl = os.environ.get("USE_SSL", "false") == "true"
if use_ssl is True:
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(
certfile="/etc/instance.crt",
keyfile="/etc/instance.key",
)
else:
ssl_context = None
try:
log.debug("getting certificate...")
use_ssl = os.environ.get("USE_SSL", "false") == "true"
if use_ssl is True:
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(
certfile="/etc/instance.crt",
keyfile="/etc/instance.key",
)
else:
ssl_context = None
async def main():
log.debug("starting server...")
app = web.Application()
app.add_routes(routes)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(
runner,
ssl_context=ssl_context,
port=int(os.environ["WORKER_PORT"]),
**kwargs
)
await gather(site.start(), backend._start_tracking())
async def main():
log.debug("starting server...")
app = web.Application()
app.add_routes(routes)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(
runner,
ssl_context=ssl_context,
port=int(os.environ["WORKER_PORT"]),
**kwargs
)
await gather(site.start(), backend._start_tracking())
run(main())
run(main())
except Exception as e:
err_msg = f"PyWorker failed to launch: {e}"
log.error(err_msg)
async def beacon():
metrics = Metrics()
metrics._set_version(getattr(backend, "version", "0"))
metrics._set_mtoken(getattr(backend, "mtoken", ""))
try:
while True:
metrics._model_errored(err_msg)
await metrics._Metrics__send_metrics_and_reset()
await asyncio.sleep(10)
finally:
await metrics.aclose()
run(beacon())
+46 -4
View File
@@ -41,6 +41,14 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG
echo_var MODEL_LOG
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
if [ -e "$MODEL_LOG" ]; then
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
cat "$MODEL_LOG" >> "$MODEL_LOG.old"
: > "$MODEL_LOG"
fi
# Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
@@ -124,9 +132,43 @@ cd "$SERVER_DIR"
echo "launching PyWorker server"
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
set +e
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
set -e
if [ "${PY_STATUS}" -ne 0 ]; then
echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..."
ERROR_MSG="PyWorker exited: code ${PY_STATUS}"
MTOKEN="${MASTER_TOKEN:-}"
VERSION="${PYWORKER_VERSION:-0}"
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
for addr in "${REPORT_ADDRS[@]}"; do
curl -sS -X POST -H 'Content-Type: application/json' \
-d "$(cat <<JSON
{
"id": ${CONTAINER_ID:-0},
"mtoken": "${MTOKEN}",
"version": "${VERSION}",
"loadtime": 0,
"new_load": 0,
"cur_load": 0,
"rej_load": 0,
"max_perf": 0,
"cur_perf": 0,
"error_msg": "${ERROR_MSG}",
"num_requests_working": 0,
"num_requests_recieved": 0,
"additional_disk_usage": 0,
"working_request_idxs": [],
"cur_capacity": 0,
"max_capacity": 0,
"url": "${URL}"
}
JSON
)" "${addr%/}/worker_status/" || true
done
fi
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
echo "launching PyWorker server done"
+31 -23
View File
@@ -34,38 +34,20 @@ uv pip install -r requirements.txt
Several examples have been provided in the client to help you get started with your own implementation.
### Completions
Call to `/v1/completions` with json response
First, set your API key as an environment variable:
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
export VAST_API_KEY=<your_api_key>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Interactive Chat (streaming)
@@ -75,6 +57,32 @@ Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
+31 -15
View File
@@ -18,7 +18,7 @@ logging.basicConfig(
log = logging.getLogger(__file__)
# ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "the capital of USA is"
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? "
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
# ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
@@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
@@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
return resp["response"]
# ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
@@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
@@ -174,9 +174,10 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
class APIDemo:
"""Demo and testing functionality for the API client"""
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
self.client = client
self.model = model
self.endpoint_name = endpoint_name
self.tool_manager = tool_manager or ToolManager()
# ----- Streaming handler -----
@@ -185,11 +186,16 @@ class APIDemo:
reasoning_content = ""
printed_reasoning = False
printed_answer = False
finish_reason = None
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
# Track finish reason
if choice.get("finish_reason"):
finish_reason = choice.get("finish_reason")
# reasoning tokens
rc = delta.get("reasoning_content")
if rc and show_reasoning:
@@ -219,6 +225,8 @@ class APIDemo:
print(f"Reasoning tokens: {len(reasoning_content.split())}")
if printed_answer:
print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
return full_response
@@ -231,6 +239,7 @@ class APIDemo:
client=self.client,
model=self.model,
prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
@@ -249,6 +258,7 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
@@ -261,6 +271,7 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
@@ -287,6 +298,7 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool,
tool_choice="none",
max_tokens=10
@@ -312,6 +324,7 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto",
max_tokens=MAX_TOKENS,
@@ -389,6 +402,7 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
@@ -427,7 +441,6 @@ class APIDemo:
print("=" * 60)
print("INTERACTIVE STREAMING CHAT")
print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history")
print()
@@ -454,6 +467,7 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=0.7
)
@@ -473,8 +487,8 @@ class APIDemo:
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", required=True, help="Model to use for requests (required)")
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
@@ -502,12 +516,14 @@ async def main_async():
print("Please specify exactly one test mode")
sys.exit(1)
print(f"Using model: {args.model}")
print("=" * 60)
print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}")
try:
async with Serverless() as client:
demo = APIDemo(client, args.model, ToolManager())
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
if args.completion:
await demo.demo_completions()