diff --git a/lib/backend.py b/lib/backend.py index dea39a3..117ea6d 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -5,7 +5,7 @@ import base64 import subprocess import dataclasses import logging -from asyncio import sleep, gather, Semaphore +from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from functools import cached_property from distutils.util import strtobool @@ -123,6 +123,12 @@ class Backend: return web.json_response(dict(error="invalid JSON"), status=422) workload = payload.count_workload() + async def cancel_api_call_if_disconnected() -> web.Response: + await request.wait_for_disconnection() + log.debug(f"request with reqnum: {auth_data.reqnum} was canceled") + self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum) + return web.Response(status=500) + async def make_request() -> Union[web.Response, web.StreamResponse]: log.debug(f"got request, {auth_data.reqnum}") self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum) @@ -168,7 +174,15 @@ class Backend: return web.Response(status=401) try: - return await make_request() + done, pending = await wait( + [ + create_task(make_request()), + create_task(cancel_api_call_if_disconnected()), + ], + return_when=FIRST_COMPLETED, + ) + [task.cancel() for task in pending] + return done.pop().result() except Exception as e: log.debug(f"Exception in main handler loop {e}") return web.Response(status=500) diff --git a/requirements.txt b/requirements.txt index 6e753c2..007aebc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aiohttp~=3.11 +aiohttp==3.10.1 anyio~=4.4 lib~=4.0 nltk~=3.9 diff --git a/workers/openai/client.py b/workers/openai/client.py index 2748aab..4dbf099 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -19,40 +19,45 @@ COMPLETIONS_PROMPT = "the capital of USA is" CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?" + class APIClient: """Lightweight client focused solely on API communication""" - + # Remove the generic WORKER_ENDPOINT since we're now going direct DEFAULT_COST = 100 DEFAULT_TIMEOUT = 4 - - def __init__(self, endpoint_group_name: str, api_key: str, server_url: str): + + def __init__( + self, endpoint_group_name: str, api_key: str, server_url: str, instance: str + ): self.endpoint_group_name = endpoint_group_name self.api_key = api_key self.server_url = server_url + self.instance = instance self.endpoint_api_key = self._get_endpoint_api_key() - + def _get_endpoint_api_key(self) -> Optional[str]: """Get the endpoint API key""" endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_name=self.endpoint_group_name, account_api_key=self.api_key, + instance=self.instance, ) if not endpoint_api_key: log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}") return endpoint_api_key - + def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]: """Get worker URL and auth data from routing service""" if not self.endpoint_api_key: raise ValueError("No valid endpoint API key available") - + route_payload = { "endpoint": self.endpoint_group_name, "api_key": self.endpoint_api_key, "cost": cost, } - + response = requests.post( urljoin(self.server_url, "/route/"), json=route_payload, @@ -60,7 +65,7 @@ class APIClient: ) response.raise_for_status() return response.json() - + def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]: """Create auth data from routing response""" return { @@ -70,27 +75,27 @@ class APIClient: "reqnum": message["reqnum"], "url": message["url"], } - - def _make_request(self, payload: Dict[str, Any], endpoint: str, method: str = "POST", - stream: bool = False) -> Union[Dict[str, Any], Iterator[str]]: + + def _make_request( + self, + payload: Dict[str, Any], + endpoint: str, + method: str = "POST", + stream: bool = False, + ) -> Union[Dict[str, Any], Iterator[str]]: """Make request directly to the specific worker endpoint""" # Get worker URL and auth data - cost = payload.get('max_tokens') + cost = payload.get("max_tokens", self.DEFAULT_COST) message = self._get_worker_url(cost=cost) worker_url = message["url"] auth_data = self._create_auth_data(message) - - req_data = { - "payload": { - "input": payload - }, - "auth_data": auth_data - } + + req_data = {"payload": {"input": payload}, "auth_data": auth_data} url = urljoin(worker_url, endpoint) log.debug(f"Making direct request to: {url}") log.debug(f"Payload: {req_data}") - + # Make the request using the specified method if method.upper() == "POST": response = requests.post(url, json=req_data, stream=stream) @@ -98,14 +103,14 @@ class APIClient: response = requests.get(url, params=req_data, stream=stream) else: raise ValueError(f"Unsupported HTTP method: {method}") - + response.raise_for_status() - + if stream: return self._handle_streaming_response(response) else: return response.json() - + def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]: """Handle streaming response and yield tokens""" try: @@ -124,61 +129,60 @@ class APIClient: log.error(f"Error handling streaming response: {e}") raise - - def call_completions(self, config: CompletionConfig) -> Union[Dict[str, Any], Iterator[str]]: - payload = config.to_dict() - - return self._make_request( - payload=payload, - endpoint="/v1/completions", - stream=config.stream - ) - - def call_chat_completions(self, config: ChatCompletionConfig) -> Union[Dict[str, Any], Iterator[str]]: + def call_completions( + self, config: CompletionConfig + ) -> Union[Dict[str, Any], Iterator[str]]: payload = config.to_dict() return self._make_request( - payload=payload, - endpoint="/v1/chat/completions", - stream=config.stream + payload=payload, endpoint="/v1/completions", stream=config.stream + ) + + def call_chat_completions( + self, config: ChatCompletionConfig + ) -> Union[Dict[str, Any], Iterator[str]]: + payload = config.to_dict() + + return self._make_request( + payload=payload, endpoint="/v1/chat/completions", stream=config.stream ) class ToolManager: """Handles tool definitions and execution""" - + @staticmethod def list_files() -> str: """Execute ls on current directory""" try: - result = subprocess.run(['ls', '-la', '.'], capture_output=True, text=True, timeout=10) + result = subprocess.run( + ["ls", "-la", "."], capture_output=True, text=True, timeout=10 + ) if result.returncode == 0: return result.stdout else: return f"Error: {result.stderr}" except Exception as e: return f"Error running ls: {e}" - + @staticmethod def get_ls_tool_definition() -> List[Dict[str, Any]]: """Get the ls tool definition""" - return [{ - "type": "function", - "function": { - "name": "list_files", - "description": "List files and directories in the cwd", - "parameters": { - "type": "object", - "properties": {}, - "required": [] - } + return [ + { + "type": "function", + "function": { + "name": "list_files", + "description": "List files and directories in the cwd", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, } - }] - + ] + def execute_tool_call(self, tool_call: Dict[str, Any]) -> str: """Execute a tool call and return the result""" function_name = tool_call["function"]["name"] - + if function_name == "list_files": return self.list_files() else: @@ -187,13 +191,17 @@ class ToolManager: class APIDemo: """Demo and testing functionality for the API client""" - - def __init__(self, client: APIClient, model: str, tool_manager: ToolManager = None): + + def __init__( + self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None + ): self.client = client self.model = model self.tool_manager = tool_manager or ToolManager() - - def handle_streaming_response(self, response_stream, show_reasoning: bool = True) -> str: + + def handle_streaming_response( + self, response_stream, show_reasoning: bool = True + ) -> str: """ Handle streaming chat response and display all output. """ @@ -260,178 +268,181 @@ class APIDemo: return full_response - def test_tool_support(self) -> bool: """Test if the endpoint supports function calling""" log.debug("Testing endpoint tool calling support...") - + # Try a simple request with minimal tools to test support messages = [{"role": "user", "content": "Hello"}] - minimal_tool = [{ - "type": "function", - "function": { - "name": "test_function", - "description": "Test function" + minimal_tool = [ + { + "type": "function", + "function": {"name": "test_function", "description": "Test function"}, } - }] - + ] + config = ChatCompletionConfig( model=self.model, messages=messages, max_tokens=10, tools=minimal_tool, - tool_choice="none" # Don't actually call the tool + tool_choice="none", # Don't actually call the tool ) - + try: response = self.client.call_chat_completions(config) return True except Exception as e: log.error(f"Error: Endpoint does not support tool calling: {e}") return False - + def demo_completions(self) -> None: """Demo: test basic completions endpoint""" print("=" * 60) print("COMPLETIONS DEMO") print("=" * 60) - + config = CompletionConfig( - model=self.model, - prompt=COMPLETIONS_PROMPT, - stream=False + model=self.model, prompt=COMPLETIONS_PROMPT, stream=False + ) + + log.info( + f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'" ) - - log.info(f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'") response = self.client.call_completions(config) - + if isinstance(response, dict): print("\nResponse:") print(json.dumps(response, indent=2)) else: log.error("Unexpected response format") - + def demo_chat(self, use_streaming: bool = True) -> None: """ Demo: test chat completions endpoint with optional streaming """ print("=" * 60) - print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}") + print( + f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}" + ) print("=" * 60) - + config = ChatCompletionConfig( model=self.model, messages=[{"role": "user", "content": CHAT_PROMPT}], stream=use_streaming, ) - + log.info(f"Testing chat completions with model '{self.model}'...") response = self.client.call_chat_completions(config) - + if use_streaming: try: self.handle_streaming_response(response, show_reasoning=True) except Exception as e: log.error(f"\nError during streaming: {e}") import traceback + traceback.print_exc() return - + else: if isinstance(response, dict): choice = response.get("choices", [{}])[0] message = choice.get("message", {}) content = message.get("content", "") - reasoning = message.get("reasoning_content", "") or message.get("reasoning", "") - + reasoning = message.get("reasoning_content", "") or message.get( + "reasoning", "" + ) + if reasoning: print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m") - + print(f"\nšŸ’¬ Assistant: {content}") print(f"\nFull Response:") print(json.dumps(response, indent=2)) else: log.error("Unexpected response format") - - def demo_ls_tool(self) -> None: """Demo: ask LLM to list files in the current directory and describe what it sees""" print("=" * 60) print("TOOL USE DEMO: List Directory Contents") print("=" * 60) - + # Test if tools are supported first if not self.test_tool_support(): return - + # Request with tool available - messages = [ - {"role": "user", "content": TOOLS_PROMPT} - ] - + messages = [{"role": "user", "content": TOOLS_PROMPT}] + config = ChatCompletionConfig( model=self.model, messages=messages, tools=self.tool_manager.get_ls_tool_definition(), - tool_choice="auto" + tool_choice="auto", ) - + log.info(f"Making initial request with tool using model '{self.model}'...") response = self.client.call_chat_completions(config) - + if not isinstance(response, dict): raise ValueError("Expected dict response for tool use") - + choice = response.get("choices", [{}])[0] message = choice.get("message", {}) - + print(f"Assistant response: {message.get('content', 'No content')}") - + # Check for tool calls tool_calls = message.get("tool_calls") if not tool_calls: - raise ValueError("No tool calls made - model may not support function calling") - + raise ValueError( + "No tool calls made - model may not support function calling" + ) + print(f"Tool calls detected: {len(tool_calls)}") - + # Execute the tool call for tool_call in tool_calls: function_name = tool_call["function"]["name"] print(f"Executing tool: {function_name}") - + tool_result = self.tool_manager.execute_tool_call(tool_call) print(f"Tool result:\n{tool_result}") - + # Add tool result and continue conversation messages.append(message) # Add assistant's message with tool call - messages.append({ - "role": "tool", - "tool_call_id": tool_call["id"], - "content": tool_result - }) - + messages.append( + { + "role": "tool", + "tool_call_id": tool_call["id"], + "content": tool_result, + } + ) + # Get final response final_config = ChatCompletionConfig( model=self.model, messages=messages, - tools=self.tool_manager.get_ls_tool_definition() + tools=self.tool_manager.get_ls_tool_definition(), ) - + print("Getting final response...") final_response = self.client.call_chat_completions(final_config) - + if isinstance(final_response, dict): final_choice = final_response.get("choices", [{}])[0] final_message = final_choice.get("message", {}) final_content = final_message.get("content", "") - + print("\n" + "=" * 60) print("FINAL LLM ANALYSIS:") print("=" * 60) print(final_content) print("=" * 60) - + def interactive_chat(self) -> None: """Interactive chat session with streaming""" print("=" * 60) @@ -440,40 +451,39 @@ class APIDemo: print(f"Using model: {self.model}") print("Type 'quit' to exit, 'clear' to clear history") print() - + messages = [] - + while True: try: user_input = input("You: ").strip() - - if user_input.lower() == 'quit': + + if user_input.lower() == "quit": print("šŸ‘‹ Goodbye!") break - elif user_input.lower() == 'clear': + elif user_input.lower() == "clear": messages = [] print("Chat history cleared") continue elif not user_input: continue - + messages.append({"role": "user", "content": user_input}) - + config = ChatCompletionConfig( - model=self.model, - messages=messages, - stream=True, - temperature=0.7 + model=self.model, messages=messages, stream=True, temperature=0.7 ) - + print("Assistant: ", end="", flush=True) - + response = self.client.call_chat_completions(config) - assistant_content = self.handle_streaming_response(response, show_reasoning=True) - + assistant_content = self.handle_streaming_response( + response, show_reasoning=True + ) + # Add assistant response to conversation history messages.append({"role": "assistant", "content": assistant_content}) - + except KeyboardInterrupt: print("\nšŸ‘‹ Chat interrupted. Goodbye!") break @@ -485,50 +495,49 @@ class APIDemo: def main(): """Main function with CLI switches for different tests""" from lib.test_utils import test_args - + # Add mandatory model argument test_args.add_argument( - "--model", - required=True, - help="Model to use for requests (required)" + "--model", required=True, help="Model to use for requests (required)" ) - + # Add test mode arguments test_args.add_argument( - "--completion", - action="store_true", - help="Test completions endpoint" + "--completion", action="store_true", help="Test completions endpoint" ) test_args.add_argument( - "--chat", + "--chat", action="store_true", - help="Test chat completions endpoint (non-streaming)" + help="Test chat completions endpoint (non-streaming)", ) test_args.add_argument( - "--chat-stream", + "--chat-stream", action="store_true", - help="Test chat completions endpoint with streaming" + help="Test chat completions endpoint with streaming", ) test_args.add_argument( - "--tools", + "--tools", action="store_true", - help="Test function calling with ls tool (non-streaming)" + help="Test function calling with ls tool (non-streaming)", ) test_args.add_argument( - "--interactive", + "--interactive", action="store_true", - help="Start interactive streaming chat session" + help="Start interactive streaming chat session", ) - + args = test_args.parse_args() - + # Check that only one test mode is selected test_modes = [ - args.completion, args.chat, args.chat_stream, - args.tools, args.interactive + args.completion, + args.chat, + args.chat_stream, + args.tools, + args.interactive, ] selected_count = sum(test_modes) - + if selected_count == 0: print("Please specify exactly one test mode:") print(" --completion : Test completions endpoint") @@ -536,27 +545,30 @@ def main(): print(" --chat-stream : Test chat completions endpoint with streaming") print(" --tools : Test function calling with ls tool (non-streaming)") print(" --interactive : Start interactive streaming chat session") - print(f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT") + print( + f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT" + ) sys.exit(1) elif selected_count > 1: print("Please specify exactly one test mode") sys.exit(1) - + try: # Create the core API client client = APIClient( endpoint_group_name=args.endpoint_group_name, api_key=args.api_key, - server_url=args.server_url + server_url=args.server_url, + instance=args.instance, ) - + # Create tool manager and demo (passing the model parameter) tool_manager = ToolManager() demo = APIDemo(client, args.model, tool_manager) - + print(f"Using model: {args.model}") print("=" * 60) - + # Run the selected test if args.completion: demo.demo_completions() @@ -568,11 +580,11 @@ def main(): demo.demo_ls_tool() elif args.interactive: demo.interactive_chat() - + except Exception as e: log.error(f"Error during test: {e}", exc_info=True) sys.exit(1) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/workers/openai/data_types/client.py b/workers/openai/data_types/client.py index 205d596..444ae2d 100644 --- a/workers/openai/data_types/client.py +++ b/workers/openai/data_types/client.py @@ -3,11 +3,13 @@ from dataclasses import dataclass, field, fields, is_dataclass from typing import Optional, List, Dict, Any -class SerializableDataclass: +class SerializableDataclass: def _serialize_recursive(self, obj: Any) -> Any: if is_dataclass(obj): - return {field.name: self._serialize_recursive(getattr(obj, field.name)) - for field in fields(obj)} + return { + field.name: self._serialize_recursive(getattr(obj, field.name)) + for field in fields(obj) + } elif isinstance(obj, dict): return {key: self._serialize_recursive(value) for key, value in obj.items()} elif isinstance(obj, (list, tuple)): @@ -16,10 +18,10 @@ class SerializableDataclass: return [self._serialize_recursive(item) for item in obj] else: return obj - + def to_dict(self) -> Dict[str, Any]: return self._serialize_recursive(self) - + def to_json(self, indent: int = 2) -> str: return json.dumps(self.to_dict(), indent=indent) @@ -27,6 +29,7 @@ class SerializableDataclass: @dataclass class CompletionConfig(SerializableDataclass): """Configuration for completion requests""" + model: str prompt: str = "Hello" max_tokens: int = 256 @@ -39,8 +42,9 @@ class CompletionConfig(SerializableDataclass): @dataclass class ChatCompletionConfig(SerializableDataclass): """Configuration for chat completion requests""" + model: str - messages: list = None + messages: list = field(default_factory=list) max_tokens: int = 2096 temperature: float = 0.7 top_k: int = 20 @@ -48,7 +52,7 @@ class ChatCompletionConfig(SerializableDataclass): stream: bool = False tools: Optional[List[Dict[str, Any]]] = field(default_factory=list) tool_choice: str = "auto" - + def __post_init__(self): if self.messages is None: - self.messages = [{"role": "user", "content": "Hello"}] \ No newline at end of file + self.messages = [{"role": "user", "content": "Hello"}] diff --git a/workers/openai/data_types/server.py b/workers/openai/data_types/server.py index f0e341e..dd9b45c 100644 --- a/workers/openai/data_types/server.py +++ b/workers/openai/data_types/server.py @@ -2,7 +2,7 @@ import os, json, random from abc import ABC, abstractmethod from dataclasses import dataclass from lib.data_types import EndpointHandler, ApiPayload, JsonDataException -from typing import Union, Type, Dict, Any +from typing import Union, Type, Dict, Any, Optional from aiohttp import web, ClientResponse import nltk import logging @@ -10,41 +10,39 @@ import logging nltk.download("words") WORD_LIST = nltk.corpus.words.words() log = logging.getLogger(__name__) - + """ Generic dataclass accepts any dictionary in input. """ + + @dataclass class GenericData(ApiPayload, ABC): input: Dict[str, Any] @classmethod def from_dict(cls, data: Dict[str, Any]) -> "GenericData": - return cls( - input=data["input"] - ) - + return cls(input=data["input"]) + @classmethod def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData": errors = {} - + # Validate required parameters required_params = ["input"] for param in required_params: if param not in json_msg: errors[param] = "missing parameter" - + if errors: raise JsonDataException(errors) - + try: # Create clean data dict and delegate to from_dict - clean_data = { - "input": json_msg["input"] - } - + clean_data = {"input": json_msg["input"]} + return cls.from_dict(clean_data) - + except (json.JSONDecodeError, JsonDataException) as e: errors["parameters"] = str(e) raise JsonDataException(errors) @@ -59,7 +57,8 @@ class GenericData(ApiPayload, ABC): def count_workload(self) -> int: return self.input.get("max_tokens", 0) - + + @dataclass class GenericHandler(EndpointHandler[GenericData], ABC): @@ -67,10 +66,10 @@ class GenericHandler(EndpointHandler[GenericData], ABC): @abstractmethod def endpoint(self) -> str: pass - + @property - def healthcheck_endpoint(self) -> str: - return os.environ.get('MODEL_HEALTH_ENDPOINT') + def healthcheck_endpoint(self) -> Optional[str]: + return os.environ.get("MODEL_HEALTH_ENDPOINT") @classmethod def payload_cls(cls) -> Type[GenericData]: @@ -82,17 +81,17 @@ class GenericHandler(EndpointHandler[GenericData], ABC): async def generate_client_response( self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: + ) -> Union[web.Response, web.StreamResponse]: match model_response.status: case 200: # Check if the response is actually streaming based on response headers/content-type is_streaming_response = ( - model_response.content_type == "text/event-stream" or - model_response.content_type == "application/x-ndjson" or - model_response.headers.get("Transfer-Encoding") == "chunked" or - "stream" in model_response.content_type.lower() + model_response.content_type == "text/event-stream" + or model_response.content_type == "application/x-ndjson" + or model_response.headers.get("Transfer-Encoding") == "chunked" + or "stream" in model_response.content_type.lower() ) - + if is_streaming_response: log.debug("Detected streaming response...") res = web.StreamResponse() @@ -109,12 +108,13 @@ class GenericHandler(EndpointHandler[GenericData], ABC): return web.Response( body=content, status=200, - content_type=model_response.content_type + content_type=model_response.content_type, ) case code: log.debug("SENDING RESPONSE: ERROR: unknown code") return web.Response(status=code) + @dataclass class CompletionsData(GenericData): @classmethod @@ -123,55 +123,54 @@ class CompletionsData(GenericData): model = os.environ.get("MODEL_NAME") if not model: raise ValueError("MODEL_NAME environment variable not set") - - test_input = { - "model": model, - "prompt": prompt, - "temperature": 0.7 - } + + test_input = {"model": model, "prompt": prompt, "temperature": 0.7} return cls(input=test_input) - + + @dataclass class CompletionsHandler(GenericHandler): @property def endpoint(self) -> str: return "/v1/completions" - + @classmethod def payload_cls(cls) -> Type[CompletionsData]: return CompletionsData - + def make_benchmark_payload(self) -> CompletionsData: return CompletionsData.for_test() + @dataclass class ChatCompletionsData(GenericData): """Chat completions-specific data implementation""" - + @classmethod def for_test(cls) -> "ChatCompletionsData": prompt = " ".join(random.choices(WORD_LIST, k=int(250))) model = os.environ.get("MODEL_NAME") if not model: raise ValueError("MODEL_NAME environment variable not set") - + # Chat completions use messages format instead of prompt test_input = { "model": model, "messages": [{"role": "user", "content": prompt}], - "temperature": 0.7 + "temperature": 0.7, } return cls(input=test_input) - -@dataclass -class ChatCompletionsHandler(GenericHandler): + + +@dataclass +class ChatCompletionsHandler(GenericHandler): @property def endpoint(self) -> str: return "/v1/chat/completions" - + @classmethod def payload_cls(cls) -> Type[ChatCompletionsData]: return ChatCompletionsData - + def make_benchmark_payload(self) -> ChatCompletionsData: return ChatCompletionsData.for_test() diff --git a/workers/openai/server.py b/workers/openai/server.py index 3eee141..bfeb819 100644 --- a/workers/openai/server.py +++ b/workers/openai/server.py @@ -7,20 +7,20 @@ from lib.server import start_server # This line indicates that the inference server is listening MODEL_SERVER_START_LOG_MSG = [ - "Application startup complete.", # vLLM - "llama runner started", # Ollama - '"message":"Connected","target":"text_generation_router"', # TGI - '"message":"Connected","target":"text_generation_router::server"', # TGI + "Application startup complete.", # vLLM + "llama runner started", # Ollama + '"message":"Connected","target":"text_generation_router"', # TGI + '"message":"Connected","target":"text_generation_router::server"', # TGI ] MODEL_SERVER_ERROR_LOG_MSGS = [ - "INFO exited: vllm", # vLLM - "RuntimeError: Engine", # vLLM - "Error: pull model manifest:", # Ollama - "stalled; retrying", # Ollama - "Error: WebserverFailed", # TGI - "Error: DownloadError", # TGI - "Error: ShardCannotStart", #TGI + "INFO exited: vllm", # vLLM + "RuntimeError: Engine", # vLLM + "Error: pull model manifest:", # Ollama + "stalled; retrying", # Ollama + "Error: WebserverFailed", # TGI + "Error: DownloadError", # TGI + "Error: ShardCannotStart", # TGI ] logging.basicConfig( @@ -31,8 +31,8 @@ logging.basicConfig( log = logging.getLogger(__file__) backend = Backend( - model_server_url=os.environ.get("MODEL_SERVER_URL"), - model_log_file=os.environ.get("MODEL_LOG"), + model_server_url=os.environ["MODEL_SERVER_URL"], + model_log_file=os.environ["MODEL_LOG"], allow_parallel_requests=True, benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256), log_actions=[ @@ -45,9 +45,11 @@ backend = Backend( ], ) + async def handle_ping(_): return web.Response(body="pong") + routes = [ web.post("/v1/completions", backend.create_handler(CompletionsHandler())), web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())), diff --git a/workers/openai/test_load.py b/workers/openai/test_load.py index 385837c..0c45524 100644 --- a/workers/openai/test_load.py +++ b/workers/openai/test_load.py @@ -7,22 +7,22 @@ WORKER_ENDPOINT = "/v1/completions" if __name__ == "__main__": # Check if MODEL_NAME environment variable is set model_name_set = os.environ.get("MODEL_NAME") is not None - + # Add model argument - required only if MODEL_NAME is not set test_args.add_argument( - "--model", + "--model", dest="model", required=not model_name_set, - help="Model to use for completions request (required if MODEL_NAME env var not set)" + help="Model to use for completions request (required if MODEL_NAME env var not set)", ) - + # Parse known args to get model early, before test_load_cmd adds its args known_args, _ = test_args.parse_known_args() - + # Set environment variable if model was provided - if hasattr(known_args, 'model') and known_args.model: + if hasattr(known_args, "model") and known_args.model: os.environ["MODEL_NAME"] = known_args.model print(f"Set MODEL_NAME environment variable to: {known_args.model}") - + # Now call test_load_cmd normally - it will add its own args and re-parse - test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args) \ No newline at end of file + test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)