From b55bfa961124a2a124b78a57867f3732130e80f4 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 11 Nov 2025 17:09:28 -0800 Subject: [PATCH 1/3] Updated clients, include vastai-sdk, handle non-UTF-8 --- lib/backend.py | 2 +- requirements.txt | 1 + workers/comfyui-json/client.py | 175 ++------ workers/comfyui/client.py | 17 +- workers/openai/client.py | 785 +++++++++++++++------------------ workers/tgi/client.py | 162 ++----- 6 files changed, 441 insertions(+), 701 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 5cbb7ff..bf1d746 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -417,7 +417,7 @@ class Backend: async def tail_log(): log.debug(f"tailing file: {self.model_log_file}") - async with await open_file(self.model_log_file) as f: + async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore'): while True: line = await f.readline() if line: diff --git a/requirements.txt b/requirements.txt index 1d99304..13b194e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ Requests~=2.32 transformers~=4.52 utils==1.0.* hf_transfer>=0.1.9 +vastai-sdk>=0.2.0g \ No newline at end of file diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index e4ac92c..c877df2 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -1,156 +1,35 @@ -import logging -import uuid -import random -from urllib.parse import urljoin -import json - -import requests - -from lib.test_utils import print_truncate_res -from utils.endpoint_util import Endpoint -from utils.ssl import get_cert_file_path +from vastai import Serverless from .data_types import count_workload -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) +import uuid +import random +import asyncio +import random +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name -def call_text2image_workflow( - endpoint_group_name: str, api_key: str, server_url: str -) -> None: - """Simple Text2Image using the new modifier-based approach""" - - def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"): - """Helper function for making requests with consistent error handling""" - try: - response = requests.post( - url, - json=payload, - timeout=timeout, - verify=verify - ) - - response.raise_for_status() - return response.json() - - except requests.exceptions.HTTPError as http_err: - log.error(f"HTTP error occurred during {context}: {http_err}") - log.error(f"Status Code: {response.status_code}") - log.error("Response content:", response.text) - return None - except requests.exceptions.Timeout: - log.error(f"Timeout occurred during {context}: {url}") - return None - except requests.exceptions.ConnectionError: - log.error(f"Connection error occurred during {context}: {url}") - return None - except json.JSONDecodeError as json_err: - log.error(f"Failed to decode JSON response during {context}: {json_err}") - if 'response' in locals(): - print("Response content:", response.text) - return None - except Exception as err: - log.error(f"An unexpected error occurred during {context}: {err}") - if 'response' in locals(): - log.error("Response content (if available):", response.text) - return None - - WORKER_ENDPOINT = "/generate/sync" - - # This worker has concurrency = 1. All workloads have cost value 1.0 - COST = count_workload() - - # Route to get worker URL - route_payload = { - "endpoint": endpoint_group_name, - "api_key": api_key, - "cost": COST, - } - - # First request - get routing information - route_response = make_request( - url=urljoin(server_url, "/route/"), - payload=route_payload, - timeout=4, - context="route request" - ) - - if route_response is None: - return None - - if "url" not in route_response or not route_response["url"]: - log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.") - return None - - if "status" in route_response: - print(f"Autoscaler status: {route_response['status']}") - return None - - # Extract data from route response - url = route_response["url"] - auth_data = dict( - signature=route_response["signature"], - cost=route_response["cost"], - endpoint=route_response["endpoint"], - reqnum=route_response["reqnum"], - url=route_response["url"], - request_idx=route_response["request_idx"], - ) - - # Build the payload for the worker request - worker_payload = { - "input": { - "request_id": str(uuid.uuid4()), - "modifier": "Text2Image", - "modifications": { - "prompt": "a beautiful landscape with mountains and lakes", - "width": 1024, - "height": 1024, - "steps": 20, - "seed": random.randint(0, 2**32 - 1) - }, - "workflow_json": {} # Empty since using modifier approach + payload = { + "input": { + "request_id": str(uuid.uuid4()), + "modifier": "Text2Image", + "modifications": { + "prompt": "a beautiful landscape with mountains and lakes", + "width": 1024, + "height": 1024, + "steps": 20, + "seed": random.randint(0, 2**32 - 1) + }, + "workflow_json": {} # Empty since using modifier approach + } } - } - - req_data = dict(payload=worker_payload, auth_data=auth_data) - worker_url = urljoin(url, WORKER_ENDPOINT) - print(f"url: {worker_url}") - - # Second request - call the worker endpoint - worker_response = make_request( - url=worker_url, - payload=req_data, - verify=get_cert_file_path(), - context="worker request" - ) - - return worker_response + + response = await endpoint.request("/generate/sync", payload, cost=count_workload()) + # Get the file from the path on the local machine using SCP or SFTP + # or configure S3 to upload to cloud storage. + print(response["response"]["output"][0]["local_path"]) if __name__ == "__main__": - from lib.test_utils import test_args - - args = test_args.parse_args() - endpoint_api_key = Endpoint.get_endpoint_api_key( - endpoint_name=args.endpoint_group_name, - account_api_key=args.api_key, - instance=args.instance, - ) - - if endpoint_api_key: - result = call_text2image_workflow( - api_key=endpoint_api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, - ) - if result is None: - log.error("Text2Image workflow failed") - else: - print(result) - else: - log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}") + asyncio.run(main()) \ No newline at end of file diff --git a/workers/comfyui/client.py b/workers/comfyui/client.py index 986ff22..7d1935e 100644 --- a/workers/comfyui/client.py +++ b/workers/comfyui/client.py @@ -7,20 +7,13 @@ from lib.test_utils import print_truncate_res from utils.endpoint_util import Endpoint from utils.ssl import get_cert_file_path -""" -NOTE: this client example uses a custom comfy workflow compatible with SD3 only -""" -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) +from vastai import Serverless -def call_default_workflow( - endpoint_group_name: str, api_key: str, server_url: str -) -> None: +ENDPOINT_NAME = "my-comfyui-endpoint" +COST = 100 # Use a constant cost for image generation + +def call_default_workflow(client: Serverless) -> None: WORKER_ENDPOINT = "/prompt" COST = 100 route_payload = { diff --git a/workers/openai/client.py b/workers/openai/client.py index e34cc90..1dadc68 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -1,14 +1,16 @@ +#!/usr/bin/env python3 import logging -import sys import json +import os +import sys import subprocess -from urllib.parse import urljoin -from typing import Dict, Any, Optional, Iterator, Union, List -import requests -from utils.endpoint_util import Endpoint -from utils.ssl import get_cert_file_path -from .data_types.client import CompletionConfig, ChatCompletionConfig +import argparse +from typing import Any, Dict, List, Optional +from vastai import Serverless +import asyncio + +# ---------------------- Logging ---------------------- logging.basicConfig( level=logging.DEBUG, format="%(asctime)s[%(levelname)-5s] %(message)s", @@ -16,135 +18,20 @@ logging.basicConfig( ) log = logging.getLogger(__file__) +# ---------------------- Prompts ---------------------- COMPLETIONS_PROMPT = "the capital of USA is" CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." -TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?" - - -class APIClient: - """Lightweight client focused solely on API communication""" - - # Remove the generic WORKER_ENDPOINT since we're now going direct - DEFAULT_COST = 100 - DEFAULT_TIMEOUT = 4 - - def __init__( - self, - endpoint_group_name: str, - api_key: str, - server_url: str, - endpoint_api_key: str, - ): - self.endpoint_group_name = endpoint_group_name - self.api_key = api_key - self.server_url = server_url - self.endpoint_api_key = endpoint_api_key - - def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]: - """Get worker URL and auth data from routing service""" - if not self.endpoint_api_key: - raise ValueError("No valid endpoint API key available") - - route_payload = { - "endpoint": self.endpoint_group_name, - "api_key": self.endpoint_api_key, - "cost": cost, - } - - response = requests.post( - urljoin(self.server_url, "/route/"), - json=route_payload, - timeout=self.DEFAULT_TIMEOUT, - ) - response.raise_for_status() - return response.json() - - def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]: - """Create auth data from routing response""" - return { - "signature": message["signature"], - "cost": message["cost"], - "endpoint": message["endpoint"], - "reqnum": message["reqnum"], - "url": message["url"], - } - - def _make_request( - self, - payload: Dict[str, Any], - endpoint: str, - method: str = "POST", - stream: bool = False, - ) -> Union[Dict[str, Any], Iterator[str]]: - """Make request directly to the specific worker endpoint""" - # Get worker URL and auth data - cost = payload.get("max_tokens", self.DEFAULT_COST) - message = self._get_worker_url(cost=cost) - worker_url = message["url"] - auth_data = self._create_auth_data(message) - - req_data = {"payload": {"input": payload}, "auth_data": auth_data} - - url = urljoin(worker_url, endpoint) - log.debug(f"Making direct request to: {url}") - log.debug(f"Payload: {req_data}") - - # Make the request using the specified method - if method.upper() == "POST": - response = requests.post( - url, json=req_data, stream=stream, verify=get_cert_file_path() - ) - elif method.upper() == "GET": - response = requests.get( - url, params=req_data, stream=stream, verify=get_cert_file_path() - ) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - response.raise_for_status() - - if stream: - return self._handle_streaming_response(response) - else: - return response.json() - - def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]: - """Handle streaming response and yield tokens""" - try: - for line in response.iter_lines(decode_unicode=True): - if line: - if line.startswith("data: "): - data_str = line[6:] - if data_str.strip() == "[DONE]": - break - try: - data = json.loads(data_str) - yield data # Yield the full chunk - except json.JSONDecodeError: - continue - except Exception as e: - log.error(f"Error handling streaming response: {e}") - raise - - def call_completions( - self, config: CompletionConfig - ) -> Union[Dict[str, Any], Iterator[str]]: - payload = config.to_dict() - - return self._make_request( - payload=payload, endpoint="/v1/completions", stream=config.stream - ) - - def call_chat_completions( - self, config: ChatCompletionConfig - ) -> Union[Dict[str, Any], Iterator[str]]: - payload = config.to_dict() - - return self._make_request( - payload=payload, endpoint="/v1/chat/completions", stream=config.stream - ) +TOOLS_PROMPT = ( + "Can you list the files in the current working directory and tell me what you see? " + "What do you think this directory might be for?" +) +ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name +DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling +MAX_TOKENS = 1024 +DEFAULT_TEMPERATURE = 0.7 +# ---------------------- Tooling ---------------------- class ToolManager: """Handles tool definitions and execution""" @@ -164,7 +51,7 @@ class ToolManager: @staticmethod def get_ls_tool_definition() -> List[Dict[str, Any]]: - """Get the ls tool definition""" + """OpenAI-compatible tool schema""" return [ { "type": "function", @@ -178,98 +65,217 @@ class ToolManager: def execute_tool_call(self, tool_call: Dict[str, Any]) -> str: """Execute a tool call and return the result""" - function_name = tool_call["function"]["name"] - + function_name = (tool_call.get("function") or {}).get("name") if function_name == "list_files": return self.list_files() - else: - raise ValueError(f"Unknown tool function: {function_name}") + raise ValueError(f"Unknown tool function: {function_name}") +# ----- Helpers to handle streamed tool_calls assembly ----- +def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None: + """ + OpenAI-style streaming sends partial tool_calls with an index and partial fields. + We merge into a per-index state dict until the assistant message finishes. + """ + idx = tc_delta.get("index") + if idx is None: + return + + entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"}) + + if tc_delta.get("id"): + entry["id"] = tc_delta["id"] + + fn_delta = tc_delta.get("function") or {} + if "name" in fn_delta and fn_delta["name"]: + entry["function"]["name"] = fn_delta["name"] + if "arguments" in fn_delta and fn_delta["arguments"]: + entry["function"]["arguments"] += fn_delta["arguments"] + + +def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]: + return [state[i] for i in sorted(state.keys())] + + +# ---- OpenAI-compatible calls (non-streaming) ---- +async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]: + + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + + payload = { + "input": { + "model": model, + "prompt": prompt, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + } + } + log.debug("POST /v1/completions %s", json.dumps(payload)[:500]) + resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) + return resp["response"] + +async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: + + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + + payload = { + "input": { + "model": model, + "messages": messages, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), + **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), + } + } + log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500]) + resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"]) + return resp["response"] + +# ---- Streaming variants ---- +async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs): + + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + + payload = { + "input": { + "model": model, + "prompt": prompt, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + "stream": True, + **({"stop": kwargs["stop"]} if "stop" in kwargs else {}), + } + } + log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500]) + resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) + return resp["response"] # async generator + +async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs): + + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + + payload = { + "input": { + "model": model, + "messages": messages, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + "stream": True, + **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), + **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), + } + } + log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500]) + resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True) + return resp["response"] # async generator + + +# ---------------------- Demo Runner ---------------------- class APIDemo: """Demo and testing functionality for the API client""" - def __init__( - self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None - ): + def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None): self.client = client self.model = model self.tool_manager = tool_manager or ToolManager() - def handle_streaming_response( - self, response_stream, show_reasoning: bool = True - ) -> str: - """ - Handle streaming chat response and display all output. - """ - + # ----- Streaming handler ----- + async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str: full_response = "" reasoning_content = "" - reasoning_started = False - content_started = False + printed_reasoning = False + printed_answer = False - for chunk in response_stream: - # Normalize the chunk - if isinstance(chunk, str): - chunk = chunk.strip() - if chunk.startswith("data: "): - chunk = chunk[6:].strip() - if chunk in ["[DONE]", ""]: - continue - try: - parsed_chunk = json.loads(chunk) - except json.JSONDecodeError: - continue - elif isinstance(chunk, dict): - parsed_chunk = chunk - else: - continue + async for chunk in stream: + choice = (chunk.get("choices") or [{}])[0] + delta = choice.get("delta", {}) - # Parse delta from the chunk - choices = parsed_chunk.get("choices", []) - if not choices: - continue - - delta = choices[0].get("delta", {}) - reasoning_token = delta.get("reasoning_content", "") - content_token = delta.get("content", "") - - # Print reasoning token if applicable - if show_reasoning and reasoning_token: - if not reasoning_started: + # reasoning tokens + rc = delta.get("reasoning_content") + if rc and show_reasoning: + if not printed_reasoning: print("\n🧠 Reasoning: ", end="", flush=True) - reasoning_started = True - print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True) - reasoning_content += reasoning_token + printed_reasoning = True + print(rc, end="", flush=True) + reasoning_content += rc - # Print content token - if content_token: - if not content_started: - if show_reasoning and reasoning_started: - print(f"\n💬 Response: ", end="", flush=True) + # content tokens + content_part = delta.get("content") + if content_part: + if not printed_answer: + if show_reasoning and printed_reasoning: + print("\n💬 Response: ", end="", flush=True) else: print("Assistant: ", end="", flush=True) - content_started = True - print(content_token, end="", flush=True) - full_response += content_token - - print() # Ensure newline after response + printed_answer = True + print(content_part, end="", flush=True) + full_response += content_part + print() # newline if show_reasoning: - if reasoning_started or content_started: + if printed_reasoning or printed_answer: print("\nStreaming completed.") - if reasoning_started: + if printed_reasoning: print(f"Reasoning tokens: {len(reasoning_content.split())}") - if content_started: + if printed_answer: print(f"Response tokens: {len(full_response.split())}") return full_response + + async def demo_completions(self) -> None: + print("=" * 60) + print("COMPLETIONS DEMO") + print("=" * 60) - def test_tool_support(self) -> bool: - """Test if the endpoint supports function calling""" - log.debug("Testing endpoint tool calling support...") + response = await call_completions( + client=self.client, + model=self.model, + prompt=COMPLETIONS_PROMPT, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) + print("\nResponse:") + print(json.dumps(response, indent=2)) - # Try a simple request with minimal tools to test support + async def demo_chat(self, use_streaming: bool = True) -> None: + print("=" * 60) + print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}") + print("=" * 60) + + messages = [{"role": "user", "content": CHAT_PROMPT}] + + if use_streaming: + stream = await stream_chat_completions( + client=self.client, + model=self.model, + messages=messages, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE + ) + try: + await self.handle_streaming_response(stream, show_reasoning=True) + except Exception as e: + log.error("\nError during streaming: %s", e, exc_info=True) + else: + response = await call_chat_completions( + client=self.client, + model=self.model, + messages=messages, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE + ) + choice = (response.get("choices") or [{}])[0] + message = choice.get("message", {}) + content = message.get("content", "") + reasoning = message.get("reasoning_content", "") or message.get("reasoning", "") + if reasoning: + print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m") + print(f"\n💬 Assistant: {content}") + print(f"\nFull Response:\n{json.dumps(response, indent=2)}") + + async def test_tool_support(self) -> bool: + """Probe that tool schema is accepted (no actual call)""" messages = [{"role": "user", "content": "Hello"}] minimal_tool = [ { @@ -277,170 +283,147 @@ class APIDemo: "function": {"name": "test_function", "description": "Test function"}, } ] - - config = ChatCompletionConfig( - model=self.model, - messages=messages, - max_tokens=10, - tools=minimal_tool, - tool_choice="none", # Don't actually call the tool - ) - try: - response = self.client.call_chat_completions(config) + _ = await call_chat_completions( + client=self.client, + model=self.model, + messages=messages, + tools=minimal_tool, + tool_choice="none", + max_tokens=10 + ) return True except Exception as e: - log.error(f"Error: Endpoint does not support tool calling: {e}") + log.error("Endpoint does not support tool calling: %s", e) return False - def demo_completions(self) -> None: - """Demo: test basic completions endpoint""" - print("=" * 60) - print("COMPLETIONS DEMO") - print("=" * 60) - - config = CompletionConfig( - model=self.model, prompt=COMPLETIONS_PROMPT, stream=False - ) - - log.info( - f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'" - ) - response = self.client.call_completions(config) - - if isinstance(response, dict): - print("\nResponse:") - print(json.dumps(response, indent=2)) - else: - log.error("Unexpected response format") - - def demo_chat(self, use_streaming: bool = True) -> None: - """ - Demo: test chat completions endpoint with optional streaming - """ - print("=" * 60) - print( - f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}" - ) - print("=" * 60) - - config = ChatCompletionConfig( - model=self.model, - messages=[{"role": "user", "content": CHAT_PROMPT}], - stream=use_streaming, - ) - - log.info(f"Testing chat completions with model '{self.model}'...") - response = self.client.call_chat_completions(config) - - if use_streaming: - try: - self.handle_streaming_response(response, show_reasoning=True) - except Exception as e: - log.error(f"\nError during streaming: {e}") - import traceback - - traceback.print_exc() - return - - else: - if isinstance(response, dict): - choice = response.get("choices", [{}])[0] - message = choice.get("message", {}) - content = message.get("content", "") - reasoning = message.get("reasoning_content", "") or message.get( - "reasoning", "" - ) - - if reasoning: - print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m") - - print(f"\n💬 Assistant: {content}") - print(f"\nFull Response:") - print(json.dumps(response, indent=2)) - else: - log.error("Unexpected response format") - - def demo_ls_tool(self) -> None: - """Demo: ask LLM to list files in the current directory and describe what it sees""" + async def demo_ls_tool(self) -> None: + """Ask to list files using function calling, then provide final analysis""" print("=" * 60) print("TOOL USE DEMO: List Directory Contents") print("=" * 60) - # Test if tools are supported first - if not self.test_tool_support(): + if not await self.test_tool_support(): return - # Request with tool available - messages = [{"role": "user", "content": TOOLS_PROMPT}] + messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}] - config = ChatCompletionConfig( + # First pass: let the model decide tools, stream tool_calls and partial content + stream = await stream_chat_completions( + client=self.client, model=self.model, messages=messages, tools=self.tool_manager.get_ls_tool_definition(), tool_choice="auto", + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, ) - log.info(f"Making initial request with tool using model '{self.model}'...") - response = self.client.call_chat_completions(config) + assistant_content_buf: List[str] = [] + tool_calls_state: Dict[int, Dict[str, Any]] = {} + printed_reasoning = False + printed_answer = False - if not isinstance(response, dict): - raise ValueError("Expected dict response for tool use") + async for chunk in stream: + choice = (chunk.get("choices") or [{}])[0] + delta = choice.get("delta", {}) - choice = response.get("choices", [{}])[0] - message = choice.get("message", {}) + rc = delta.get("reasoning_content") + if rc: + if not printed_reasoning: + printed_reasoning = True + print("🧠 Reasoning: ", end="", flush=True) + print(rc, end="", flush=True) - print(f"Assistant response: {message.get('content', 'No content')}") + content_part = delta.get("content") + if content_part: + assistant_content_buf.append(content_part) + if not printed_answer: + printed_answer = True + print("\n💬 Response: ", end="", flush=True) + print(content_part, end="", flush=True) - # Check for tool calls - tool_calls = message.get("tool_calls") - if not tool_calls: - raise ValueError( - "No tool calls made - model may not support function calling" - ) + if "tool_calls" in delta and delta["tool_calls"]: + for tc_delta in delta["tool_calls"]: + _merge_tool_call_delta(tool_calls_state, tc_delta) - print(f"Tool calls detected: {len(tool_calls)}") + # If no tool calls, we’re done. + if not tool_calls_state: + print("\n(No tool calls were made.)") + return - # Execute the tool call - for tool_call in tool_calls: - function_name = tool_call["function"]["name"] - print(f"Executing tool: {function_name}") + # Build assistant message with tool_calls + assistant_message = { + "role": "assistant", + "content": "".join(assistant_content_buf) if assistant_content_buf else None, + "tool_calls": _tool_state_to_message_tool_calls(tool_calls_state), + } + messages.append(assistant_message) - tool_result = self.tool_manager.execute_tool_call(tool_call) - print(f"Tool result:\n{tool_result}") + # Execute tools and feed results back + for tc in assistant_message["tool_calls"]: + tool_name = (tc.get("function") or {}).get("name") + call_id = tc.get("id") + raw_args = (tc.get("function") or {}).get("arguments") or "{}" - # Add tool result and continue conversation - messages.append(message) # Add assistant's message with tool call - messages.append( - { - "role": "tool", - "tool_call_id": tool_call["id"], - "content": tool_result, - } - ) + try: + args = json.loads(raw_args) if raw_args.strip() else {} + except Exception as e: + tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args}) + messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result}) + continue - # Get final response - final_config = ChatCompletionConfig( - model=self.model, - messages=messages, - tools=self.tool_manager.get_ls_tool_definition(), - ) + try: + if tool_name == "list_files": + tool_result = self.tool_manager.list_files() + else: + tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"}) + except Exception as e: + tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"}) - print("Getting final response...") - final_response = self.client.call_chat_completions(final_config) + print("\n[Tool executed]", tool_name) + print(tool_result[:500] + ("..." if len(tool_result) > 500 else "")) + messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result}) - if isinstance(final_response, dict): - final_choice = final_response.get("choices", [{}])[0] - final_message = final_choice.get("message", {}) - final_content = final_message.get("content", "") + # Second pass: get final streamed answer after tool results + stream2 = await stream_chat_completions( + client=self.client, + model=self.model, + messages=messages, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) - print("\n" + "=" * 60) - print("FINAL LLM ANALYSIS:") - print("=" * 60) - print(final_content) - print("=" * 60) + final_buf = [] + printed_reasoning2 = False + printed_answer2 = False - def interactive_chat(self) -> None: + async for chunk in stream2: + choice = (chunk.get("choices") or [{}])[0] + delta = choice.get("delta", {}) + + rc2 = delta.get("reasoning_content") + if rc2: + if not printed_reasoning2: + printed_reasoning2 = True + print("\n🧠 Reasoning (post-tools): ", end="", flush=True) + print(rc2, end="", flush=True) + + c2 = delta.get("content") + if c2: + final_buf.append(c2) + if not printed_answer2: + printed_answer2 = True + print("\n💬 Response (final): ", end="", flush=True) + print(c2, end="", flush=True) + + print("\n" + "=" * 60) + print("FINAL LLM ANALYSIS:") + print("=" * 60) + print("".join(final_buf)) + print("=" * 60) + + async def interactive_chat(self) -> None: """Interactive chat session with streaming""" print("=" * 60) print("INTERACTIVE STREAMING CHAT") @@ -449,7 +432,7 @@ class APIDemo: print("Type 'quit' to exit, 'clear' to clear history") print() - messages = [] + messages: List[Dict[str, Any]] = [] while True: try: @@ -467,16 +450,15 @@ class APIDemo: messages.append({"role": "user", "content": user_input}) - config = ChatCompletionConfig( - model=self.model, messages=messages, stream=True, temperature=0.7 - ) - print("Assistant: ", end="", flush=True) - - response = self.client.call_chat_completions(config) - assistant_content = self.handle_streaming_response( - response, show_reasoning=True + stream = await stream_chat_completions( + client=self.client, + model=self.model, + messages=messages, + max_tokens=MAX_TOKENS, + temperature=0.7 ) + assistant_content = await self.handle_streaming_response(stream, show_reasoning=True) # Add assistant response to conversation history messages.append({"role": "assistant", "content": assistant_content}) @@ -485,115 +467,64 @@ class APIDemo: print("\n👋 Chat interrupted. Goodbye!") break except Exception as e: - log.error(f"\nError: {e}") + log.error("\nError: %s", e) continue -def main(): - """Main function with CLI switches for different tests""" - from lib.test_utils import test_args +# ---------------------- CLI ---------------------- +def build_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)") + p.add_argument("--model", required=True, help="Model to use for requests (required)") + p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)") - # Add mandatory model argument - test_args.add_argument( - "--model", required=True, help="Model to use for requests (required)" - ) + modes = p.add_mutually_exclusive_group(required=False) + modes.add_argument("--completion", action="store_true", help="Test completions endpoint") + modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)") + modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming") + modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)") + modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session") + return p - # Add test mode arguments - test_args.add_argument( - "--completion", action="store_true", help="Test completions endpoint" - ) - test_args.add_argument( - "--chat", - action="store_true", - help="Test chat completions endpoint (non-streaming)", - ) - test_args.add_argument( - "--chat-stream", - action="store_true", - help="Test chat completions endpoint with streaming", - ) - test_args.add_argument( - "--tools", - action="store_true", - help="Test function calling with ls tool (non-streaming)", - ) - test_args.add_argument( - "--interactive", - action="store_true", - help="Start interactive streaming chat session", - ) - args = test_args.parse_args() +async def main_async(): + args = build_arg_parser().parse_args() - # Check that only one test mode is selected - test_modes = [ - args.completion, - args.chat, - args.chat_stream, - args.tools, - args.interactive, - ] - selected_count = sum(test_modes) - - if selected_count == 0: + selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive]) + if selected == 0: print("Please specify exactly one test mode:") print(" --completion : Test completions endpoint") print(" --chat : Test chat completions endpoint (non-streaming)") print(" --chat-stream : Test chat completions endpoint with streaming") - print(" --tools : Test function calling with ls tool (non-streaming)") + print(" --tools : Test function calling with ls tool") print(" --interactive : Start interactive streaming chat session") - print( - f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT" - ) + print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint") sys.exit(1) - elif selected_count > 1: + elif selected > 1: print("Please specify exactly one test mode") sys.exit(1) + print(f"Using model: {args.model}") + print("=" * 60) + try: - endpoint_api_key = Endpoint.get_endpoint_api_key( - endpoint_name=args.endpoint_group_name, - account_api_key=args.api_key, - instance=args.instance, - ) + async with Serverless() as client: + demo = APIDemo(client, args.model, ToolManager()) - if not endpoint_api_key: - log.error( - f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting." - ) - sys.exit(1) - - # Create the core API client - client = APIClient( - endpoint_group_name=args.endpoint_group_name, - api_key=args.api_key, - server_url=Endpoint.get_autoscaler_server_url(args.instance), - endpoint_api_key=endpoint_api_key, - ) - - # Create tool manager and demo (passing the model parameter) - tool_manager = ToolManager() - demo = APIDemo(client, args.model, tool_manager) - - print(f"Using model: {args.model}") - print("=" * 60) - - # Run the selected test - if args.completion: - demo.demo_completions() - elif args.chat: - demo.demo_chat(use_streaming=False) - elif args.chat_stream: - demo.demo_chat(use_streaming=True) - elif args.tools: - demo.demo_ls_tool() - elif args.interactive: - demo.interactive_chat() + if args.completion: + await demo.demo_completions() + elif args.chat: + await demo.demo_chat(use_streaming=False) + elif args.chat_stream: + await demo.demo_chat(use_streaming=True) + elif args.tools: + await demo.demo_ls_tool() + elif args.interactive: + await demo.interactive_chat() except Exception as e: - log.error(f"Error during test: {e}", exc_info=True) + log.error("Error during test: %s", e, exc_info=True) sys.exit(1) if __name__ == "__main__": - main() + asyncio.run(main_async()) diff --git a/workers/tgi/client.py b/workers/tgi/client.py index 66dacb9..f307602 100644 --- a/workers/tgi/client.py +++ b/workers/tgi/client.py @@ -1,125 +1,61 @@ -import logging -import sys -import json -from urllib.parse import urljoin -import requests -from utils.endpoint_util import Endpoint -from utils.ssl import get_cert_file_path +from vastai import Serverless +import asyncio -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) +ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name +MAX_TOKENS = 1024 +PROMPT = "Think step by step: Tell me about the Python programming language." +async def call_generate(client: Serverless) -> None: + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) -def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None: - WORKER_ENDPOINT = "/generate" - COST = 100 - route_payload = { - "endpoint": endpoint_group_name, - "api_key": api_key, - "cost": COST, + payload = { + "inputs": PROMPT, + "parameters": { + "max_new_tokens": MAX_TOKENS, + "temperature": 0.7, + "return_full_text": False + } } - response = requests.post( - urljoin(server_url, "/route/"), - json=route_payload, - timeout=4, - ) - response.raise_for_status() # Raise an exception for bad status codes - message = response.json() - url = message["url"] - auth_data = dict( - signature=message["signature"], - cost=message["cost"], - endpoint=message["endpoint"], - reqnum=message["reqnum"], - url=url, - ) + resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS) - payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500)) - req_data = dict(payload=payload, auth_data=auth_data) - url = urljoin(url, WORKER_ENDPOINT) - print(f"url: {url}") - response = requests.post( - url, - json=req_data, - verify=get_cert_file_path(), - ) - response.raise_for_status() - res = response.json() - print(res) + print(resp["response"]["generated_text"]) -def call_generate_stream( - endpoint_group_name: str, api_key: str, server_url: str -) -> None: - WORKER_ENDPOINT = "/generate_stream" - COST = 100 - route_payload = { - "endpoint": endpoint_group_name, - "api_key": api_key, - "cost": COST, +async def call_generate_stream(client: Serverless) -> None: + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + + payload = { + "inputs": PROMPT, + "parameters": { + "max_new_tokens": MAX_TOKENS, + "temperature": 0.7, + "do_sample": True, + "return_full_text": False, + } } - response = requests.post( - urljoin(server_url, "/route/"), - json=route_payload, - timeout=4, - ) - response.raise_for_status() # Raise an exception for bad status codes - message = response.json() - url = message["url"] - print(f"url: {url}") - auth_data = dict( - signature=message["signature"], - cost=message["cost"], - endpoint=message["endpoint"], - reqnum=message["reqnum"], - url=message["url"], - ) - payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500)) - req_data = dict(payload=payload, auth_data=auth_data) - url = urljoin(url, WORKER_ENDPOINT) - response = requests.post(url, json=req_data, stream=True) - response.raise_for_status() # Raise an exception for bad status codes - for line in response.iter_lines(): - payload = line.decode().lstrip("data:").rstrip() - if payload: - try: - data = json.loads(payload) - print(data["token"]["text"], end="") - sys.stdout.flush() - except (json.JSONDecodeError, KeyError) as e: - log.warning(f"Failed to parse streaming response: {e}") - continue - print() + resp = await endpoint.request( + "/generate_stream", + payload, + cost=MAX_TOKENS, + stream=True, + ) + stream = resp["response"] + + printed_answer = False + async for event in stream: + tok = (event.get("token") or {}).get("text") + if tok: + if not printed_answer: + printed_answer = True + print("Answer:\n", end="", flush=True) + print(tok, end="", flush=True) + +async def main(): + async with Serverless() as client: + await call_generate(client) + await call_generate_stream(client) if __name__ == "__main__": - from lib.test_utils import test_args - - args = test_args.parse_args() - - endpoint_api_key = Endpoint.get_endpoint_api_key( - endpoint_name=args.endpoint_group_name, - account_api_key=args.api_key, - instance=args.instance, - ) - if endpoint_api_key: - try: - call_generate( - api_key=endpoint_api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, - ) - call_generate_stream( - api_key=endpoint_api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, - ) - except Exception as e: - log.error(f"Error during API call: {e}") - else: - log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ") + asyncio.run(main()) From 3adec1826d982231805ea8a91745fadb37283d28 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 11 Nov 2025 17:11:38 -0800 Subject: [PATCH 2/3] minor changes --- workers/comfyui-json/client.py | 4 ++-- workers/openai/client.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index c877df2..93e184c 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -1,11 +1,11 @@ -from vastai import Serverless from .data_types import count_workload - import uuid import random import asyncio import random +from vastai import Serverless + async def main(): async with Serverless() as client: endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name diff --git a/workers/openai/client.py b/workers/openai/client.py index 1dadc68..8c88444 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 import logging import json import os From eedf81c0a314c759977109a4f91ca3059402bd1f Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 11 Nov 2025 17:18:40 -0800 Subject: [PATCH 3/3] Updated readme and .gitignore --- .gitignore | 3 ++- README.md | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 226869e..dc47eed 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .envrc __pycache__ bin/ -lib64 \ No newline at end of file +lib64 +.venv \ No newline at end of file diff --git a/README.md b/README.md index 117600d..dda0ea2 100644 --- a/README.md +++ b/README.md @@ -39,11 +39,12 @@ reporting these metrics to the autoscaler. If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few: -* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00) -* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447) +* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d) +* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d) +* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938) Currently available workers: -* `hello_world`: A simple example worker for a basic LLM server. +* `openai`: A simple example worker for a basic vLLM server. * `comfyui`: A worker for the ComfyUI image generation backend. * `tgi`: A worker for the Text Generation Inference backend.