Compare commits

...

10 Commits

Author SHA1 Message Date
Lucas Armand 63550d5af3 Actual fix -- MOVED TO TEMPLATE 2025-11-17 10:57:55 -08:00
Lucas Armand 7ec0e11938 add await 2025-11-17 10:46:09 -08:00
Lucas Armand 191fbbfe18 try seek 2025-11-17 10:44:13 -08:00
Lucas Armand 9a4a39c71b Read model log 2025-11-17 10:32:38 -08:00
Lucas Armand a4339bd3f1 hotfix: add f 2025-11-12 16:10:55 -08:00
Lucas Armand 2b26e5e20c hotfix: remove g 2025-11-12 16:01:57 -08:00
LucasArmandVast d3727d4fd7 Merge pull request #58 from vast-ai/update-client-scripts
Update client scripts
2025-11-12 10:22:42 -08:00
Lucas Armand eedf81c0a3 Updated readme and .gitignore 2025-11-11 17:18:40 -08:00
Lucas Armand 3adec1826d minor changes 2025-11-11 17:11:38 -08:00
Lucas Armand b55bfa9611 Updated clients, include vastai-sdk, handle non-UTF-8 2025-11-11 17:09:28 -08:00
9 changed files with 446 additions and 709 deletions
+1
View File
@@ -3,3 +3,4 @@
__pycache__ __pycache__
bin/ bin/
lib64 lib64
.venv
+4 -3
View File
@@ -39,11 +39,12 @@ reporting these metrics to the autoscaler.
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few: If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00) * **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447) * **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938)
Currently available workers: Currently available workers:
* `hello_world`: A simple example worker for a basic LLM server. * `openai`: A simple example worker for a basic vLLM server.
* `comfyui`: A worker for the ComfyUI image generation backend. * `comfyui`: A worker for the ComfyUI image generation backend.
* `tgi`: A worker for the Text Generation Inference backend. * `tgi`: A worker for the Text Generation Inference backend.
+2 -2
View File
@@ -417,7 +417,7 @@ class Backend:
async def tail_log(): async def tail_log():
log.debug(f"tailing file: {self.model_log_file}") log.debug(f"tailing file: {self.model_log_file}")
async with await open_file(self.model_log_file) as f: async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f:
while True: while True:
line = await f.readline() line = await f.readline()
if line: if line:
@@ -431,4 +431,4 @@ class Backend:
if os.path.isfile(self.model_log_file) is True: if os.path.isfile(self.model_log_file) is True:
return await tail_log() return await tail_log()
else: else:
await sleep(1) await asyncio.sleep(1)
+1
View File
@@ -8,3 +8,4 @@ Requests~=2.32
transformers~=4.52 transformers~=4.52
utils==1.0.* utils==1.0.*
hf_transfer>=0.1.9 hf_transfer>=0.1.9
vastai-sdk>=0.2.0
-4
View File
@@ -124,9 +124,5 @@ cd "$SERVER_DIR"
echo "launching PyWorker server" echo "launching PyWorker server"
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") & (python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
echo "launching PyWorker server done" echo "launching PyWorker server done"
+13 -134
View File
@@ -1,108 +1,16 @@
import logging from .data_types import count_workload
import uuid import uuid
import random import random
from urllib.parse import urljoin import asyncio
import json import random
import requests from vastai import Serverless
from lib.test_utils import print_truncate_res async def main():
from utils.endpoint_util import Endpoint async with Serverless() as client:
from utils.ssl import get_cert_file_path endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name
from .data_types import count_workload
logging.basicConfig( payload = {
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_text2image_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
"""Simple Text2Image using the new modifier-based approach"""
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
"""Helper function for making requests with consistent error handling"""
try:
response = requests.post(
url,
json=payload,
timeout=timeout,
verify=verify
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred during {context}: {http_err}")
log.error(f"Status Code: {response.status_code}")
log.error("Response content:", response.text)
return None
except requests.exceptions.Timeout:
log.error(f"Timeout occurred during {context}: {url}")
return None
except requests.exceptions.ConnectionError:
log.error(f"Connection error occurred during {context}: {url}")
return None
except json.JSONDecodeError as json_err:
log.error(f"Failed to decode JSON response during {context}: {json_err}")
if 'response' in locals():
print("Response content:", response.text)
return None
except Exception as err:
log.error(f"An unexpected error occurred during {context}: {err}")
if 'response' in locals():
log.error("Response content (if available):", response.text)
return None
WORKER_ENDPOINT = "/generate/sync"
# This worker has concurrency = 1. All workloads have cost value 1.0
COST = count_workload()
# Route to get worker URL
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
# First request - get routing information
route_response = make_request(
url=urljoin(server_url, "/route/"),
payload=route_payload,
timeout=4,
context="route request"
)
if route_response is None:
return None
if "url" not in route_response or not route_response["url"]:
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
return None
if "status" in route_response:
print(f"Autoscaler status: {route_response['status']}")
return None
# Extract data from route response
url = route_response["url"]
auth_data = dict(
signature=route_response["signature"],
cost=route_response["cost"],
endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"],
url=route_response["url"],
request_idx=route_response["request_idx"],
)
# Build the payload for the worker request
worker_payload = {
"input": { "input": {
"request_id": str(uuid.uuid4()), "request_id": str(uuid.uuid4()),
"modifier": "Text2Image", "modifier": "Text2Image",
@@ -117,40 +25,11 @@ def call_text2image_workflow(
} }
} }
req_data = dict(payload=worker_payload, auth_data=auth_data) response = await endpoint.request("/generate/sync", payload, cost=count_workload())
worker_url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {worker_url}")
# Second request - call the worker endpoint
worker_response = make_request(
url=worker_url,
payload=req_data,
verify=get_cert_file_path(),
context="worker request"
)
return worker_response
# Get the file from the path on the local machine using SCP or SFTP
# or configure S3 to upload to cloud storage.
print(response["response"]["output"][0]["local_path"])
if __name__ == "__main__": if __name__ == "__main__":
from lib.test_utils import test_args asyncio.run(main())
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
result = call_text2image_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
if result is None:
log.error("Text2Image workflow failed")
else:
print(result)
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
+5 -12
View File
@@ -7,20 +7,13 @@ from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path from utils.ssl import get_cert_file_path
""" from vastai import Serverless
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
"""
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_default_workflow( ENDPOINT_NAME = "my-comfyui-endpoint"
endpoint_group_name: str, api_key: str, server_url: str COST = 100 # Use a constant cost for image generation
) -> None:
def call_default_workflow(client: Serverless) -> None:
WORKER_ENDPOINT = "/prompt" WORKER_ENDPOINT = "/prompt"
COST = 100 COST = 100
route_payload = { route_payload = {
+341 -411
View File
@@ -1,14 +1,15 @@
import logging import logging
import sys
import json import json
import os
import sys
import subprocess import subprocess
from urllib.parse import urljoin import argparse
from typing import Dict, Any, Optional, Iterator, Union, List from typing import Any, Dict, List, Optional
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig
from vastai import Serverless
import asyncio
# ---------------------- Logging ----------------------
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s", format="%(asctime)s[%(levelname)-5s] %(message)s",
@@ -16,135 +17,20 @@ logging.basicConfig(
) )
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
# ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "the capital of USA is" COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?" TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? "
"What do you think this directory might be for?"
class APIClient: )
"""Lightweight client focused solely on API communication"""
# Remove the generic WORKER_ENDPOINT since we're now going direct
DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4
def __init__(
self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
):
self.endpoint_group_name = endpoint_group_name
self.api_key = api_key
self.server_url = server_url
self.endpoint_api_key = endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service"""
if not self.endpoint_api_key:
raise ValueError("No valid endpoint API key available")
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.endpoint_api_key,
"cost": cost,
}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
timeout=self.DEFAULT_TIMEOUT,
)
response.raise_for_status()
return response.json()
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Create auth data from routing response"""
return {
"signature": message["signature"],
"cost": message["cost"],
"endpoint": message["endpoint"],
"reqnum": message["reqnum"],
"url": message["url"],
}
def _make_request(
self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint"""
# Get worker URL and auth data
cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost)
worker_url = message["url"]
auth_data = self._create_auth_data(message)
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}")
log.debug(f"Payload: {req_data}")
# Make the request using the specified method
if method.upper() == "POST":
response = requests.post(
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET":
response = requests.get(
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
if stream:
return self._handle_streaming_response(response)
else:
return response.json()
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
"""Handle streaming response and yield tokens"""
try:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
yield data # Yield the full chunk
except json.JSONDecodeError:
continue
except Exception as e:
log.error(f"Error handling streaming response: {e}")
raise
def call_completions(
self, config: CompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/completions", stream=config.stream
)
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
)
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- Tooling ----------------------
class ToolManager: class ToolManager:
"""Handles tool definitions and execution""" """Handles tool definitions and execution"""
@@ -164,7 +50,7 @@ class ToolManager:
@staticmethod @staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]: def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""Get the ls tool definition""" """OpenAI-compatible tool schema"""
return [ return [
{ {
"type": "function", "type": "function",
@@ -178,98 +64,217 @@ class ToolManager:
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str: def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result""" """Execute a tool call and return the result"""
function_name = tool_call["function"]["name"] function_name = (tool_call.get("function") or {}).get("name")
if function_name == "list_files": if function_name == "list_files":
return self.list_files() return self.list_files()
else:
raise ValueError(f"Unknown tool function: {function_name}") raise ValueError(f"Unknown tool function: {function_name}")
# ----- Helpers to handle streamed tool_calls assembly -----
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
"""
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
We merge into a per-index state dict until the assistant message finishes.
"""
idx = tc_delta.get("index")
if idx is None:
return
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
if tc_delta.get("id"):
entry["id"] = tc_delta["id"]
fn_delta = tc_delta.get("function") or {}
if "name" in fn_delta and fn_delta["name"]:
entry["function"]["name"] = fn_delta["name"]
if "arguments" in fn_delta and fn_delta["arguments"]:
entry["function"]["arguments"] += fn_delta["arguments"]
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
return [state[i] for i in sorted(state.keys())]
# ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
}
}
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
# ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
}
}
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
}
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
# ---------------------- Demo Runner ----------------------
class APIDemo: class APIDemo:
"""Demo and testing functionality for the API client""" """Demo and testing functionality for the API client"""
def __init__( def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
self.client = client self.client = client
self.model = model self.model = model
self.tool_manager = tool_manager or ToolManager() self.tool_manager = tool_manager or ToolManager()
def handle_streaming_response( # ----- Streaming handler -----
self, response_stream, show_reasoning: bool = True async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
) -> str:
"""
Handle streaming chat response and display all output.
"""
full_response = "" full_response = ""
reasoning_content = "" reasoning_content = ""
reasoning_started = False printed_reasoning = False
content_started = False printed_answer = False
for chunk in response_stream: async for chunk in stream:
# Normalize the chunk choice = (chunk.get("choices") or [{}])[0]
if isinstance(chunk, str): delta = choice.get("delta", {})
chunk = chunk.strip()
if chunk.startswith("data: "):
chunk = chunk[6:].strip()
if chunk in ["[DONE]", ""]:
continue
try:
parsed_chunk = json.loads(chunk)
except json.JSONDecodeError:
continue
elif isinstance(chunk, dict):
parsed_chunk = chunk
else:
continue
# Parse delta from the chunk # reasoning tokens
choices = parsed_chunk.get("choices", []) rc = delta.get("reasoning_content")
if not choices: if rc and show_reasoning:
continue if not printed_reasoning:
delta = choices[0].get("delta", {})
reasoning_token = delta.get("reasoning_content", "")
content_token = delta.get("content", "")
# Print reasoning token if applicable
if show_reasoning and reasoning_token:
if not reasoning_started:
print("\n🧠 Reasoning: ", end="", flush=True) print("\n🧠 Reasoning: ", end="", flush=True)
reasoning_started = True printed_reasoning = True
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True) print(rc, end="", flush=True)
reasoning_content += reasoning_token reasoning_content += rc
# Print content token # content tokens
if content_token: content_part = delta.get("content")
if not content_started: if content_part:
if show_reasoning and reasoning_started: if not printed_answer:
print(f"\n💬 Response: ", end="", flush=True) if show_reasoning and printed_reasoning:
print("\n💬 Response: ", end="", flush=True)
else: else:
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
content_started = True printed_answer = True
print(content_token, end="", flush=True) print(content_part, end="", flush=True)
full_response += content_token full_response += content_part
print() # Ensure newline after response
print() # newline
if show_reasoning: if show_reasoning:
if reasoning_started or content_started: if printed_reasoning or printed_answer:
print("\nStreaming completed.") print("\nStreaming completed.")
if reasoning_started: if printed_reasoning:
print(f"Reasoning tokens: {len(reasoning_content.split())}") print(f"Reasoning tokens: {len(reasoning_content.split())}")
if content_started: if printed_answer:
print(f"Response tokens: {len(full_response.split())}") print(f"Response tokens: {len(full_response.split())}")
return full_response return full_response
def test_tool_support(self) -> bool: async def demo_completions(self) -> None:
"""Test if the endpoint supports function calling""" print("=" * 60)
log.debug("Testing endpoint tool calling support...") print("COMPLETIONS DEMO")
print("=" * 60)
# Try a simple request with minimal tools to test support response = await call_completions(
client=self.client,
model=self.model,
prompt=COMPLETIONS_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print("\nResponse:")
print(json.dumps(response, indent=2))
async def demo_chat(self, use_streaming: bool = True) -> None:
print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
print("=" * 60)
messages = [{"role": "user", "content": CHAT_PROMPT}]
if use_streaming:
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
try:
await self.handle_streaming_response(stream, show_reasoning=True)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
else:
response = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def test_tool_support(self) -> bool:
"""Probe that tool schema is accepted (no actual call)"""
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [ minimal_tool = [
{ {
@@ -277,170 +282,147 @@ class APIDemo:
"function": {"name": "test_function", "description": "Test function"}, "function": {"name": "test_function", "description": "Test function"},
} }
] ]
try:
config = ChatCompletionConfig( _ = await call_chat_completions(
client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
max_tokens=10,
tools=minimal_tool, tools=minimal_tool,
tool_choice="none", # Don't actually call the tool tool_choice="none",
max_tokens=10
) )
try:
response = self.client.call_chat_completions(config)
return True return True
except Exception as e: except Exception as e:
log.error(f"Error: Endpoint does not support tool calling: {e}") log.error("Endpoint does not support tool calling: %s", e)
return False return False
def demo_completions(self) -> None: async def demo_ls_tool(self) -> None:
"""Demo: test basic completions endpoint""" """Ask to list files using function calling, then provide final analysis"""
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
config = CompletionConfig(
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
)
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
)
response = self.client.call_completions(config)
if isinstance(response, dict):
print("\nResponse:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_chat(self, use_streaming: bool = True) -> None:
"""
Demo: test chat completions endpoint with optional streaming
"""
print("=" * 60)
print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60)
config = ChatCompletionConfig(
model=self.model,
messages=[{"role": "user", "content": CHAT_PROMPT}],
stream=use_streaming,
)
log.info(f"Testing chat completions with model '{self.model}'...")
response = self.client.call_chat_completions(config)
if use_streaming:
try:
self.handle_streaming_response(response, show_reasoning=True)
except Exception as e:
log.error(f"\nError during streaming: {e}")
import traceback
traceback.print_exc()
return
else:
if isinstance(response, dict):
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
print("=" * 60) print("=" * 60)
print("TOOL USE DEMO: List Directory Contents") print("TOOL USE DEMO: List Directory Contents")
print("=" * 60) print("=" * 60)
# Test if tools are supported first if not await self.test_tool_support():
if not self.test_tool_support():
return return
# Request with tool available messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
messages = [{"role": "user", "content": TOOLS_PROMPT}]
config = ChatCompletionConfig( # First pass: let the model decide tools, stream tool_calls and partial content
stream = await stream_chat_completions(
client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
tools=self.tool_manager.get_ls_tool_definition(), tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto", tool_choice="auto",
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
) )
log.info(f"Making initial request with tool using model '{self.model}'...") assistant_content_buf: List[str] = []
response = self.client.call_chat_completions(config) tool_calls_state: Dict[int, Dict[str, Any]] = {}
printed_reasoning = False
printed_answer = False
if not isinstance(response, dict): async for chunk in stream:
raise ValueError("Expected dict response for tool use") choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
choice = response.get("choices", [{}])[0] rc = delta.get("reasoning_content")
message = choice.get("message", {}) if rc:
if not printed_reasoning:
printed_reasoning = True
print("🧠 Reasoning: ", end="", flush=True)
print(rc, end="", flush=True)
print(f"Assistant response: {message.get('content', 'No content')}") content_part = delta.get("content")
if content_part:
assistant_content_buf.append(content_part)
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(content_part, end="", flush=True)
# Check for tool calls if "tool_calls" in delta and delta["tool_calls"]:
tool_calls = message.get("tool_calls") for tc_delta in delta["tool_calls"]:
if not tool_calls: _merge_tool_call_delta(tool_calls_state, tc_delta)
raise ValueError(
"No tool calls made - model may not support function calling"
)
print(f"Tool calls detected: {len(tool_calls)}") # If no tool calls, were done.
if not tool_calls_state:
print("\n(No tool calls were made.)")
return
# Execute the tool call # Build assistant message with tool_calls
for tool_call in tool_calls: assistant_message = {
function_name = tool_call["function"]["name"] "role": "assistant",
print(f"Executing tool: {function_name}") "content": "".join(assistant_content_buf) if assistant_content_buf else None,
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
tool_result = self.tool_manager.execute_tool_call(tool_call)
print(f"Tool result:\n{tool_result}")
# Add tool result and continue conversation
messages.append(message) # Add assistant's message with tool call
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": tool_result,
} }
) messages.append(assistant_message)
# Get final response # Execute tools and feed results back
final_config = ChatCompletionConfig( for tc in assistant_message["tool_calls"]:
tool_name = (tc.get("function") or {}).get("name")
call_id = tc.get("id")
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
try:
args = json.loads(raw_args) if raw_args.strip() else {}
except Exception as e:
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
continue
try:
if tool_name == "list_files":
tool_result = self.tool_manager.list_files()
else:
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
except Exception as e:
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
print("\n[Tool executed]", tool_name)
print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
# Second pass: get final streamed answer after tool results
stream2 = await stream_chat_completions(
client=self.client,
model=self.model, model=self.model,
messages=messages, messages=messages,
tools=self.tool_manager.get_ls_tool_definition(), max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
) )
print("Getting final response...") final_buf = []
final_response = self.client.call_chat_completions(final_config) printed_reasoning2 = False
printed_answer2 = False
if isinstance(final_response, dict): async for chunk in stream2:
final_choice = final_response.get("choices", [{}])[0] choice = (chunk.get("choices") or [{}])[0]
final_message = final_choice.get("message", {}) delta = choice.get("delta", {})
final_content = final_message.get("content", "")
rc2 = delta.get("reasoning_content")
if rc2:
if not printed_reasoning2:
printed_reasoning2 = True
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
print(rc2, end="", flush=True)
c2 = delta.get("content")
if c2:
final_buf.append(c2)
if not printed_answer2:
printed_answer2 = True
print("\n💬 Response (final): ", end="", flush=True)
print(c2, end="", flush=True)
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:") print("FINAL LLM ANALYSIS:")
print("=" * 60) print("=" * 60)
print(final_content) print("".join(final_buf))
print("=" * 60) print("=" * 60)
def interactive_chat(self) -> None: async def interactive_chat(self) -> None:
"""Interactive chat session with streaming""" """Interactive chat session with streaming"""
print("=" * 60) print("=" * 60)
print("INTERACTIVE STREAMING CHAT") print("INTERACTIVE STREAMING CHAT")
@@ -449,7 +431,7 @@ class APIDemo:
print("Type 'quit' to exit, 'clear' to clear history") print("Type 'quit' to exit, 'clear' to clear history")
print() print()
messages = [] messages: List[Dict[str, Any]] = []
while True: while True:
try: try:
@@ -467,16 +449,15 @@ class APIDemo:
messages.append({"role": "user", "content": user_input}) messages.append({"role": "user", "content": user_input})
config = ChatCompletionConfig(
model=self.model, messages=messages, stream=True, temperature=0.7
)
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
stream = await stream_chat_completions(
response = self.client.call_chat_completions(config) client=self.client,
assistant_content = self.handle_streaming_response( model=self.model,
response, show_reasoning=True messages=messages,
max_tokens=MAX_TOKENS,
temperature=0.7
) )
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
# Add assistant response to conversation history # Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content}) messages.append({"role": "assistant", "content": assistant_content})
@@ -485,115 +466,64 @@ class APIDemo:
print("\n👋 Chat interrupted. Goodbye!") print("\n👋 Chat interrupted. Goodbye!")
break break
except Exception as e: except Exception as e:
log.error(f"\nError: {e}") log.error("\nError: %s", e)
continue continue
def main(): # ---------------------- CLI ----------------------
"""Main function with CLI switches for different tests""" def build_arg_parser() -> argparse.ArgumentParser:
from lib.test_utils import test_args p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", required=True, help="Model to use for requests (required)")
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
# Add mandatory model argument modes = p.add_mutually_exclusive_group(required=False)
test_args.add_argument( modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
"--model", required=True, help="Model to use for requests (required)" modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
) modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
return p
# Add test mode arguments
test_args.add_argument(
"--completion", action="store_true", help="Test completions endpoint"
)
test_args.add_argument(
"--chat",
action="store_true",
help="Test chat completions endpoint (non-streaming)",
)
test_args.add_argument(
"--chat-stream",
action="store_true",
help="Test chat completions endpoint with streaming",
)
test_args.add_argument(
"--tools",
action="store_true",
help="Test function calling with ls tool (non-streaming)",
)
test_args.add_argument(
"--interactive",
action="store_true",
help="Start interactive streaming chat session",
)
args = test_args.parse_args() async def main_async():
args = build_arg_parser().parse_args()
# Check that only one test mode is selected selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
test_modes = [ if selected == 0:
args.completion,
args.chat,
args.chat_stream,
args.tools,
args.interactive,
]
selected_count = sum(test_modes)
if selected_count == 0:
print("Please specify exactly one test mode:") print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint") print(" --completion : Test completions endpoint")
print(" --chat : Test chat completions endpoint (non-streaming)") print(" --chat : Test chat completions endpoint (non-streaming)")
print(" --chat-stream : Test chat completions endpoint with streaming") print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool (non-streaming)") print(" --tools : Test function calling with ls tool")
print(" --interactive : Start interactive streaming chat session") print(" --interactive : Start interactive streaming chat session")
print( print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
sys.exit(1) sys.exit(1)
elif selected_count > 1: elif selected > 1:
print("Please specify exactly one test mode") print("Please specify exactly one test mode")
sys.exit(1) sys.exit(1)
try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client
client = APIClient(
endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key,
server_url=Endpoint.get_autoscaler_server_url(args.instance),
endpoint_api_key=endpoint_api_key,
)
# Create tool manager and demo (passing the model parameter)
tool_manager = ToolManager()
demo = APIDemo(client, args.model, tool_manager)
print(f"Using model: {args.model}") print(f"Using model: {args.model}")
print("=" * 60) print("=" * 60)
# Run the selected test try:
async with Serverless() as client:
demo = APIDemo(client, args.model, ToolManager())
if args.completion: if args.completion:
demo.demo_completions() await demo.demo_completions()
elif args.chat: elif args.chat:
demo.demo_chat(use_streaming=False) await demo.demo_chat(use_streaming=False)
elif args.chat_stream: elif args.chat_stream:
demo.demo_chat(use_streaming=True) await demo.demo_chat(use_streaming=True)
elif args.tools: elif args.tools:
demo.demo_ls_tool() await demo.demo_ls_tool()
elif args.interactive: elif args.interactive:
demo.interactive_chat() await demo.interactive_chat()
except Exception as e: except Exception as e:
log.error(f"Error during test: {e}", exc_info=True) log.error("Error during test: %s", e, exc_info=True)
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
main() asyncio.run(main_async())
+53 -117
View File
@@ -1,125 +1,61 @@
import logging from vastai import Serverless
import sys import asyncio
import json
from urllib.parse import urljoin
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
logging.basicConfig( ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
level=logging.DEBUG, MAX_TOKENS = 1024
format="%(asctime)s[%(levelname)-5s] %(message)s", PROMPT = "Think step by step: Tell me about the Python programming language."
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def call_generate(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None: payload = {
WORKER_ENDPOINT = "/generate" "inputs": PROMPT,
COST = 100 "parameters": {
route_payload = { "max_new_tokens": MAX_TOKENS,
"endpoint": endpoint_group_name, "temperature": 0.7,
"api_key": api_key, "return_full_text": False
"cost": COST,
} }
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=url,
)
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
res = response.json()
print(res)
def call_generate_stream(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/generate_stream"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
} }
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
print(f"url: {url}")
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
response = requests.post(url, json=req_data, stream=True)
response.raise_for_status() # Raise an exception for bad status codes
for line in response.iter_lines():
payload = line.decode().lstrip("data:").rstrip()
if payload:
try:
data = json.loads(payload)
print(data["token"]["text"], end="")
sys.stdout.flush()
except (json.JSONDecodeError, KeyError) as e:
log.warning(f"Failed to parse streaming response: {e}")
continue
print()
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
print(resp["response"]["generated_text"])
async def call_generate_stream(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"inputs": PROMPT,
"parameters": {
"max_new_tokens": MAX_TOKENS,
"temperature": 0.7,
"do_sample": True,
"return_full_text": False,
}
}
resp = await endpoint.request(
"/generate_stream",
payload,
cost=MAX_TOKENS,
stream=True,
)
stream = resp["response"]
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("Answer:\n", end="", flush=True)
print(tok, end="", flush=True)
async def main():
async with Serverless() as client:
await call_generate(client)
await call_generate_stream(client)
if __name__ == "__main__": if __name__ == "__main__":
from lib.test_utils import test_args asyncio.run(main())
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
try:
call_generate(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_generate_stream(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")