OpenAI compatible worker (#19)
Adds initial support for OpenAI compatible inference servers Available endpoints: - `/v1/completions` - `/v1/chat/completions`
This commit is contained in:
@@ -0,0 +1,80 @@
|
||||
# OpenAI Compatible PyWorker
|
||||
|
||||
This is the base PyWorker for OpenAI compatible inference servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||
|
||||
## Instance Setup
|
||||
|
||||
1. Pick a template
|
||||
|
||||
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
||||
|
||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
|
||||
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
|
||||
|
||||
|
||||
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
||||
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
|
||||
## Client Setup (Demo)
|
||||
|
||||
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vast-ai/pyworker
|
||||
cd pyworker
|
||||
pip install uv
|
||||
uv venv -p 3.12
|
||||
source .venv/bin/activate
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Using the Test Client
|
||||
|
||||
Several examples have been provided in the client to help you get started with your own implementation.
|
||||
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (streaming)
|
||||
|
||||
Call to `/v1/chat/completions` with streaming response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
|
||||
Call to `/v1/chat/completions` with tool and json response.
|
||||
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Interactive Chat (streaming)
|
||||
|
||||
Interactive session with calls to `/v1/chat/completions`.
|
||||
|
||||
Type `clear` to clear the chat history or `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# <INFERENCE_SERVER> + <MODEL_NAME> (serverless)
|
||||
|
||||
Run <INFERENCE_SERVER> with our serverless autoscaling infrastructure.
|
||||
|
||||
See the [serverless documentation](https://docs.vast.ai/serverless) and the [Getting Started](https://docs.vast.ai/serverless/getting-started) guide for in-depth details about how to use these templates.
|
||||
|
||||
## Configuration
|
||||
|
||||
Two environment variables are provided to help you configure the <INFERENCE_SERVER> server:
|
||||
|
||||
| Variable | Default Value | Used For |
|
||||
| --- | --- | --- |
|
||||
| `MODEL_NAME` | `<MODEL_NAME>` | The model to load. Also accepts [hf.co/repo/model](#) links |
|
||||
| `<ARGS_VAR>` | `<ARGS_VAL>` | Arguments to pass to the `<ARGS_RECEIVER>` command |
|
||||
|
||||
This template has been configured to work with <MIN_VRAM> VRAM. Setting alternative models and server arguments will change the VRAM requirements. Check model cards and <INFERENCE_SERVER_DOCS> for guidance.
|
||||
|
||||
## Usage
|
||||
|
||||
We have provided a demonstration client to help you implement this template into your own infrastructure
|
||||
|
||||
### Client Setup
|
||||
|
||||
Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vast-ai/pyworker
|
||||
cd pyworker
|
||||
pip install uv
|
||||
uv venv -p 3.12
|
||||
source .venv/bin/activate
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (streaming)
|
||||
|
||||
Call to `/v1/chat/completions` with streaming response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
|
||||
Call to `/v1/chat/completions` with tool and json response.
|
||||
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Interactive Chat (streaming)
|
||||
|
||||
Interactive session with calls to `/v1/chat/completions`.
|
||||
|
||||
Type `clear` to clear the chat history or `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||
```
|
||||
@@ -0,0 +1,578 @@
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
import subprocess
|
||||
from urllib.parse import urljoin
|
||||
from typing import Dict, Any, Optional, Iterator, Union, List
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
COMPLETIONS_PROMPT = "the capital of USA is"
|
||||
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
|
||||
|
||||
class APIClient:
|
||||
"""Lightweight client focused solely on API communication"""
|
||||
|
||||
# Remove the generic WORKER_ENDPOINT since we're now going direct
|
||||
DEFAULT_COST = 100
|
||||
DEFAULT_TIMEOUT = 4
|
||||
|
||||
def __init__(self, endpoint_group_name: str, api_key: str, server_url: str):
|
||||
self.endpoint_group_name = endpoint_group_name
|
||||
self.api_key = api_key
|
||||
self.server_url = server_url
|
||||
self.endpoint_api_key = self._get_endpoint_api_key()
|
||||
|
||||
def _get_endpoint_api_key(self) -> Optional[str]:
|
||||
"""Get the endpoint API key"""
|
||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||
endpoint_name=self.endpoint_group_name,
|
||||
account_api_key=self.api_key,
|
||||
)
|
||||
if not endpoint_api_key:
|
||||
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
|
||||
return endpoint_api_key
|
||||
|
||||
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
|
||||
"""Get worker URL and auth data from routing service"""
|
||||
if not self.endpoint_api_key:
|
||||
raise ValueError("No valid endpoint API key available")
|
||||
|
||||
route_payload = {
|
||||
"endpoint": self.endpoint_group_name,
|
||||
"api_key": self.endpoint_api_key,
|
||||
"cost": cost,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
urljoin(self.server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=self.DEFAULT_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create auth data from routing response"""
|
||||
return {
|
||||
"signature": message["signature"],
|
||||
"cost": message["cost"],
|
||||
"endpoint": message["endpoint"],
|
||||
"reqnum": message["reqnum"],
|
||||
"url": message["url"],
|
||||
}
|
||||
|
||||
def _make_request(self, payload: Dict[str, Any], endpoint: str, method: str = "POST",
|
||||
stream: bool = False) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
"""Make request directly to the specific worker endpoint"""
|
||||
# Get worker URL and auth data
|
||||
cost = payload.get('max_tokens')
|
||||
message = self._get_worker_url(cost=cost)
|
||||
worker_url = message["url"]
|
||||
auth_data = self._create_auth_data(message)
|
||||
|
||||
req_data = {
|
||||
"payload": {
|
||||
"input": payload
|
||||
},
|
||||
"auth_data": auth_data
|
||||
}
|
||||
|
||||
url = urljoin(worker_url, endpoint)
|
||||
log.debug(f"Making direct request to: {url}")
|
||||
log.debug(f"Payload: {req_data}")
|
||||
|
||||
# Make the request using the specified method
|
||||
if method.upper() == "POST":
|
||||
response = requests.post(url, json=req_data, stream=stream)
|
||||
elif method.upper() == "GET":
|
||||
response = requests.get(url, params=req_data, stream=stream)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
if stream:
|
||||
return self._handle_streaming_response(response)
|
||||
else:
|
||||
return response.json()
|
||||
|
||||
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
|
||||
"""Handle streaming response and yield tokens"""
|
||||
try:
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line:
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
yield data # Yield the full chunk
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
log.error(f"Error handling streaming response: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call_completions(self, config: CompletionConfig) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
payload = config.to_dict()
|
||||
|
||||
return self._make_request(
|
||||
payload=payload,
|
||||
endpoint="/v1/completions",
|
||||
stream=config.stream
|
||||
)
|
||||
|
||||
def call_chat_completions(self, config: ChatCompletionConfig) -> Union[Dict[str, Any], Iterator[str]]:
|
||||
payload = config.to_dict()
|
||||
|
||||
return self._make_request(
|
||||
payload=payload,
|
||||
endpoint="/v1/chat/completions",
|
||||
stream=config.stream
|
||||
)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Handles tool definitions and execution"""
|
||||
|
||||
@staticmethod
|
||||
def list_files() -> str:
|
||||
"""Execute ls on current directory"""
|
||||
try:
|
||||
result = subprocess.run(['ls', '-la', '.'], capture_output=True, text=True, timeout=10)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
else:
|
||||
return f"Error: {result.stderr}"
|
||||
except Exception as e:
|
||||
return f"Error running ls: {e}"
|
||||
|
||||
@staticmethod
|
||||
def get_ls_tool_definition() -> List[Dict[str, Any]]:
|
||||
"""Get the ls tool definition"""
|
||||
return [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_files",
|
||||
"description": "List files and directories in the cwd",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
||||
"""Execute a tool call and return the result"""
|
||||
function_name = tool_call["function"]["name"]
|
||||
|
||||
if function_name == "list_files":
|
||||
return self.list_files()
|
||||
else:
|
||||
raise ValueError(f"Unknown tool function: {function_name}")
|
||||
|
||||
|
||||
class APIDemo:
|
||||
"""Demo and testing functionality for the API client"""
|
||||
|
||||
def __init__(self, client: APIClient, model: str, tool_manager: ToolManager = None):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.tool_manager = tool_manager or ToolManager()
|
||||
|
||||
def handle_streaming_response(self, response_stream, show_reasoning: bool = True) -> str:
|
||||
"""
|
||||
Handle streaming chat response and display all output.
|
||||
"""
|
||||
|
||||
full_response = ""
|
||||
reasoning_content = ""
|
||||
reasoning_started = False
|
||||
content_started = False
|
||||
|
||||
for chunk in response_stream:
|
||||
# Normalize the chunk
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.strip()
|
||||
if chunk.startswith("data: "):
|
||||
chunk = chunk[6:].strip()
|
||||
if chunk in ["[DONE]", ""]:
|
||||
continue
|
||||
try:
|
||||
parsed_chunk = json.loads(chunk)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
elif isinstance(chunk, dict):
|
||||
parsed_chunk = chunk
|
||||
else:
|
||||
continue
|
||||
|
||||
# Parse delta from the chunk
|
||||
choices = parsed_chunk.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
reasoning_token = delta.get("reasoning_content", "")
|
||||
content_token = delta.get("content", "")
|
||||
|
||||
# Print reasoning token if applicable
|
||||
if show_reasoning and reasoning_token:
|
||||
if not reasoning_started:
|
||||
print("\n🧠 Reasoning: ", end="", flush=True)
|
||||
reasoning_started = True
|
||||
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
|
||||
reasoning_content += reasoning_token
|
||||
|
||||
# Print content token
|
||||
if content_token:
|
||||
if not content_started:
|
||||
if show_reasoning and reasoning_started:
|
||||
print(f"\n💬 Response: ", end="", flush=True)
|
||||
else:
|
||||
print("Assistant: ", end="", flush=True)
|
||||
content_started = True
|
||||
print(content_token, end="", flush=True)
|
||||
full_response += content_token
|
||||
|
||||
print() # Ensure newline after response
|
||||
|
||||
if show_reasoning:
|
||||
if reasoning_started or content_started:
|
||||
print("\nStreaming completed.")
|
||||
if reasoning_started:
|
||||
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
||||
if content_started:
|
||||
print(f"Response tokens: {len(full_response.split())}")
|
||||
|
||||
return full_response
|
||||
|
||||
|
||||
def test_tool_support(self) -> bool:
|
||||
"""Test if the endpoint supports function calling"""
|
||||
log.debug("Testing endpoint tool calling support...")
|
||||
|
||||
# Try a simple request with minimal tools to test support
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
minimal_tool = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_function",
|
||||
"description": "Test function"
|
||||
}
|
||||
}]
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
tools=minimal_tool,
|
||||
tool_choice="none" # Don't actually call the tool
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.client.call_chat_completions(config)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error: Endpoint does not support tool calling: {e}")
|
||||
return False
|
||||
|
||||
def demo_completions(self) -> None:
|
||||
"""Demo: test basic completions endpoint"""
|
||||
print("=" * 60)
|
||||
print("COMPLETIONS DEMO")
|
||||
print("=" * 60)
|
||||
|
||||
config = CompletionConfig(
|
||||
model=self.model,
|
||||
prompt=COMPLETIONS_PROMPT,
|
||||
stream=False
|
||||
)
|
||||
|
||||
log.info(f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'")
|
||||
response = self.client.call_completions(config)
|
||||
|
||||
if isinstance(response, dict):
|
||||
print("\nResponse:")
|
||||
print(json.dumps(response, indent=2))
|
||||
else:
|
||||
log.error("Unexpected response format")
|
||||
|
||||
def demo_chat(self, use_streaming: bool = True) -> None:
|
||||
"""
|
||||
Demo: test chat completions endpoint with optional streaming
|
||||
"""
|
||||
print("=" * 60)
|
||||
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
|
||||
print("=" * 60)
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": CHAT_PROMPT}],
|
||||
stream=use_streaming,
|
||||
)
|
||||
|
||||
log.info(f"Testing chat completions with model '{self.model}'...")
|
||||
response = self.client.call_chat_completions(config)
|
||||
|
||||
if use_streaming:
|
||||
try:
|
||||
self.handle_streaming_response(response, show_reasoning=True)
|
||||
except Exception as e:
|
||||
log.error(f"\nError during streaming: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
else:
|
||||
if isinstance(response, dict):
|
||||
choice = response.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
|
||||
|
||||
if reasoning:
|
||||
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
|
||||
|
||||
print(f"\n💬 Assistant: {content}")
|
||||
print(f"\nFull Response:")
|
||||
print(json.dumps(response, indent=2))
|
||||
else:
|
||||
log.error("Unexpected response format")
|
||||
|
||||
|
||||
|
||||
def demo_ls_tool(self) -> None:
|
||||
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
|
||||
print("=" * 60)
|
||||
print("TOOL USE DEMO: List Directory Contents")
|
||||
print("=" * 60)
|
||||
|
||||
# Test if tools are supported first
|
||||
if not self.test_tool_support():
|
||||
return
|
||||
|
||||
# Request with tool available
|
||||
messages = [
|
||||
{"role": "user", "content": TOOLS_PROMPT}
|
||||
]
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=self.tool_manager.get_ls_tool_definition(),
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
log.info(f"Making initial request with tool using model '{self.model}'...")
|
||||
response = self.client.call_chat_completions(config)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
raise ValueError("Expected dict response for tool use")
|
||||
|
||||
choice = response.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
print(f"Assistant response: {message.get('content', 'No content')}")
|
||||
|
||||
# Check for tool calls
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not tool_calls:
|
||||
raise ValueError("No tool calls made - model may not support function calling")
|
||||
|
||||
print(f"Tool calls detected: {len(tool_calls)}")
|
||||
|
||||
# Execute the tool call
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call["function"]["name"]
|
||||
print(f"Executing tool: {function_name}")
|
||||
|
||||
tool_result = self.tool_manager.execute_tool_call(tool_call)
|
||||
print(f"Tool result:\n{tool_result}")
|
||||
|
||||
# Add tool result and continue conversation
|
||||
messages.append(message) # Add assistant's message with tool call
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": tool_result
|
||||
})
|
||||
|
||||
# Get final response
|
||||
final_config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=self.tool_manager.get_ls_tool_definition()
|
||||
)
|
||||
|
||||
print("Getting final response...")
|
||||
final_response = self.client.call_chat_completions(final_config)
|
||||
|
||||
if isinstance(final_response, dict):
|
||||
final_choice = final_response.get("choices", [{}])[0]
|
||||
final_message = final_choice.get("message", {})
|
||||
final_content = final_message.get("content", "")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("FINAL LLM ANALYSIS:")
|
||||
print("=" * 60)
|
||||
print(final_content)
|
||||
print("=" * 60)
|
||||
|
||||
def interactive_chat(self) -> None:
|
||||
"""Interactive chat session with streaming"""
|
||||
print("=" * 60)
|
||||
print("INTERACTIVE STREAMING CHAT")
|
||||
print("=" * 60)
|
||||
print(f"Using model: {self.model}")
|
||||
print("Type 'quit' to exit, 'clear' to clear history")
|
||||
print()
|
||||
|
||||
messages = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("You: ").strip()
|
||||
|
||||
if user_input.lower() == 'quit':
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
elif user_input.lower() == 'clear':
|
||||
messages = []
|
||||
print("Chat history cleared")
|
||||
continue
|
||||
elif not user_input:
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
config = ChatCompletionConfig(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = self.client.call_chat_completions(config)
|
||||
assistant_content = self.handle_streaming_response(response, show_reasoning=True)
|
||||
|
||||
# Add assistant response to conversation history
|
||||
messages.append({"role": "assistant", "content": assistant_content})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Chat interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
log.error(f"\nError: {e}")
|
||||
continue
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function with CLI switches for different tests"""
|
||||
from lib.test_utils import test_args
|
||||
|
||||
# Add mandatory model argument
|
||||
test_args.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
help="Model to use for requests (required)"
|
||||
)
|
||||
|
||||
# Add test mode arguments
|
||||
test_args.add_argument(
|
||||
"--completion",
|
||||
action="store_true",
|
||||
help="Test completions endpoint"
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--chat",
|
||||
action="store_true",
|
||||
help="Test chat completions endpoint (non-streaming)"
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--chat-stream",
|
||||
action="store_true",
|
||||
help="Test chat completions endpoint with streaming"
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--tools",
|
||||
action="store_true",
|
||||
help="Test function calling with ls tool (non-streaming)"
|
||||
)
|
||||
test_args.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="Start interactive streaming chat session"
|
||||
)
|
||||
|
||||
args = test_args.parse_args()
|
||||
|
||||
# Check that only one test mode is selected
|
||||
test_modes = [
|
||||
args.completion, args.chat, args.chat_stream,
|
||||
args.tools, args.interactive
|
||||
]
|
||||
selected_count = sum(test_modes)
|
||||
|
||||
if selected_count == 0:
|
||||
print("Please specify exactly one test mode:")
|
||||
print(" --completion : Test completions endpoint")
|
||||
print(" --chat : Test chat completions endpoint (non-streaming)")
|
||||
print(" --chat-stream : Test chat completions endpoint with streaming")
|
||||
print(" --tools : Test function calling with ls tool (non-streaming)")
|
||||
print(" --interactive : Start interactive streaming chat session")
|
||||
print(f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT")
|
||||
sys.exit(1)
|
||||
elif selected_count > 1:
|
||||
print("Please specify exactly one test mode")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
# Create the core API client
|
||||
client = APIClient(
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
api_key=args.api_key,
|
||||
server_url=args.server_url
|
||||
)
|
||||
|
||||
# Create tool manager and demo (passing the model parameter)
|
||||
tool_manager = ToolManager()
|
||||
demo = APIDemo(client, args.model, tool_manager)
|
||||
|
||||
print(f"Using model: {args.model}")
|
||||
print("=" * 60)
|
||||
|
||||
# Run the selected test
|
||||
if args.completion:
|
||||
demo.demo_completions()
|
||||
elif args.chat:
|
||||
demo.demo_chat(use_streaming=False)
|
||||
elif args.chat_stream:
|
||||
demo.demo_chat(use_streaming=True)
|
||||
elif args.tools:
|
||||
demo.demo_ls_tool()
|
||||
elif args.interactive:
|
||||
demo.interactive_chat()
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error during test: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,54 @@
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
|
||||
class SerializableDataclass:
|
||||
def _serialize_recursive(self, obj: Any) -> Any:
|
||||
if is_dataclass(obj):
|
||||
return {field.name: self._serialize_recursive(getattr(obj, field.name))
|
||||
for field in fields(obj)}
|
||||
elif isinstance(obj, dict):
|
||||
return {key: self._serialize_recursive(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [self._serialize_recursive(item) for item in obj]
|
||||
elif isinstance(obj, set):
|
||||
return [self._serialize_recursive(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return self._serialize_recursive(self)
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionConfig(SerializableDataclass):
|
||||
"""Configuration for completion requests"""
|
||||
model: str
|
||||
prompt: str = "Hello"
|
||||
max_tokens: int = 256
|
||||
temperature: float = 0.7
|
||||
top_k: int = 20
|
||||
top_p: float = 0.4
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionConfig(SerializableDataclass):
|
||||
"""Configuration for chat completion requests"""
|
||||
model: str
|
||||
messages: list = None
|
||||
max_tokens: int = 2096
|
||||
temperature: float = 0.7
|
||||
top_k: int = 20
|
||||
top_p: float = 0.4
|
||||
stream: bool = False
|
||||
tools: Optional[List[Dict[str, Any]]] = field(default_factory=list)
|
||||
tool_choice: str = "auto"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.messages is None:
|
||||
self.messages = [{"role": "user", "content": "Hello"}]
|
||||
@@ -0,0 +1,177 @@
|
||||
import os, json, random
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
|
||||
from typing import Union, Type, Dict, Any
|
||||
from aiohttp import web, ClientResponse
|
||||
import nltk
|
||||
import logging
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
Generic dataclass accepts any dictionary in input.
|
||||
"""
|
||||
@dataclass
|
||||
class GenericData(ApiPayload, ABC):
|
||||
input: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
|
||||
return cls(
|
||||
input=data["input"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
|
||||
errors = {}
|
||||
|
||||
# Validate required parameters
|
||||
required_params = ["input"]
|
||||
for param in required_params:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
|
||||
try:
|
||||
# Create clean data dict and delegate to from_dict
|
||||
clean_data = {
|
||||
"input": json_msg["input"]
|
||||
}
|
||||
|
||||
return cls.from_dict(clean_data)
|
||||
|
||||
except (json.JSONDecodeError, JsonDataException) as e:
|
||||
errors["parameters"] = str(e)
|
||||
raise JsonDataException(errors)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def for_test(cls) -> "GenericData":
|
||||
pass
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
return self.input
|
||||
|
||||
def count_workload(self) -> int:
|
||||
return self.input.get("max_tokens", 0)
|
||||
|
||||
@dataclass
|
||||
class GenericHandler(EndpointHandler[GenericData], ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def endpoint(self) -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> str:
|
||||
return os.environ.get('MODEL_HEALTH_ENDPOINT')
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[GenericData]:
|
||||
return GenericData
|
||||
|
||||
@abstractmethod
|
||||
def make_benchmark_payload(self) -> GenericData:
|
||||
pass
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
# Check if the response is actually streaming based on response headers/content-type
|
||||
is_streaming_response = (
|
||||
model_response.content_type == "text/event-stream" or
|
||||
model_response.content_type == "application/x-ndjson" or
|
||||
model_response.headers.get("Transfer-Encoding") == "chunked" or
|
||||
"stream" in model_response.content_type.lower()
|
||||
)
|
||||
|
||||
if is_streaming_response:
|
||||
log.debug("Detected streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = model_response.content_type
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
else:
|
||||
log.debug("Detected non-streaming response...")
|
||||
content = await model_response.read()
|
||||
return web.Response(
|
||||
body=content,
|
||||
status=200,
|
||||
content_type=model_response.content_type
|
||||
)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
@dataclass
|
||||
class CompletionsData(GenericData):
|
||||
@classmethod
|
||||
def for_test(cls) -> "CompletionsData":
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
test_input = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7
|
||||
}
|
||||
return cls(input=test_input)
|
||||
|
||||
@dataclass
|
||||
class CompletionsHandler(GenericHandler):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/v1/completions"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[CompletionsData]:
|
||||
return CompletionsData
|
||||
|
||||
def make_benchmark_payload(self) -> CompletionsData:
|
||||
return CompletionsData.for_test()
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionsData(GenericData):
|
||||
"""Chat completions-specific data implementation"""
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "ChatCompletionsData":
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
# Chat completions use messages format instead of prompt
|
||||
test_input = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.7
|
||||
}
|
||||
return cls(input=test_input)
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionsHandler(GenericHandler):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/v1/chat/completions"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[ChatCompletionsData]:
|
||||
return ChatCompletionsData
|
||||
|
||||
def make_benchmark_payload(self) -> ChatCompletionsData:
|
||||
return ChatCompletionsData.for_test()
|
||||
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
import logging
|
||||
from .data_types.server import CompletionsHandler, ChatCompletionsHandler
|
||||
from aiohttp import web
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.server import start_server
|
||||
|
||||
# This line indicates that the inference server is listening
|
||||
MODEL_SERVER_START_LOG_MSG = [
|
||||
"Application startup complete.", # vLLM
|
||||
"llama runner started", # Ollama
|
||||
'"message":"Connected","target":"text_generation_router"', # TGI
|
||||
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
||||
]
|
||||
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"INFO exited: vllm", # vLLM
|
||||
"RuntimeError: Engine", # vLLM
|
||||
"Error: pull model manifest:", # Ollama
|
||||
"stalled; retrying", # Ollama
|
||||
"Error: WebserverFailed", # TGI
|
||||
"Error: DownloadError", # TGI
|
||||
"Error: ShardCannotStart", #TGI
|
||||
]
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=os.environ.get("MODEL_SERVER_URL"),
|
||||
model_log_file=os.environ.get("MODEL_LOG"),
|
||||
allow_parallel_requests=True,
|
||||
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
async def handle_ping(_):
|
||||
return web.Response(body="pong")
|
||||
|
||||
routes = [
|
||||
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
|
||||
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server(backend, routes)
|
||||
@@ -0,0 +1,28 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from .data_types.server import CompletionsData
|
||||
import os
|
||||
|
||||
WORKER_ENDPOINT = "/v1/completions"
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if MODEL_NAME environment variable is set
|
||||
model_name_set = os.environ.get("MODEL_NAME") is not None
|
||||
|
||||
# Add model argument - required only if MODEL_NAME is not set
|
||||
test_args.add_argument(
|
||||
"--model",
|
||||
dest="model",
|
||||
required=not model_name_set,
|
||||
help="Model to use for completions request (required if MODEL_NAME env var not set)"
|
||||
)
|
||||
|
||||
# Parse known args to get model early, before test_load_cmd adds its args
|
||||
known_args, _ = test_args.parse_known_args()
|
||||
|
||||
# Set environment variable if model was provided
|
||||
if hasattr(known_args, 'model') and known_args.model:
|
||||
os.environ["MODEL_NAME"] = known_args.model
|
||||
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
||||
|
||||
# Now call test_load_cmd normally - it will add its own args and re-parse
|
||||
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||
Reference in New Issue
Block a user