fix pyright errors + revert to old way of handling cancelled api requests

This commit is contained in:
Nader Arbabian
2025-07-17 15:18:21 -07:00
parent 9e369c55a5
commit 4ac51947b4
7 changed files with 265 additions and 234 deletions
+16 -2
View File
@@ -5,7 +5,7 @@ import base64
import subprocess import subprocess
import dataclasses import dataclasses
import logging import logging
from asyncio import sleep, gather, Semaphore from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property from functools import cached_property
from distutils.util import strtobool from distutils.util import strtobool
@@ -123,6 +123,12 @@ class Backend:
return web.json_response(dict(error="invalid JSON"), status=422) return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload() workload = payload.count_workload()
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]: async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}") log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum) self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
@@ -168,7 +174,15 @@ class Backend:
return web.Response(status=401) return web.Response(status=401)
try: try:
return await make_request() done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
[task.cancel() for task in pending]
return done.pop().result()
except Exception as e: except Exception as e:
log.debug(f"Exception in main handler loop {e}") log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500) return web.Response(status=500)
+1 -1
View File
@@ -1,4 +1,4 @@
aiohttp~=3.11 aiohttp==3.10.1
anyio~=4.4 anyio~=4.4
lib~=4.0 lib~=4.0
nltk~=3.9 nltk~=3.9
+93 -81
View File
@@ -19,6 +19,7 @@ COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?" TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
class APIClient: class APIClient:
"""Lightweight client focused solely on API communication""" """Lightweight client focused solely on API communication"""
@@ -26,10 +27,13 @@ class APIClient:
DEFAULT_COST = 100 DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4 DEFAULT_TIMEOUT = 4
def __init__(self, endpoint_group_name: str, api_key: str, server_url: str): def __init__(
self, endpoint_group_name: str, api_key: str, server_url: str, instance: str
):
self.endpoint_group_name = endpoint_group_name self.endpoint_group_name = endpoint_group_name
self.api_key = api_key self.api_key = api_key
self.server_url = server_url self.server_url = server_url
self.instance = instance
self.endpoint_api_key = self._get_endpoint_api_key() self.endpoint_api_key = self._get_endpoint_api_key()
def _get_endpoint_api_key(self) -> Optional[str]: def _get_endpoint_api_key(self) -> Optional[str]:
@@ -37,6 +41,7 @@ class APIClient:
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=self.endpoint_group_name, endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key, account_api_key=self.api_key,
instance=self.instance,
) )
if not endpoint_api_key: if not endpoint_api_key:
log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}") log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}")
@@ -71,21 +76,21 @@ class APIClient:
"url": message["url"], "url": message["url"],
} }
def _make_request(self, payload: Dict[str, Any], endpoint: str, method: str = "POST", def _make_request(
stream: bool = False) -> Union[Dict[str, Any], Iterator[str]]: self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint""" """Make request directly to the specific worker endpoint"""
# Get worker URL and auth data # Get worker URL and auth data
cost = payload.get('max_tokens') cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost) message = self._get_worker_url(cost=cost)
worker_url = message["url"] worker_url = message["url"]
auth_data = self._create_auth_data(message) auth_data = self._create_auth_data(message)
req_data = { req_data = {"payload": {"input": payload}, "auth_data": auth_data}
"payload": {
"input": payload
},
"auth_data": auth_data
}
url = urljoin(worker_url, endpoint) url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}") log.debug(f"Making direct request to: {url}")
@@ -124,23 +129,22 @@ class APIClient:
log.error(f"Error handling streaming response: {e}") log.error(f"Error handling streaming response: {e}")
raise raise
def call_completions(
def call_completions(self, config: CompletionConfig) -> Union[Dict[str, Any], Iterator[str]]: self, config: CompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict() payload = config.to_dict()
return self._make_request( return self._make_request(
payload=payload, payload=payload, endpoint="/v1/completions", stream=config.stream
endpoint="/v1/completions",
stream=config.stream
) )
def call_chat_completions(self, config: ChatCompletionConfig) -> Union[Dict[str, Any], Iterator[str]]: def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict() payload = config.to_dict()
return self._make_request( return self._make_request(
payload=payload, payload=payload, endpoint="/v1/chat/completions", stream=config.stream
endpoint="/v1/chat/completions",
stream=config.stream
) )
@@ -151,7 +155,9 @@ class ToolManager:
def list_files() -> str: def list_files() -> str:
"""Execute ls on current directory""" """Execute ls on current directory"""
try: try:
result = subprocess.run(['ls', '-la', '.'], capture_output=True, text=True, timeout=10) result = subprocess.run(
["ls", "-la", "."], capture_output=True, text=True, timeout=10
)
if result.returncode == 0: if result.returncode == 0:
return result.stdout return result.stdout
else: else:
@@ -162,18 +168,16 @@ class ToolManager:
@staticmethod @staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]: def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""Get the ls tool definition""" """Get the ls tool definition"""
return [{ return [
"type": "function", {
"function": { "type": "function",
"name": "list_files", "function": {
"description": "List files and directories in the cwd", "name": "list_files",
"parameters": { "description": "List files and directories in the cwd",
"type": "object", "parameters": {"type": "object", "properties": {}, "required": []},
"properties": {}, },
"required": []
}
} }
}] ]
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str: def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result""" """Execute a tool call and return the result"""
@@ -188,12 +192,16 @@ class ToolManager:
class APIDemo: class APIDemo:
"""Demo and testing functionality for the API client""" """Demo and testing functionality for the API client"""
def __init__(self, client: APIClient, model: str, tool_manager: ToolManager = None): def __init__(
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
self.client = client self.client = client
self.model = model self.model = model
self.tool_manager = tool_manager or ToolManager() self.tool_manager = tool_manager or ToolManager()
def handle_streaming_response(self, response_stream, show_reasoning: bool = True) -> str: def handle_streaming_response(
self, response_stream, show_reasoning: bool = True
) -> str:
""" """
Handle streaming chat response and display all output. Handle streaming chat response and display all output.
""" """
@@ -260,27 +268,25 @@ class APIDemo:
return full_response return full_response
def test_tool_support(self) -> bool: def test_tool_support(self) -> bool:
"""Test if the endpoint supports function calling""" """Test if the endpoint supports function calling"""
log.debug("Testing endpoint tool calling support...") log.debug("Testing endpoint tool calling support...")
# Try a simple request with minimal tools to test support # Try a simple request with minimal tools to test support
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [{ minimal_tool = [
"type": "function", {
"function": { "type": "function",
"name": "test_function", "function": {"name": "test_function", "description": "Test function"},
"description": "Test function"
} }
}] ]
config = ChatCompletionConfig( config = ChatCompletionConfig(
model=self.model, model=self.model,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
tools=minimal_tool, tools=minimal_tool,
tool_choice="none" # Don't actually call the tool tool_choice="none", # Don't actually call the tool
) )
try: try:
@@ -297,12 +303,12 @@ class APIDemo:
print("=" * 60) print("=" * 60)
config = CompletionConfig( config = CompletionConfig(
model=self.model, model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
prompt=COMPLETIONS_PROMPT,
stream=False
) )
log.info(f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'") log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
)
response = self.client.call_completions(config) response = self.client.call_completions(config)
if isinstance(response, dict): if isinstance(response, dict):
@@ -316,7 +322,9 @@ class APIDemo:
Demo: test chat completions endpoint with optional streaming Demo: test chat completions endpoint with optional streaming
""" """
print("=" * 60) print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}") print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60) print("=" * 60)
config = ChatCompletionConfig( config = ChatCompletionConfig(
@@ -334,6 +342,7 @@ class APIDemo:
except Exception as e: except Exception as e:
log.error(f"\nError during streaming: {e}") log.error(f"\nError during streaming: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return return
@@ -342,7 +351,9 @@ class APIDemo:
choice = response.get("choices", [{}])[0] choice = response.get("choices", [{}])[0]
message = choice.get("message", {}) message = choice.get("message", {})
content = message.get("content", "") content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "") reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning: if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m") print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
@@ -353,8 +364,6 @@ class APIDemo:
else: else:
log.error("Unexpected response format") log.error("Unexpected response format")
def demo_ls_tool(self) -> None: def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees""" """Demo: ask LLM to list files in the current directory and describe what it sees"""
print("=" * 60) print("=" * 60)
@@ -366,15 +375,13 @@ class APIDemo:
return return
# Request with tool available # Request with tool available
messages = [ messages = [{"role": "user", "content": TOOLS_PROMPT}]
{"role": "user", "content": TOOLS_PROMPT}
]
config = ChatCompletionConfig( config = ChatCompletionConfig(
model=self.model, model=self.model,
messages=messages, messages=messages,
tools=self.tool_manager.get_ls_tool_definition(), tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto" tool_choice="auto",
) )
log.info(f"Making initial request with tool using model '{self.model}'...") log.info(f"Making initial request with tool using model '{self.model}'...")
@@ -391,7 +398,9 @@ class APIDemo:
# Check for tool calls # Check for tool calls
tool_calls = message.get("tool_calls") tool_calls = message.get("tool_calls")
if not tool_calls: if not tool_calls:
raise ValueError("No tool calls made - model may not support function calling") raise ValueError(
"No tool calls made - model may not support function calling"
)
print(f"Tool calls detected: {len(tool_calls)}") print(f"Tool calls detected: {len(tool_calls)}")
@@ -405,17 +414,19 @@ class APIDemo:
# Add tool result and continue conversation # Add tool result and continue conversation
messages.append(message) # Add assistant's message with tool call messages.append(message) # Add assistant's message with tool call
messages.append({ messages.append(
"role": "tool", {
"tool_call_id": tool_call["id"], "role": "tool",
"content": tool_result "tool_call_id": tool_call["id"],
}) "content": tool_result,
}
)
# Get final response # Get final response
final_config = ChatCompletionConfig( final_config = ChatCompletionConfig(
model=self.model, model=self.model,
messages=messages, messages=messages,
tools=self.tool_manager.get_ls_tool_definition() tools=self.tool_manager.get_ls_tool_definition(),
) )
print("Getting final response...") print("Getting final response...")
@@ -447,10 +458,10 @@ class APIDemo:
try: try:
user_input = input("You: ").strip() user_input = input("You: ").strip()
if user_input.lower() == 'quit': if user_input.lower() == "quit":
print("👋 Goodbye!") print("👋 Goodbye!")
break break
elif user_input.lower() == 'clear': elif user_input.lower() == "clear":
messages = [] messages = []
print("Chat history cleared") print("Chat history cleared")
continue continue
@@ -460,16 +471,15 @@ class APIDemo:
messages.append({"role": "user", "content": user_input}) messages.append({"role": "user", "content": user_input})
config = ChatCompletionConfig( config = ChatCompletionConfig(
model=self.model, model=self.model, messages=messages, stream=True, temperature=0.7
messages=messages,
stream=True,
temperature=0.7
) )
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
response = self.client.call_chat_completions(config) response = self.client.call_chat_completions(config)
assistant_content = self.handle_streaming_response(response, show_reasoning=True) assistant_content = self.handle_streaming_response(
response, show_reasoning=True
)
# Add assistant response to conversation history # Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content}) messages.append({"role": "assistant", "content": assistant_content})
@@ -488,44 +498,43 @@ def main():
# Add mandatory model argument # Add mandatory model argument
test_args.add_argument( test_args.add_argument(
"--model", "--model", required=True, help="Model to use for requests (required)"
required=True,
help="Model to use for requests (required)"
) )
# Add test mode arguments # Add test mode arguments
test_args.add_argument( test_args.add_argument(
"--completion", "--completion", action="store_true", help="Test completions endpoint"
action="store_true",
help="Test completions endpoint"
) )
test_args.add_argument( test_args.add_argument(
"--chat", "--chat",
action="store_true", action="store_true",
help="Test chat completions endpoint (non-streaming)" help="Test chat completions endpoint (non-streaming)",
) )
test_args.add_argument( test_args.add_argument(
"--chat-stream", "--chat-stream",
action="store_true", action="store_true",
help="Test chat completions endpoint with streaming" help="Test chat completions endpoint with streaming",
) )
test_args.add_argument( test_args.add_argument(
"--tools", "--tools",
action="store_true", action="store_true",
help="Test function calling with ls tool (non-streaming)" help="Test function calling with ls tool (non-streaming)",
) )
test_args.add_argument( test_args.add_argument(
"--interactive", "--interactive",
action="store_true", action="store_true",
help="Start interactive streaming chat session" help="Start interactive streaming chat session",
) )
args = test_args.parse_args() args = test_args.parse_args()
# Check that only one test mode is selected # Check that only one test mode is selected
test_modes = [ test_modes = [
args.completion, args.chat, args.chat_stream, args.completion,
args.tools, args.interactive args.chat,
args.chat_stream,
args.tools,
args.interactive,
] ]
selected_count = sum(test_modes) selected_count = sum(test_modes)
@@ -536,7 +545,9 @@ def main():
print(" --chat-stream : Test chat completions endpoint with streaming") print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool (non-streaming)") print(" --tools : Test function calling with ls tool (non-streaming)")
print(" --interactive : Start interactive streaming chat session") print(" --interactive : Start interactive streaming chat session")
print(f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT") print(
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
sys.exit(1) sys.exit(1)
elif selected_count > 1: elif selected_count > 1:
print("Please specify exactly one test mode") print("Please specify exactly one test mode")
@@ -547,7 +558,8 @@ def main():
client = APIClient( client = APIClient(
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key, api_key=args.api_key,
server_url=args.server_url server_url=args.server_url,
instance=args.instance,
) )
# Create tool manager and demo (passing the model parameter) # Create tool manager and demo (passing the model parameter)
+7 -3
View File
@@ -6,8 +6,10 @@ from typing import Optional, List, Dict, Any
class SerializableDataclass: class SerializableDataclass:
def _serialize_recursive(self, obj: Any) -> Any: def _serialize_recursive(self, obj: Any) -> Any:
if is_dataclass(obj): if is_dataclass(obj):
return {field.name: self._serialize_recursive(getattr(obj, field.name)) return {
for field in fields(obj)} field.name: self._serialize_recursive(getattr(obj, field.name))
for field in fields(obj)
}
elif isinstance(obj, dict): elif isinstance(obj, dict):
return {key: self._serialize_recursive(value) for key, value in obj.items()} return {key: self._serialize_recursive(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)): elif isinstance(obj, (list, tuple)):
@@ -27,6 +29,7 @@ class SerializableDataclass:
@dataclass @dataclass
class CompletionConfig(SerializableDataclass): class CompletionConfig(SerializableDataclass):
"""Configuration for completion requests""" """Configuration for completion requests"""
model: str model: str
prompt: str = "Hello" prompt: str = "Hello"
max_tokens: int = 256 max_tokens: int = 256
@@ -39,8 +42,9 @@ class CompletionConfig(SerializableDataclass):
@dataclass @dataclass
class ChatCompletionConfig(SerializableDataclass): class ChatCompletionConfig(SerializableDataclass):
"""Configuration for chat completion requests""" """Configuration for chat completion requests"""
model: str model: str
messages: list = None messages: list = field(default_factory=list)
max_tokens: int = 2096 max_tokens: int = 2096
temperature: float = 0.7 temperature: float = 0.7
top_k: int = 20 top_k: int = 20
+20 -21
View File
@@ -2,7 +2,7 @@ import os, json, random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
from typing import Union, Type, Dict, Any from typing import Union, Type, Dict, Any, Optional
from aiohttp import web, ClientResponse from aiohttp import web, ClientResponse
import nltk import nltk
import logging import logging
@@ -14,15 +14,15 @@ log = logging.getLogger(__name__)
""" """
Generic dataclass accepts any dictionary in input. Generic dataclass accepts any dictionary in input.
""" """
@dataclass @dataclass
class GenericData(ApiPayload, ABC): class GenericData(ApiPayload, ABC):
input: Dict[str, Any] input: Dict[str, Any]
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GenericData": def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
return cls( return cls(input=data["input"])
input=data["input"]
)
@classmethod @classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData": def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
@@ -39,9 +39,7 @@ class GenericData(ApiPayload, ABC):
try: try:
# Create clean data dict and delegate to from_dict # Create clean data dict and delegate to from_dict
clean_data = { clean_data = {"input": json_msg["input"]}
"input": json_msg["input"]
}
return cls.from_dict(clean_data) return cls.from_dict(clean_data)
@@ -60,6 +58,7 @@ class GenericData(ApiPayload, ABC):
def count_workload(self) -> int: def count_workload(self) -> int:
return self.input.get("max_tokens", 0) return self.input.get("max_tokens", 0)
@dataclass @dataclass
class GenericHandler(EndpointHandler[GenericData], ABC): class GenericHandler(EndpointHandler[GenericData], ABC):
@@ -69,8 +68,8 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
pass pass
@property @property
def healthcheck_endpoint(self) -> str: def healthcheck_endpoint(self) -> Optional[str]:
return os.environ.get('MODEL_HEALTH_ENDPOINT') return os.environ.get("MODEL_HEALTH_ENDPOINT")
@classmethod @classmethod
def payload_cls(cls) -> Type[GenericData]: def payload_cls(cls) -> Type[GenericData]:
@@ -82,15 +81,15 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
async def generate_client_response( async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]: ) -> Union[web.Response, web.StreamResponse]:
match model_response.status: match model_response.status:
case 200: case 200:
# Check if the response is actually streaming based on response headers/content-type # Check if the response is actually streaming based on response headers/content-type
is_streaming_response = ( is_streaming_response = (
model_response.content_type == "text/event-stream" or model_response.content_type == "text/event-stream"
model_response.content_type == "application/x-ndjson" or or model_response.content_type == "application/x-ndjson"
model_response.headers.get("Transfer-Encoding") == "chunked" or or model_response.headers.get("Transfer-Encoding") == "chunked"
"stream" in model_response.content_type.lower() or "stream" in model_response.content_type.lower()
) )
if is_streaming_response: if is_streaming_response:
@@ -109,12 +108,13 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
return web.Response( return web.Response(
body=content, body=content,
status=200, status=200,
content_type=model_response.content_type content_type=model_response.content_type,
) )
case code: case code:
log.debug("SENDING RESPONSE: ERROR: unknown code") log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code) return web.Response(status=code)
@dataclass @dataclass
class CompletionsData(GenericData): class CompletionsData(GenericData):
@classmethod @classmethod
@@ -124,13 +124,10 @@ class CompletionsData(GenericData):
if not model: if not model:
raise ValueError("MODEL_NAME environment variable not set") raise ValueError("MODEL_NAME environment variable not set")
test_input = { test_input = {"model": model, "prompt": prompt, "temperature": 0.7}
"model": model,
"prompt": prompt,
"temperature": 0.7
}
return cls(input=test_input) return cls(input=test_input)
@dataclass @dataclass
class CompletionsHandler(GenericHandler): class CompletionsHandler(GenericHandler):
@property @property
@@ -144,6 +141,7 @@ class CompletionsHandler(GenericHandler):
def make_benchmark_payload(self) -> CompletionsData: def make_benchmark_payload(self) -> CompletionsData:
return CompletionsData.for_test() return CompletionsData.for_test()
@dataclass @dataclass
class ChatCompletionsData(GenericData): class ChatCompletionsData(GenericData):
"""Chat completions-specific data implementation""" """Chat completions-specific data implementation"""
@@ -159,10 +157,11 @@ class ChatCompletionsData(GenericData):
test_input = { test_input = {
"model": model, "model": model,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"temperature": 0.7 "temperature": 0.7,
} }
return cls(input=test_input) return cls(input=test_input)
@dataclass @dataclass
class ChatCompletionsHandler(GenericHandler): class ChatCompletionsHandler(GenericHandler):
@property @property
+15 -13
View File
@@ -7,20 +7,20 @@ from lib.server import start_server
# This line indicates that the inference server is listening # This line indicates that the inference server is listening
MODEL_SERVER_START_LOG_MSG = [ MODEL_SERVER_START_LOG_MSG = [
"Application startup complete.", # vLLM "Application startup complete.", # vLLM
"llama runner started", # Ollama "llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI '"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI '"message":"Connected","target":"text_generation_router::server"', # TGI
] ]
MODEL_SERVER_ERROR_LOG_MSGS = [ MODEL_SERVER_ERROR_LOG_MSGS = [
"INFO exited: vllm", # vLLM "INFO exited: vllm", # vLLM
"RuntimeError: Engine", # vLLM "RuntimeError: Engine", # vLLM
"Error: pull model manifest:", # Ollama "Error: pull model manifest:", # Ollama
"stalled; retrying", # Ollama "stalled; retrying", # Ollama
"Error: WebserverFailed", # TGI "Error: WebserverFailed", # TGI
"Error: DownloadError", # TGI "Error: DownloadError", # TGI
"Error: ShardCannotStart", #TGI "Error: ShardCannotStart", # TGI
] ]
logging.basicConfig( logging.basicConfig(
@@ -31,8 +31,8 @@ logging.basicConfig(
log = logging.getLogger(__file__) log = logging.getLogger(__file__)
backend = Backend( backend = Backend(
model_server_url=os.environ.get("MODEL_SERVER_URL"), model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ.get("MODEL_LOG"), model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True, allow_parallel_requests=True,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256), benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[ log_actions=[
@@ -45,9 +45,11 @@ backend = Backend(
], ],
) )
async def handle_ping(_): async def handle_ping(_):
return web.Response(body="pong") return web.Response(body="pong")
routes = [ routes = [
web.post("/v1/completions", backend.create_handler(CompletionsHandler())), web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())), web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
+2 -2
View File
@@ -13,14 +13,14 @@ if __name__ == "__main__":
"--model", "--model",
dest="model", dest="model",
required=not model_name_set, required=not model_name_set,
help="Model to use for completions request (required if MODEL_NAME env var not set)" help="Model to use for completions request (required if MODEL_NAME env var not set)",
) )
# Parse known args to get model early, before test_load_cmd adds its args # Parse known args to get model early, before test_load_cmd adds its args
known_args, _ = test_args.parse_known_args() known_args, _ = test_args.parse_known_args()
# Set environment variable if model was provided # Set environment variable if model was provided
if hasattr(known_args, 'model') and known_args.model: if hasattr(known_args, "model") and known_args.model:
os.environ["MODEL_NAME"] = known_args.model os.environ["MODEL_NAME"] = known_args.model
print(f"Set MODEL_NAME environment variable to: {known_args.model}") print(f"Set MODEL_NAME environment variable to: {known_args.model}")