fix pyright errors + revert to old way of handling cancelled api requests (#23)

This commit is contained in:
Nader Arbabian
2025-07-17 16:59:06 -07:00
committed by GitHub
parent 9e369c55a5
commit be2aafdb1f
7 changed files with 265 additions and 234 deletions
+16 -2
View File
@@ -5,7 +5,7 @@ import base64
import subprocess import subprocess
import dataclasses import dataclasses
import logging import logging
from asyncio import sleep, gather, Semaphore from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool from distutils.util import strtobool
@@ -123,6 +123,12 @@ class Backend:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}") log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum) self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
@@ -168,7 +174,15 @@ class Backend:
return web.Response(status=401) return web.Response(status=401)
try: try:
return await make_request() done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
[task.cancel() for task in pending]
return done.pop().result()
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
+1 -1
View File
@@ -1,4 +1,4 @@
aiohttp~=3.11 aiohttp==3.10.1
anyio~=4.4 anyio~=4.4
lib~=4.0 lib~=4.0
nltk~=3.9 nltk~=3.9
+172 -160
View File
@@ -19,40 +19,45 @@ COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?" TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
class APIClient: class APIClient:
"""Lightweight client focused solely on API communication""" """Lightweight client focused solely on API communication"""
# Remove the generic WORKER_ENDPOINT since we're now going direct # Remove the generic WORKER_ENDPOINT since we're now going direct
DEFAULT_COST = 100 DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4 DEFAULT_TIMEOUT = 4
def __init__(self, endpoint_group_name: str, api_key: str, server_url: str): def __init__(
self, endpoint_group_name: str, api_key: str, server_url: str, instance: str
):
self.endpoint_group_name = endpoint_group_name self.endpoint_group_name = endpoint_group_name
self.api_key = api_key self.api_key = api_key
self.server_url = server_url self.server_url = server_url
self.instance = instance
self.endpoint_api_key = self._get_endpoint_api_key() self.endpoint_api_key = self._get_endpoint_api_key()
def _get_endpoint_api_key(self) -> Optional[str]: def _get_endpoint_api_key(self) -> Optional[str]:
"""Get the endpoint API key""" """Get the endpoint API key"""
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=self.endpoint_group_name, endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key, account_api_key=self.api_key,
instance=self.instance,
) )
if not endpoint_api_key: if not endpoint_api_key:
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}") log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
return endpoint_api_key return endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]: def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service""" """Get worker URL and auth data from routing service"""
if not self.endpoint_api_key: if not self.endpoint_api_key:
raise ValueError("No valid endpoint API key available") raise ValueError("No valid endpoint API key available")
route_payload = { route_payload = {
"endpoint": self.endpoint_group_name, "endpoint": self.endpoint_group_name,
"api_key": self.endpoint_api_key, "api_key": self.endpoint_api_key,
"cost": cost, "cost": cost,
} }
response = requests.post( response = requests.post(
urljoin(self.server_url, "/route/"), urljoin(self.server_url, "/route/"),
json=route_payload, json=route_payload,
@@ -60,7 +65,7 @@ class APIClient:
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]: def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Create auth data from routing response""" """Create auth data from routing response"""
return { return {
@@ -70,27 +75,27 @@ class APIClient:
"reqnum": message["reqnum"], "reqnum": message["reqnum"],
"url": message["url"], "url": message["url"],
} }
def _make_request(self, payload: Dict[str, Any], endpoint: str, method: str = "POST", def _make_request(
stream: bool = False) -> Union[Dict[str, Any], Iterator[str]]: self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint""" """Make request directly to the specific worker endpoint"""
# Get worker URL and auth data # Get worker URL and auth data
cost = payload.get('max_tokens') cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost) message = self._get_worker_url(cost=cost)
worker_url = message["url"] worker_url = message["url"]
auth_data = self._create_auth_data(message) auth_data = self._create_auth_data(message)
req_data = { req_data = {"payload": {"input": payload}, "auth_data": auth_data}
"payload": {
"input": payload
},
"auth_data": auth_data
}
url = urljoin(worker_url, endpoint) url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}") log.debug(f"Making direct request to: {url}")
log.debug(f"Payload: {req_data}") log.debug(f"Payload: {req_data}")
# Make the request using the specified method # Make the request using the specified method
if method.upper() == "POST": if method.upper() == "POST":
response = requests.post(url, json=req_data, stream=stream) response = requests.post(url, json=req_data, stream=stream)
@@ -98,14 +103,14 @@ class APIClient:
response = requests.get(url, params=req_data, stream=stream) response = requests.get(url, params=req_data, stream=stream)
else: else:
raise ValueError(f"Unsupported HTTP method: {method}") raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status() response.raise_for_status()
if stream: if stream:
return self._handle_streaming_response(response) return self._handle_streaming_response(response)
else: else:
return response.json() return response.json()
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]: def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
"""Handle streaming response and yield tokens""" """Handle streaming response and yield tokens"""
try: try:
@@ -124,61 +129,60 @@ class APIClient:
log.error(f"Error handling streaming response: {e}") log.error(f"Error handling streaming response: {e}")
raise raise
def call_completions(
def call_completions(self, config: CompletionConfig) -> Union[Dict[str, Any], Iterator[str]]: self, config: CompletionConfig
payload = config.to_dict() ) -> Union[Dict[str, Any], Iterator[str]]:
return self._make_request(
payload=payload,
endpoint="/v1/completions",
stream=config.stream
)
def call_chat_completions(self, config: ChatCompletionConfig) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict() payload = config.to_dict()
return self._make_request( return self._make_request(
payload=payload, payload=payload, endpoint="/v1/completions", stream=config.stream
endpoint="/v1/chat/completions", )
stream=config.stream
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
) )
class ToolManager: class ToolManager:
"""Handles tool definitions and execution""" """Handles tool definitions and execution"""
@staticmethod @staticmethod
def list_files() -> str: def list_files() -> str:
"""Execute ls on current directory""" """Execute ls on current directory"""
try: try:
result = subprocess.run(['ls', '-la', '.'], capture_output=True, text=True, timeout=10) result = subprocess.run(
["ls", "-la", "."], capture_output=True, text=True, timeout=10
)
if result.returncode == 0: if result.returncode == 0:
return result.stdout return result.stdout
else: else:
return f"Error: {result.stderr}" return f"Error: {result.stderr}"
except Exception as e: except Exception as e:
return f"Error running ls: {e}" return f"Error running ls: {e}"
@staticmethod @staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]: def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""Get the ls tool definition""" """Get the ls tool definition"""
return [{ return [
"type": "function", {
"function": { "type": "function",
"name": "list_files", "function": {
"description": "List files and directories in the cwd", "name": "list_files",
"parameters": { "description": "List files and directories in the cwd",
"type": "object", "parameters": {"type": "object", "properties": {}, "required": []},
"properties": {}, },
"required": []
}
} }
}] ]
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str: def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result""" """Execute a tool call and return the result"""
function_name = tool_call["function"]["name"] function_name = tool_call["function"]["name"]
if function_name == "list_files": if function_name == "list_files":
return self.list_files() return self.list_files()
else: else:
@@ -187,13 +191,17 @@ class ToolManager:
class APIDemo: class APIDemo:
"""Demo and testing functionality for the API client""" """Demo and testing functionality for the API client"""
def __init__(self, client: APIClient, model: str, tool_manager: ToolManager = None): def __init__(
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
self.client = client self.client = client
self.model = model self.model = model
self.tool_manager = tool_manager or ToolManager() self.tool_manager = tool_manager or ToolManager()
def handle_streaming_response(self, response_stream, show_reasoning: bool = True) -> str: def handle_streaming_response(
self, response_stream, show_reasoning: bool = True
) -> str:
""" """
Handle streaming chat response and display all output. Handle streaming chat response and display all output.
""" """
@@ -260,178 +268,181 @@ class APIDemo:
return full_response return full_response
def test_tool_support(self) -> bool: def test_tool_support(self) -> bool:
"""Test if the endpoint supports function calling""" """Test if the endpoint supports function calling"""
log.debug("Testing endpoint tool calling support...") log.debug("Testing endpoint tool calling support...")
# Try a simple request with minimal tools to test support # Try a simple request with minimal tools to test support
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [{ minimal_tool = [
"type": "function", {
"function": { "type": "function",
"name": "test_function", "function": {"name": "test_function", "description": "Test function"},
"description": "Test function"
} }
}] ]
config = ChatCompletionConfig( config = ChatCompletionConfig(
model=self.model, model=self.model,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
tools=minimal_tool, tools=minimal_tool,
tool_choice="none" # Don't actually call the tool tool_choice="none", # Don't actually call the tool
) )
try: try:
response = self.client.call_chat_completions(config) response = self.client.call_chat_completions(config)
return True return True
except Exception as e: except Exception as e:
log.error(f"Error: Endpoint does not support tool calling: {e}") log.error(f"Error: Endpoint does not support tool calling: {e}")
return False return False
def demo_completions(self) -> None: def demo_completions(self) -> None:
"""Demo: test basic completions endpoint""" """Demo: test basic completions endpoint"""
print("=" * 60) print("=" * 60)
print("COMPLETIONS DEMO") print("COMPLETIONS DEMO")
print("=" * 60) print("=" * 60)
config = CompletionConfig( config = CompletionConfig(
model=self.model, model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
prompt=COMPLETIONS_PROMPT, )
stream=False
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
) )
log.info(f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'")
response = self.client.call_completions(config) response = self.client.call_completions(config)
if isinstance(response, dict): if isinstance(response, dict):
print("\nResponse:") print("\nResponse:")
print(json.dumps(response, indent=2)) print(json.dumps(response, indent=2))
else: else:
log.error("Unexpected response format") log.error("Unexpected response format")
def demo_chat(self, use_streaming: bool = True) -> None: def demo_chat(self, use_streaming: bool = True) -> None:
""" """
Demo: test chat completions endpoint with optional streaming Demo: test chat completions endpoint with optional streaming
""" """
print("=" * 60) print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}") print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60) print("=" * 60)
config = ChatCompletionConfig( config = ChatCompletionConfig(
model=self.model, model=self.model,
messages=[{"role": "user", "content": CHAT_PROMPT}], messages=[{"role": "user", "content": CHAT_PROMPT}],
stream=use_streaming, stream=use_streaming,
) )
log.info(f"Testing chat completions with model '{self.model}'...") log.info(f"Testing chat completions with model '{self.model}'...")
response = self.client.call_chat_completions(config) response = self.client.call_chat_completions(config)
if use_streaming: if use_streaming:
try: try:
self.handle_streaming_response(response, show_reasoning=True) self.handle_streaming_response(response, show_reasoning=True)
except Exception as e: except Exception as e:
log.error(f"\nError during streaming: {e}") log.error(f"\nError during streaming: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return return
else: else:
if isinstance(response, dict): if isinstance(response, dict):
choice = response.get("choices", [{}])[0] choice = response.get("choices", [{}])[0]
message = choice.get("message", {}) message = choice.get("message", {})
content = message.get("content", "") content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "") reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning: if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m") print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}") print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:") print(f"\nFull Response:")
print(json.dumps(response, indent=2)) print(json.dumps(response, indent=2))
else: else:
log.error("Unexpected response format") log.error("Unexpected response format")
def demo_ls_tool(self) -> None: def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees""" """Demo: ask LLM to list files in the current directory and describe what it sees"""
print("=" * 60) print("=" * 60)
print("TOOL USE DEMO: List Directory Contents") print("TOOL USE DEMO: List Directory Contents")
print("=" * 60) print("=" * 60)
# Test if tools are supported first # Test if tools are supported first
if not self.test_tool_support(): if not self.test_tool_support():
return return
# Request with tool available # Request with tool available
messages = [ messages = [{"role": "user", "content": TOOLS_PROMPT}]
{"role": "user", "content": TOOLS_PROMPT}
]
config = ChatCompletionConfig( config = ChatCompletionConfig(
model=self.model, model=self.model,
messages=messages, messages=messages,
tools=self.tool_manager.get_ls_tool_definition(), tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto" tool_choice="auto",
) )
log.info(f"Making initial request with tool using model '{self.model}'...") log.info(f"Making initial request with tool using model '{self.model}'...")
response = self.client.call_chat_completions(config) response = self.client.call_chat_completions(config)
if not isinstance(response, dict): if not isinstance(response, dict):
raise ValueError("Expected dict response for tool use") raise ValueError("Expected dict response for tool use")
choice = response.get("choices", [{}])[0] choice = response.get("choices", [{}])[0]
message = choice.get("message", {}) message = choice.get("message", {})
print(f"Assistant response: {message.get('content', 'No content')}") print(f"Assistant response: {message.get('content', 'No content')}")
# Check for tool calls # Check for tool calls
tool_calls = message.get("tool_calls") tool_calls = message.get("tool_calls")
if not tool_calls: if not tool_calls:
raise ValueError("No tool calls made - model may not support function calling") raise ValueError(
"No tool calls made - model may not support function calling"
)
print(f"Tool calls detected: {len(tool_calls)}") print(f"Tool calls detected: {len(tool_calls)}")
# Execute the tool call # Execute the tool call
for tool_call in tool_calls: for tool_call in tool_calls:
function_name = tool_call["function"]["name"] function_name = tool_call["function"]["name"]
print(f"Executing tool: {function_name}") print(f"Executing tool: {function_name}")
tool_result = self.tool_manager.execute_tool_call(tool_call) tool_result = self.tool_manager.execute_tool_call(tool_call)
print(f"Tool result:\n{tool_result}") print(f"Tool result:\n{tool_result}")
# Add tool result and continue conversation # Add tool result and continue conversation
messages.append(message) # Add assistant's message with tool call messages.append(message) # Add assistant's message with tool call
messages.append({ messages.append(
"role": "tool", {
"tool_call_id": tool_call["id"], "role": "tool",
"content": tool_result "tool_call_id": tool_call["id"],
}) "content": tool_result,
}
)
# Get final response # Get final response
final_config = ChatCompletionConfig( final_config = ChatCompletionConfig(
model=self.model, model=self.model,
messages=messages, messages=messages,
tools=self.tool_manager.get_ls_tool_definition() tools=self.tool_manager.get_ls_tool_definition(),
) )
print("Getting final response...") print("Getting final response...")
final_response = self.client.call_chat_completions(final_config) final_response = self.client.call_chat_completions(final_config)
if isinstance(final_response, dict): if isinstance(final_response, dict):
final_choice = final_response.get("choices", [{}])[0] final_choice = final_response.get("choices", [{}])[0]
final_message = final_choice.get("message", {}) final_message = final_choice.get("message", {})
final_content = final_message.get("content", "") final_content = final_message.get("content", "")
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:") print("FINAL LLM ANALYSIS:")
print("=" * 60) print("=" * 60)
print(final_content) print(final_content)
print("=" * 60) print("=" * 60)
def interactive_chat(self) -> None: def interactive_chat(self) -> None:
"""Interactive chat session with streaming""" """Interactive chat session with streaming"""
print("=" * 60) print("=" * 60)
@@ -440,40 +451,39 @@ class APIDemo:
print(f"Using model: {self.model}") print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history") print("Type 'quit' to exit, 'clear' to clear history")
print() print()
messages = [] messages = []
while True: while True:
try: try:
user_input = input("You: ").strip() user_input = input("You: ").strip()
if user_input.lower() == 'quit': if user_input.lower() == "quit":
print("👋 Goodbye!") print("👋 Goodbye!")
break break
elif user_input.lower() == 'clear': elif user_input.lower() == "clear":
messages = [] messages = []
print("Chat history cleared") print("Chat history cleared")
continue continue
elif not user_input: elif not user_input:
continue continue
messages.append({"role": "user", "content": user_input}) messages.append({"role": "user", "content": user_input})
config = ChatCompletionConfig( config = ChatCompletionConfig(
model=self.model, model=self.model, messages=messages, stream=True, temperature=0.7
messages=messages,
stream=True,
temperature=0.7
) )
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
response = self.client.call_chat_completions(config) response = self.client.call_chat_completions(config)
assistant_content = self.handle_streaming_response(response, show_reasoning=True) assistant_content = self.handle_streaming_response(
response, show_reasoning=True
)
# Add assistant response to conversation history # Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content}) messages.append({"role": "assistant", "content": assistant_content})
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n👋 Chat interrupted. Goodbye!") print("\n👋 Chat interrupted. Goodbye!")
break break
@@ -485,50 +495,49 @@ class APIDemo:
def main(): def main():
"""Main function with CLI switches for different tests""" """Main function with CLI switches for different tests"""
from lib.test_utils import test_args from lib.test_utils import test_args
# Add mandatory model argument # Add mandatory model argument
test_args.add_argument( test_args.add_argument(
"--model", "--model", required=True, help="Model to use for requests (required)"
required=True,
help="Model to use for requests (required)"
) )
# Add test mode arguments # Add test mode arguments
test_args.add_argument( test_args.add_argument(
"--completion", "--completion", action="store_true", help="Test completions endpoint"
action="store_true",
help="Test completions endpoint"
) )
test_args.add_argument( test_args.add_argument(
"--chat", "--chat",
action="store_true", action="store_true",
help="Test chat completions endpoint (non-streaming)" help="Test chat completions endpoint (non-streaming)",
) )
test_args.add_argument( test_args.add_argument(
"--chat-stream", "--chat-stream",
action="store_true", action="store_true",
help="Test chat completions endpoint with streaming" help="Test chat completions endpoint with streaming",
) )
test_args.add_argument( test_args.add_argument(
"--tools", "--tools",
action="store_true", action="store_true",
help="Test function calling with ls tool (non-streaming)" help="Test function calling with ls tool (non-streaming)",
) )
test_args.add_argument( test_args.add_argument(
"--interactive", "--interactive",
action="store_true", action="store_true",
help="Start interactive streaming chat session" help="Start interactive streaming chat session",
) )
args = test_args.parse_args() args = test_args.parse_args()
# Check that only one test mode is selected # Check that only one test mode is selected
test_modes = [ test_modes = [
args.completion, args.chat, args.chat_stream, args.completion,
args.tools, args.interactive args.chat,
args.chat_stream,
args.tools,
args.interactive,
] ]
selected_count = sum(test_modes) selected_count = sum(test_modes)
if selected_count == 0: if selected_count == 0:
print("Please specify exactly one test mode:") print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint") print(" --completion : Test completions endpoint")
@@ -536,27 +545,30 @@ def main():
print(" --chat-stream : Test chat completions endpoint with streaming") print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool (non-streaming)") print(" --tools : Test function calling with ls tool (non-streaming)")
print(" --interactive : Start interactive streaming chat session") print(" --interactive : Start interactive streaming chat session")
print(f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT") print(
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
sys.exit(1) sys.exit(1)
elif selected_count > 1: elif selected_count > 1:
print("Please specify exactly one test mode") print("Please specify exactly one test mode")
sys.exit(1) sys.exit(1)
try: try:
# Create the core API client # Create the core API client
client = APIClient( client = APIClient(
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key, api_key=args.api_key,
server_url=args.server_url server_url=args.server_url,
instance=args.instance,
) )
# Create tool manager and demo (passing the model parameter) # Create tool manager and demo (passing the model parameter)
tool_manager = ToolManager() tool_manager = ToolManager()
demo = APIDemo(client, args.model, tool_manager) demo = APIDemo(client, args.model, tool_manager)
print(f"Using model: {args.model}") print(f"Using model: {args.model}")
print("=" * 60) print("=" * 60)
# Run the selected test # Run the selected test
if args.completion: if args.completion:
demo.demo_completions() demo.demo_completions()
@@ -568,11 +580,11 @@ def main():
demo.demo_ls_tool() demo.demo_ls_tool()
elif args.interactive: elif args.interactive:
demo.interactive_chat() demo.interactive_chat()
except Exception as e: except Exception as e:
log.error(f"Error during test: {e}", exc_info=True) log.error(f"Error during test: {e}", exc_info=True)
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
+12 -8
View File
@@ -3,11 +3,13 @@ from dataclasses import dataclass, field, fields, is_dataclass
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
class SerializableDataclass: class SerializableDataclass:
def _serialize_recursive(self, obj: Any) -> Any: def _serialize_recursive(self, obj: Any) -> Any:
if is_dataclass(obj): if is_dataclass(obj):
return {field.name: self._serialize_recursive(getattr(obj, field.name)) return {
for field in fields(obj)} field.name: self._serialize_recursive(getattr(obj, field.name))
for field in fields(obj)
}
elif isinstance(obj, dict): elif isinstance(obj, dict):
return {key: self._serialize_recursive(value) for key, value in obj.items()} return {key: self._serialize_recursive(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)): elif isinstance(obj, (list, tuple)):
@@ -16,10 +18,10 @@ class SerializableDataclass:
return [self._serialize_recursive(item) for item in obj] return [self._serialize_recursive(item) for item in obj]
else: else:
return obj return obj
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return self._serialize_recursive(self) return self._serialize_recursive(self)
def to_json(self, indent: int = 2) -> str: def to_json(self, indent: int = 2) -> str:
return json.dumps(self.to_dict(), indent=indent) return json.dumps(self.to_dict(), indent=indent)
@@ -27,6 +29,7 @@ class SerializableDataclass:
@dataclass @dataclass
class CompletionConfig(SerializableDataclass): class CompletionConfig(SerializableDataclass):
"""Configuration for completion requests""" """Configuration for completion requests"""
model: str model: str
prompt: str = "Hello" prompt: str = "Hello"
max_tokens: int = 256 max_tokens: int = 256
@@ -39,8 +42,9 @@ class CompletionConfig(SerializableDataclass):
@dataclass @dataclass
class ChatCompletionConfig(SerializableDataclass): class ChatCompletionConfig(SerializableDataclass):
"""Configuration for chat completion requests""" """Configuration for chat completion requests"""
model: str model: str
messages: list = None messages: list = field(default_factory=list)
max_tokens: int = 2096 max_tokens: int = 2096
temperature: float = 0.7 temperature: float = 0.7
top_k: int = 20 top_k: int = 20
@@ -48,7 +52,7 @@ class ChatCompletionConfig(SerializableDataclass):
stream: bool = False stream: bool = False
tools: Optional[List[Dict[str, Any]]] = field(default_factory=list) tools: Optional[List[Dict[str, Any]]] = field(default_factory=list)
tool_choice: str = "auto" tool_choice: str = "auto"
def __post_init__(self): def __post_init__(self):
if self.messages is None: if self.messages is None:
self.messages = [{"role": "user", "content": "Hello"}] self.messages = [{"role": "user", "content": "Hello"}]
+41 -42
View File
@@ -2,7 +2,7 @@ import os, json, random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
from typing import Union, Type, Dict, Any from typing import Union, Type, Dict, Any, Optional
from aiohttp import web, ClientResponse from aiohttp import web, ClientResponse
import nltk import nltk
import logging import logging
@@ -10,41 +10,39 @@ import logging
nltk.download("words") nltk.download("words")
WORD_LIST = nltk.corpus.words.words() WORD_LIST = nltk.corpus.words.words()
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
""" """
Generic dataclass accepts any dictionary in input. Generic dataclass accepts any dictionary in input.
""" """
@dataclass @dataclass
class GenericData(ApiPayload, ABC): class GenericData(ApiPayload, ABC):
input: Dict[str, Any] input: Dict[str, Any]
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GenericData": def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
return cls( return cls(input=data["input"])
input=data["input"]
)
@classmethod @classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData": def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
errors = {} errors = {}
# Validate required parameters # Validate required parameters
required_params = ["input"] required_params = ["input"]
for param in required_params: for param in required_params:
if param not in json_msg: if param not in json_msg:
errors[param] = "missing parameter" errors[param] = "missing parameter"
if errors: if errors:
raise JsonDataException(errors) raise JsonDataException(errors)
try: try:
# Create clean data dict and delegate to from_dict # Create clean data dict and delegate to from_dict
clean_data = { clean_data = {"input": json_msg["input"]}
"input": json_msg["input"]
}
return cls.from_dict(clean_data) return cls.from_dict(clean_data)
except (json.JSONDecodeError, JsonDataException) as e: except (json.JSONDecodeError, JsonDataException) as e:
errors["parameters"] = str(e) errors["parameters"] = str(e)
raise JsonDataException(errors) raise JsonDataException(errors)
@@ -59,7 +57,8 @@ class GenericData(ApiPayload, ABC):
def count_workload(self) -> int: def count_workload(self) -> int:
return self.input.get("max_tokens", 0) return self.input.get("max_tokens", 0)
@dataclass @dataclass
class GenericHandler(EndpointHandler[GenericData], ABC): class GenericHandler(EndpointHandler[GenericData], ABC):
@@ -67,10 +66,10 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
@abstractmethod @abstractmethod
def endpoint(self) -> str: def endpoint(self) -> str:
pass pass
@property @property
def healthcheck_endpoint(self) -> str: def healthcheck_endpoint(self) -> Optional[str]:
return os.environ.get('MODEL_HEALTH_ENDPOINT') return os.environ.get("MODEL_HEALTH_ENDPOINT")
@classmethod @classmethod
def payload_cls(cls) -> Type[GenericData]: def payload_cls(cls) -> Type[GenericData]:
@@ -82,17 +81,17 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
async def generate_client_response( async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]: ) -> Union[web.Response, web.StreamResponse]:
match model_response.status: match model_response.status:
case 200: case 200:
# Check if the response is actually streaming based on response headers/content-type # Check if the response is actually streaming based on response headers/content-type
is_streaming_response = ( is_streaming_response = (
model_response.content_type == "text/event-stream" or model_response.content_type == "text/event-stream"
model_response.content_type == "application/x-ndjson" or or model_response.content_type == "application/x-ndjson"
model_response.headers.get("Transfer-Encoding") == "chunked" or or model_response.headers.get("Transfer-Encoding") == "chunked"
"stream" in model_response.content_type.lower() or "stream" in model_response.content_type.lower()
) )
if is_streaming_response: if is_streaming_response:
log.debug("Detected streaming response...") log.debug("Detected streaming response...")
res = web.StreamResponse() res = web.StreamResponse()
@@ -109,12 +108,13 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
return web.Response( return web.Response(
body=content, body=content,
status=200, status=200,
content_type=model_response.content_type content_type=model_response.content_type,
) )
case code: case code:
log.debug("SENDING RESPONSE: ERROR: unknown code") log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code) return web.Response(status=code)
@dataclass @dataclass
class CompletionsData(GenericData): class CompletionsData(GenericData):
@classmethod @classmethod
@@ -123,55 +123,54 @@ class CompletionsData(GenericData):
model = os.environ.get("MODEL_NAME") model = os.environ.get("MODEL_NAME")
if not model: if not model:
raise ValueError("MODEL_NAME environment variable not set") raise ValueError("MODEL_NAME environment variable not set")
test_input = { test_input = {"model": model, "prompt": prompt, "temperature": 0.7}
"model": model,
"prompt": prompt,
"temperature": 0.7
}
return cls(input=test_input) return cls(input=test_input)
@dataclass @dataclass
class CompletionsHandler(GenericHandler): class CompletionsHandler(GenericHandler):
@property @property
def endpoint(self) -> str: def endpoint(self) -> str:
return "/v1/completions" return "/v1/completions"
@classmethod @classmethod
def payload_cls(cls) -> Type[CompletionsData]: def payload_cls(cls) -> Type[CompletionsData]:
return CompletionsData return CompletionsData
def make_benchmark_payload(self) -> CompletionsData: def make_benchmark_payload(self) -> CompletionsData:
return CompletionsData.for_test() return CompletionsData.for_test()
@dataclass @dataclass
class ChatCompletionsData(GenericData): class ChatCompletionsData(GenericData):
"""Chat completions-specific data implementation""" """Chat completions-specific data implementation"""
@classmethod @classmethod
def for_test(cls) -> "ChatCompletionsData": def for_test(cls) -> "ChatCompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250))) prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME") model = os.environ.get("MODEL_NAME")
if not model: if not model:
raise ValueError("MODEL_NAME environment variable not set") raise ValueError("MODEL_NAME environment variable not set")
# Chat completions use messages format instead of prompt # Chat completions use messages format instead of prompt
test_input = { test_input = {
"model": model, "model": model,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"temperature": 0.7 "temperature": 0.7,
} }
return cls(input=test_input) return cls(input=test_input)
@dataclass
class ChatCompletionsHandler(GenericHandler): @dataclass
class ChatCompletionsHandler(GenericHandler):
@property @property
def endpoint(self) -> str: def endpoint(self) -> str:
return "/v1/chat/completions" return "/v1/chat/completions"
@classmethod @classmethod
def payload_cls(cls) -> Type[ChatCompletionsData]: def payload_cls(cls) -> Type[ChatCompletionsData]:
return ChatCompletionsData return ChatCompletionsData
def make_benchmark_payload(self) -> ChatCompletionsData: def make_benchmark_payload(self) -> ChatCompletionsData:
return ChatCompletionsData.for_test() return ChatCompletionsData.for_test()
+15 -13
View File
@@ -7,20 +7,20 @@ from lib.server import start_server
# This line indicates that the inference server is listening # This line indicates that the inference server is listening
MODEL_SERVER_START_LOG_MSG = [ MODEL_SERVER_START_LOG_MSG = [
"Application startup complete.", # vLLM "Application startup complete.", # vLLM
"llama runner started", # Ollama "llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI '"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI '"message":"Connected","target":"text_generation_router::server"', # TGI
] ]
MODEL_SERVER_ERROR_LOG_MSGS = [ MODEL_SERVER_ERROR_LOG_MSGS = [
"INFO exited: vllm", # vLLM "INFO exited: vllm", # vLLM
"RuntimeError: Engine", # vLLM "RuntimeError: Engine", # vLLM
"Error: pull model manifest:", # Ollama "Error: pull model manifest:", # Ollama
"stalled; retrying", # Ollama "stalled; retrying", # Ollama
"Error: WebserverFailed", # TGI "Error: WebserverFailed", # TGI
"Error: DownloadError", # TGI "Error: DownloadError", # TGI
"Error: ShardCannotStart", #TGI "Error: ShardCannotStart", # TGI
] ]
logging.basicConfig( logging.basicConfig(
@@ -31,8 +31,8 @@ logging.basicConfig(
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
backend = Backend( backend = Backend(
model_server_url=os.environ.get("MODEL_SERVER_URL"), model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ.get("MODEL_LOG"), model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True, allow_parallel_requests=True,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256), benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[ log_actions=[
@@ -45,9 +45,11 @@ backend = Backend(
], ],
) )
async def handle_ping(_): async def handle_ping(_):
return web.Response(body="pong") return web.Response(body="pong")
routes = [ routes = [
web.post("/v1/completions", backend.create_handler(CompletionsHandler())), web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())), web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
+8 -8
View File
@@ -7,22 +7,22 @@ WORKER_ENDPOINT = "/v1/completions"
if __name__ == "__main__": if __name__ == "__main__":
# Check if MODEL_NAME environment variable is set # Check if MODEL_NAME environment variable is set
model_name_set = os.environ.get("MODEL_NAME") is not None model_name_set = os.environ.get("MODEL_NAME") is not None
# Add model argument - required only if MODEL_NAME is not set # Add model argument - required only if MODEL_NAME is not set
test_args.add_argument( test_args.add_argument(
"--model", "--model",
dest="model", dest="model",
required=not model_name_set, required=not model_name_set,
help="Model to use for completions request (required if MODEL_NAME env var not set)" help="Model to use for completions request (required if MODEL_NAME env var not set)",
) )
# Parse known args to get model early, before test_load_cmd adds its args # Parse known args to get model early, before test_load_cmd adds its args
known_args, _ = test_args.parse_known_args() known_args, _ = test_args.parse_known_args()
# Set environment variable if model was provided # Set environment variable if model was provided
if hasattr(known_args, 'model') and known_args.model: if hasattr(known_args, "model") and known_args.model:
os.environ["MODEL_NAME"] = known_args.model os.environ["MODEL_NAME"] = known_args.model
print(f"Set MODEL_NAME environment variable to: {known_args.model}") print(f"Set MODEL_NAME environment variable to: {known_args.model}")
# Now call test_load_cmd normally - it will add its own args and re-parse # Now call test_load_cmd normally - it will add its own args and re-parse
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args) test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)