Compare commits

..

5 Commits

Author SHA1 Message Date
Colter Downing adedb8ba90 defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first if present 2025-12-03 16:57:28 -08:00
LucasArmandVast 0339b471c5 Merge pull request #66 from vast-ai/synthesis
PyWorker Error Handling
2025-11-25 16:02:26 -08:00
Lucas Armand e143162438 bumpy pyworker version 2025-11-25 16:01:23 -08:00
LucasArmandVast 7a792fd176 Merge pull request #64 from vast-ai/add-llama-log
add llama log
2025-11-21 10:24:27 -08:00
Lucas Armand e0449cb3c7 add llama log 2025-11-21 10:22:16 -08:00
6 changed files with 72 additions and 128 deletions
+8 -34
View File
@@ -30,7 +30,7 @@ from lib.data_types import (
BenchmarkResult BenchmarkResult
) )
VERSION = "0.2.0" VERSION = "0.2.1"
MSG_HISTORY_LEN = 100 MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
@@ -235,14 +235,10 @@ class Backend:
log.debug("No healthcheck endpoint defined, skipping healthcheck") log.debug("No healthcheck endpoint defined, skipping healthcheck")
return return
first_healthcheck = True
while True: while True:
await sleep(10) await sleep(10)
if self.__start_healthcheck is False: if self.__start_healthcheck is False:
continue continue
if first_healthcheck:
log.info(f"[healthcheck] First healthcheck starting (model is now loaded)")
first_healthcheck = False
try: try:
log.debug(f"Performing healthcheck on {health_check_url}") log.debug(f"Performing healthcheck on {health_check_url}")
async with self.healthcheck_session.get(health_check_url) as response: async with self.healthcheck_session.get(health_check_url) as response:
@@ -260,22 +256,9 @@ class Backend:
self.backend_errored(str(e)) self.backend_errored(str(e))
async def _start_tracking(self) -> None: async def _start_tracking(self) -> None:
log.info("Starting tracking tasks (read_logs, send_metrics_loop, healthcheck, send_delete_requests_loop)") await gather(
task_names = ["read_logs", "send_metrics_loop", "healthcheck", "send_delete_requests_loop"] self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop()
results = await gather(
self.__read_logs(),
self.metrics._send_metrics_loop(),
self.__healthcheck(),
self.metrics._send_delete_requests_loop(),
return_exceptions=True
) )
# If we get here, one or more tasks exited (they should run forever)
log.error(f"CRITICAL: _start_tracking gather returned! This should never happen. Results: {results}")
for name, result in zip(task_names, results):
if isinstance(result, Exception):
log.error(f"Tracking task '{name}' crashed with exception: {result}", exc_info=result)
elif result is not None:
log.warning(f"Tracking task '{name}' exited unexpectedly with result: {result}")
def backend_errored(self, msg: str) -> None: def backend_errored(self, msg: str) -> None:
self.metrics._model_errored(msg) self.metrics._model_errored(msg)
@@ -416,20 +399,15 @@ class Backend:
# await sleep(5) # await sleep(5)
try: try:
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
log.info(f"[benchmark] Benchmark complete, max_throughput={max_throughput}, setting healthcheck=True")
self.__start_healthcheck = True self.__start_healthcheck = True
self.metrics._model_loaded( self.metrics._model_loaded(
max_throughput=max_throughput, max_throughput=max_throughput,
) )
log.info(f"[benchmark] _model_loaded() called, returning from handle_log_line")
except ClientConnectorError as e: except ClientConnectorError as e:
log.debug( log.debug(
f"failed to connect to model api during benchmark" f"failed to connect to comfyui api during benchmark"
) )
self.backend_errored(str(e)) self.backend_errored(str(e))
except Exception as e:
log.error(f"Unexpected error during benchmark: {e}", exc_info=True)
self.backend_errored(f"Benchmark failed: {e}")
case LogAction.ModelError if msg in log_line: case LogAction.ModelError if msg in log_line:
log.debug(f"Got log line indicating error: {log_line}") log.debug(f"Got log line indicating error: {log_line}")
self.backend_errored(msg) self.backend_errored(msg)
@@ -441,14 +419,10 @@ class Backend:
log.debug(f"tailing file: {self.model_log_file}") log.debug(f"tailing file: {self.model_log_file}")
async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f: async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
while True: while True:
try: line = await f.readline()
line = await f.readline() if line:
if line: await handle_log_line(line.rstrip())
await handle_log_line(line.rstrip()) else:
else:
await asyncio.sleep(LOG_POLL_INTERVAL)
except Exception as e:
log.error(f"Error processing log line: {e}", exc_info=True)
await asyncio.sleep(LOG_POLL_INTERVAL) await asyncio.sleep(LOG_POLL_INTERVAL)
########### ###########
-35
View File
@@ -1,5 +1,4 @@
import os import os
import sys
import time import time
import logging import logging
import json import json
@@ -18,14 +17,6 @@ DELETE_REQUESTS_INTERVAL = 1
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
def _flush_logs():
"""Force flush all log handlers and stdout/stderr."""
for handler in logging.root.handlers:
handler.flush()
sys.stdout.flush()
sys.stderr.flush()
@cache @cache
def get_url() -> str: def get_url() -> str:
use_ssl = os.environ.get("USE_SSL", "false") == "true" use_ssl = os.environ.get("USE_SSL", "false") == "true"
@@ -128,41 +119,22 @@ class Metrics:
await self.__send_delete_requests_and_reset() await self.__send_delete_requests_and_reset()
async def _send_metrics_loop(self) -> Awaitable[NoReturn]: async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
loop_count = 0
first_loaded_send_done = False
while True: while True:
await sleep(METRICS_UPDATE_INTERVAL) await sleep(METRICS_UPDATE_INTERVAL)
loop_count += 1
elapsed = time.time() - self.last_metric_update elapsed = time.time() - self.last_metric_update
# Log heartbeat every 30 seconds to confirm loop is running
if loop_count % 30 == 0:
log.debug(f"[heartbeat] metrics loop alive, loop_count={loop_count}, model_loaded={self.system_metrics.model_is_loaded}")
_flush_logs()
# Extra logging for first few iterations after model loads
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
log.info(f"[transition] First iteration with model_loaded=True, loop_count={loop_count}, elapsed={elapsed:.1f}")
_flush_logs()
if self.system_metrics.model_is_loaded is False and elapsed >= 10: if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait") log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
await self.__send_metrics_and_reset() await self.__send_metrics_and_reset()
elif self.update_pending or elapsed > 10: elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait") log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
await self.__send_metrics_and_reset() await self.__send_metrics_and_reset()
if self.system_metrics.model_is_loaded and not first_loaded_send_done:
first_loaded_send_done = True
log.info(f"[transition] First loaded metrics send complete, continuing to next iteration...")
_flush_logs()
def _model_loaded(self, max_throughput: float) -> None: def _model_loaded(self, max_throughput: float) -> None:
log.info(f"MODEL LOADED: Setting model_is_loaded=True, max_throughput={max_throughput}")
_flush_logs()
self.system_metrics.model_loading_time = ( self.system_metrics.model_loading_time = (
time.time() - self.system_metrics.model_loading_start time.time() - self.system_metrics.model_loading_start
) )
self.system_metrics.model_is_loaded = True self.system_metrics.model_is_loaded = True
self.model_metrics.max_throughput = max_throughput self.model_metrics.max_throughput = max_throughput
log.info(f"MODEL LOADED: model_loading_time={self.system_metrics.model_loading_time}")
_flush_logs()
def _model_errored(self, error_msg: str) -> None: def _model_errored(self, error_msg: str) -> None:
self.model_metrics.set_errored(error_msg) self.model_metrics.set_errored(error_msg)
@@ -299,7 +271,6 @@ class Metrics:
########### ###########
self.system_metrics.update_disk_usage() self.system_metrics.update_disk_usage()
had_loadtime = loadtime_snapshot is not None and loadtime_snapshot > 0
sent = False sent = False
for report_addr in self.report_addr: for report_addr in self.report_addr:
@@ -308,14 +279,8 @@ class Metrics:
break break
if sent: if sent:
if had_loadtime:
log.info(f"FIRST LOADTIME METRICS SENT SUCCESSFULLY! loadtime={loadtime_snapshot}")
_flush_logs()
# clear the one-shot loadtime only if we actually sent *this* value # clear the one-shot loadtime only if we actually sent *this* value
self.system_metrics.reset(expected=loadtime_snapshot) self.system_metrics.reset(expected=loadtime_snapshot)
self.update_pending = False self.update_pending = False
self.model_metrics.reset() self.model_metrics.reset()
self.last_metric_update = time.time() self.last_metric_update = time.time()
if had_loadtime:
log.info(f"POST-SEND: reset complete, last_metric_update={self.last_metric_update}, continuing loop...")
_flush_logs()
-20
View File
@@ -1,7 +1,5 @@
import os import os
import logging import logging
import signal
import sys
from typing import List from typing import List
import ssl import ssl
from asyncio import run, gather from asyncio import run, gather
@@ -14,25 +12,7 @@ from aiohttp import web
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
def _setup_signal_handlers():
"""Setup signal handlers to log when process receives termination signals."""
def signal_handler(signum, frame):
sig_name = signal.Signals(signum).name
log.error(f"SIGNAL RECEIVED: {sig_name} ({signum}) - process is being terminated")
sys.stdout.flush()
sys.stderr.flush()
sys.exit(128 + signum)
# Handle common termination signals
for sig in [signal.SIGTERM, signal.SIGINT, signal.SIGHUP]:
try:
signal.signal(sig, signal_handler)
except (OSError, ValueError):
pass # Some signals may not be available
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs): def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
_setup_signal_handlers()
try: try:
log.debug("getting certificate...") log.debug("getting certificate...")
use_ssl = os.environ.get("USE_SSL", "false") == "true" use_ssl = os.environ.get("USE_SSL", "false") == "true"
+31 -23
View File
@@ -34,38 +34,20 @@ uv pip install -r requirements.txt
Several examples have been provided in the client to help you get started with your own implementation. Several examples have been provided in the client to help you get started with your own implementation.
### Completions First, set your API key as an environment variable:
Call to `/v1/completions` with json response
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME> export VAST_API_KEY=<your_api_key>
``` ```
### Chat Completion (json) The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming) ### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response Call to `/v1/chat/completions` with streaming response
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME> python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
``` ```
### Interactive Chat (streaming) ### Interactive Chat (streaming)
@@ -75,6 +57,32 @@ Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit. Type `clear` to clear the chat history or `quit` to exit.
```bash ```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME> python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
``` ```
+32 -16
View File
@@ -18,7 +18,7 @@ logging.basicConfig(
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
# ---------------------- Prompts ---------------------- # ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "the capital of USA is" COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = ( TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? " "Can you list the files in the current working directory and tell me what you see? "
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
# ---- OpenAI-compatible calls (non-streaming) ---- # ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]: async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"] return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
return resp["response"] return resp["response"]
# ---- Streaming variants ---- # ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs): async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs): async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=endpoint_name)
payload = { payload = {
"input": { "input": {
@@ -174,9 +174,10 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
class APIDemo: class APIDemo:
"""Demo and testing functionality for the API client""" """Demo and testing functionality for the API client"""
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None): def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
self.client = client self.client = client
self.model = model self.model = model
self.endpoint_name = endpoint_name
self.tool_manager = tool_manager or ToolManager() self.tool_manager = tool_manager or ToolManager()
# ----- Streaming handler ----- # ----- Streaming handler -----
@@ -185,10 +186,15 @@ class APIDemo:
reasoning_content = "" reasoning_content = ""
printed_reasoning = False printed_reasoning = False
printed_answer = False printed_answer = False
finish_reason = None
async for chunk in stream: async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0] choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {}) delta = choice.get("delta", {})
# Track finish reason
if choice.get("finish_reason"):
finish_reason = choice.get("finish_reason")
# reasoning tokens # reasoning tokens
rc = delta.get("reasoning_content") rc = delta.get("reasoning_content")
@@ -219,6 +225,8 @@ class APIDemo:
print(f"Reasoning tokens: {len(reasoning_content.split())}") print(f"Reasoning tokens: {len(reasoning_content.split())}")
if printed_answer: if printed_answer:
print(f"Response tokens: {len(full_response.split())}") print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
return full_response return full_response
@@ -231,6 +239,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
prompt=COMPLETIONS_PROMPT, prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE, temperature=DEFAULT_TEMPERATURE,
) )
@@ -249,6 +258,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE temperature=DEFAULT_TEMPERATURE
) )
@@ -261,6 +271,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE temperature=DEFAULT_TEMPERATURE
) )
@@ -287,6 +298,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool, tools=minimal_tool,
tool_choice="none", tool_choice="none",
max_tokens=10 max_tokens=10
@@ -312,6 +324,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
tools=self.tool_manager.get_ls_tool_definition(), tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto", tool_choice="auto",
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
@@ -389,6 +402,7 @@ class APIDemo:
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE, temperature=DEFAULT_TEMPERATURE,
) )
@@ -427,7 +441,6 @@ class APIDemo:
print("=" * 60) print("=" * 60)
print("INTERACTIVE STREAMING CHAT") print("INTERACTIVE STREAMING CHAT")
print("=" * 60) print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history") print("Type 'quit' to exit, 'clear' to clear history")
print() print()
@@ -453,7 +466,8 @@ class APIDemo:
stream = await stream_chat_completions( stream = await stream_chat_completions(
client=self.client, client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
temperature=0.7 temperature=0.7
) )
@@ -473,8 +487,8 @@ class APIDemo:
# ---------------------- CLI ---------------------- # ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser: def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)") p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", required=True, help="Model to use for requests (required)") p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)") p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False) modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint") modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
@@ -502,12 +516,14 @@ async def main_async():
print("Please specify exactly one test mode") print("Please specify exactly one test mode")
sys.exit(1) sys.exit(1)
print(f"Using model: {args.model}")
print("=" * 60) print("=" * 60)
print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}")
try: try:
async with Serverless() as client: async with Serverless() as client:
demo = APIDemo(client, args.model, ToolManager()) demo = APIDemo(client, args.model, args.endpoint, ToolManager())
if args.completion: if args.completion:
await demo.demo_completions() await demo.demo_completions()
+1
View File
@@ -11,6 +11,7 @@ MODEL_SERVER_START_LOG_MSG = [
"llama runner started", # Ollama "llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI '"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI '"message":"Connected","target":"text_generation_router::server"', # TGI
"main: model loaded" # llama.cpp
] ]
MODEL_SERVER_ERROR_LOG_MSGS = [ MODEL_SERVER_ERROR_LOG_MSGS = [