diff --git a/workers/openai/README.md b/workers/openai/README.md new file mode 100644 index 0000000..2436784 --- /dev/null +++ b/workers/openai/README.md @@ -0,0 +1,80 @@ +# OpenAI Compatible PyWorker + +This is the base PyWorker for OpenAI compatible inference servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's. + +## Instance Setup + +1. Pick a template + +This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker. + +- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended) +- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless)) +- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless)) + + +All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected. + +2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. + +## Client Setup (Demo) + +1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client. + +```bash +git clone https://github.com/vast-ai/pyworker +cd pyworker +pip install uv +uv venv -p 3.12 +source .venv/bin/activate +uv pip install -r requirements.txt +``` + +## Using the Test Client + +Several examples have been provided in the client to help you get started with your own implementation. + +### Completions + +Call to `/v1/completions` with json response + +```bash +python -m workers.openai.client -k -e --completion --model +``` + +### Chat Completion (json) + +Call to `/v1/chat/completions` with json response + +```bash +python -m workers.openai.client -k -e --chat --model +``` + +### Chat Completion (streaming) + +Call to `/v1/chat/completions` with streaming response + +```bash +python -m workers.openai.client -k -e --chat-stream --model +``` + +### Tool Use (json) + +Call to `/v1/chat/completions` with tool and json response. + +This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model. + +```bash +python -m workers.openai.client -k -e --tools --model +``` + +### Interactive Chat (streaming) + +Interactive session with calls to `/v1/chat/completions`. + +Type `clear` to clear the chat history or `quit` to exit. + +```bash +python -m workers.openai.client -k -e --interactive --model +``` + diff --git a/workers/openai/README.templates.md b/workers/openai/README.templates.md new file mode 100644 index 0000000..f4d7c2b --- /dev/null +++ b/workers/openai/README.templates.md @@ -0,0 +1,77 @@ +# + (serverless) + +Run with our serverless autoscaling infrastructure. + +See the [serverless documentation](https://docs.vast.ai/serverless) and the [Getting Started](https://docs.vast.ai/serverless/getting-started) guide for in-depth details about how to use these templates. + +## Configuration + +Two environment variables are provided to help you configure the server: + +| Variable | Default Value | Used For | +| --- | --- | --- | +| `MODEL_NAME` | `` | The model to load. Also accepts [hf.co/repo/model](#) links | +| `` | `` | Arguments to pass to the `` command | + +This template has been configured to work with VRAM. Setting alternative models and server arguments will change the VRAM requirements. Check model cards and for guidance. + +## Usage + +We have provided a demonstration client to help you implement this template into your own infrastructure + +### Client Setup + +Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client. + +```bash +git clone https://github.com/vast-ai/pyworker +cd pyworker +pip install uv +uv venv -p 3.12 +source .venv/bin/activate +uv pip install -r requirements.txt +``` + +### Completions + +Call to `/v1/completions` with json response + +```bash +python -m workers.openai.client -k -e --completion --model +``` + +### Chat Completion (json) + +Call to `/v1/chat/completions` with json response + +```bash +python -m workers.openai.client -k -e --chat --model +``` + +### Chat Completion (streaming) + +Call to `/v1/chat/completions` with streaming response + +```bash +python -m workers.openai.client -k -e --chat-stream --model +``` + +### Tool Use (json) + +Call to `/v1/chat/completions` with tool and json response. + +This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model. + +```bash +python -m workers.openai.client -k -e --tools --model +``` + +### Interactive Chat (streaming) + +Interactive session with calls to `/v1/chat/completions`. + +Type `clear` to clear the chat history or `quit` to exit. + +```bash +python -m workers.openai.client -k -e --interactive --model +``` diff --git a/workers/openai/__init__.py b/workers/openai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/workers/openai/client.py b/workers/openai/client.py new file mode 100644 index 0000000..2748aab --- /dev/null +++ b/workers/openai/client.py @@ -0,0 +1,578 @@ +import logging +import sys +import json +import subprocess +from urllib.parse import urljoin +from typing import Dict, Any, Optional, Iterator, Union, List +import requests +from utils.endpoint_util import Endpoint +from .data_types.client import CompletionConfig, ChatCompletionConfig + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s[%(levelname)-5s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +log = logging.getLogger(__file__) + +COMPLETIONS_PROMPT = "the capital of USA is" +CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." +TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?" + +class APIClient: + """Lightweight client focused solely on API communication""" + + # Remove the generic WORKER_ENDPOINT since we're now going direct + DEFAULT_COST = 100 + DEFAULT_TIMEOUT = 4 + + def __init__(self, endpoint_group_name: str, api_key: str, server_url: str): + self.endpoint_group_name = endpoint_group_name + self.api_key = api_key + self.server_url = server_url + self.endpoint_api_key = self._get_endpoint_api_key() + + def _get_endpoint_api_key(self) -> Optional[str]: + """Get the endpoint API key""" + endpoint_api_key = Endpoint.get_endpoint_api_key( + endpoint_name=self.endpoint_group_name, + account_api_key=self.api_key, + ) + if not endpoint_api_key: + log.error(f"Failed to get API key for endpoint {self.endpoint_group_name}") + return endpoint_api_key + + def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]: + """Get worker URL and auth data from routing service""" + if not self.endpoint_api_key: + raise ValueError("No valid endpoint API key available") + + route_payload = { + "endpoint": self.endpoint_group_name, + "api_key": self.endpoint_api_key, + "cost": cost, + } + + response = requests.post( + urljoin(self.server_url, "/route/"), + json=route_payload, + timeout=self.DEFAULT_TIMEOUT, + ) + response.raise_for_status() + return response.json() + + def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Create auth data from routing response""" + return { + "signature": message["signature"], + "cost": message["cost"], + "endpoint": message["endpoint"], + "reqnum": message["reqnum"], + "url": message["url"], + } + + def _make_request(self, payload: Dict[str, Any], endpoint: str, method: str = "POST", + stream: bool = False) -> Union[Dict[str, Any], Iterator[str]]: + """Make request directly to the specific worker endpoint""" + # Get worker URL and auth data + cost = payload.get('max_tokens') + message = self._get_worker_url(cost=cost) + worker_url = message["url"] + auth_data = self._create_auth_data(message) + + req_data = { + "payload": { + "input": payload + }, + "auth_data": auth_data + } + + url = urljoin(worker_url, endpoint) + log.debug(f"Making direct request to: {url}") + log.debug(f"Payload: {req_data}") + + # Make the request using the specified method + if method.upper() == "POST": + response = requests.post(url, json=req_data, stream=stream) + elif method.upper() == "GET": + response = requests.get(url, params=req_data, stream=stream) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + + if stream: + return self._handle_streaming_response(response) + else: + return response.json() + + def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]: + """Handle streaming response and yield tokens""" + try: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + try: + data = json.loads(data_str) + yield data # Yield the full chunk + except json.JSONDecodeError: + continue + except Exception as e: + log.error(f"Error handling streaming response: {e}") + raise + + + def call_completions(self, config: CompletionConfig) -> Union[Dict[str, Any], Iterator[str]]: + payload = config.to_dict() + + return self._make_request( + payload=payload, + endpoint="/v1/completions", + stream=config.stream + ) + + def call_chat_completions(self, config: ChatCompletionConfig) -> Union[Dict[str, Any], Iterator[str]]: + payload = config.to_dict() + + return self._make_request( + payload=payload, + endpoint="/v1/chat/completions", + stream=config.stream + ) + + +class ToolManager: + """Handles tool definitions and execution""" + + @staticmethod + def list_files() -> str: + """Execute ls on current directory""" + try: + result = subprocess.run(['ls', '-la', '.'], capture_output=True, text=True, timeout=10) + if result.returncode == 0: + return result.stdout + else: + return f"Error: {result.stderr}" + except Exception as e: + return f"Error running ls: {e}" + + @staticmethod + def get_ls_tool_definition() -> List[Dict[str, Any]]: + """Get the ls tool definition""" + return [{ + "type": "function", + "function": { + "name": "list_files", + "description": "List files and directories in the cwd", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + }] + + def execute_tool_call(self, tool_call: Dict[str, Any]) -> str: + """Execute a tool call and return the result""" + function_name = tool_call["function"]["name"] + + if function_name == "list_files": + return self.list_files() + else: + raise ValueError(f"Unknown tool function: {function_name}") + + +class APIDemo: + """Demo and testing functionality for the API client""" + + def __init__(self, client: APIClient, model: str, tool_manager: ToolManager = None): + self.client = client + self.model = model + self.tool_manager = tool_manager or ToolManager() + + def handle_streaming_response(self, response_stream, show_reasoning: bool = True) -> str: + """ + Handle streaming chat response and display all output. + """ + + full_response = "" + reasoning_content = "" + reasoning_started = False + content_started = False + + for chunk in response_stream: + # Normalize the chunk + if isinstance(chunk, str): + chunk = chunk.strip() + if chunk.startswith("data: "): + chunk = chunk[6:].strip() + if chunk in ["[DONE]", ""]: + continue + try: + parsed_chunk = json.loads(chunk) + except json.JSONDecodeError: + continue + elif isinstance(chunk, dict): + parsed_chunk = chunk + else: + continue + + # Parse delta from the chunk + choices = parsed_chunk.get("choices", []) + if not choices: + continue + + delta = choices[0].get("delta", {}) + reasoning_token = delta.get("reasoning_content", "") + content_token = delta.get("content", "") + + # Print reasoning token if applicable + if show_reasoning and reasoning_token: + if not reasoning_started: + print("\n🧠 Reasoning: ", end="", flush=True) + reasoning_started = True + print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True) + reasoning_content += reasoning_token + + # Print content token + if content_token: + if not content_started: + if show_reasoning and reasoning_started: + print(f"\nšŸ’¬ Response: ", end="", flush=True) + else: + print("Assistant: ", end="", flush=True) + content_started = True + print(content_token, end="", flush=True) + full_response += content_token + + print() # Ensure newline after response + + if show_reasoning: + if reasoning_started or content_started: + print("\nStreaming completed.") + if reasoning_started: + print(f"Reasoning tokens: {len(reasoning_content.split())}") + if content_started: + print(f"Response tokens: {len(full_response.split())}") + + return full_response + + + def test_tool_support(self) -> bool: + """Test if the endpoint supports function calling""" + log.debug("Testing endpoint tool calling support...") + + # Try a simple request with minimal tools to test support + messages = [{"role": "user", "content": "Hello"}] + minimal_tool = [{ + "type": "function", + "function": { + "name": "test_function", + "description": "Test function" + } + }] + + config = ChatCompletionConfig( + model=self.model, + messages=messages, + max_tokens=10, + tools=minimal_tool, + tool_choice="none" # Don't actually call the tool + ) + + try: + response = self.client.call_chat_completions(config) + return True + except Exception as e: + log.error(f"Error: Endpoint does not support tool calling: {e}") + return False + + def demo_completions(self) -> None: + """Demo: test basic completions endpoint""" + print("=" * 60) + print("COMPLETIONS DEMO") + print("=" * 60) + + config = CompletionConfig( + model=self.model, + prompt=COMPLETIONS_PROMPT, + stream=False + ) + + log.info(f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'") + response = self.client.call_completions(config) + + if isinstance(response, dict): + print("\nResponse:") + print(json.dumps(response, indent=2)) + else: + log.error("Unexpected response format") + + def demo_chat(self, use_streaming: bool = True) -> None: + """ + Demo: test chat completions endpoint with optional streaming + """ + print("=" * 60) + print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}") + print("=" * 60) + + config = ChatCompletionConfig( + model=self.model, + messages=[{"role": "user", "content": CHAT_PROMPT}], + stream=use_streaming, + ) + + log.info(f"Testing chat completions with model '{self.model}'...") + response = self.client.call_chat_completions(config) + + if use_streaming: + try: + self.handle_streaming_response(response, show_reasoning=True) + except Exception as e: + log.error(f"\nError during streaming: {e}") + import traceback + traceback.print_exc() + return + + else: + if isinstance(response, dict): + choice = response.get("choices", [{}])[0] + message = choice.get("message", {}) + content = message.get("content", "") + reasoning = message.get("reasoning_content", "") or message.get("reasoning", "") + + if reasoning: + print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m") + + print(f"\nšŸ’¬ Assistant: {content}") + print(f"\nFull Response:") + print(json.dumps(response, indent=2)) + else: + log.error("Unexpected response format") + + + + def demo_ls_tool(self) -> None: + """Demo: ask LLM to list files in the current directory and describe what it sees""" + print("=" * 60) + print("TOOL USE DEMO: List Directory Contents") + print("=" * 60) + + # Test if tools are supported first + if not self.test_tool_support(): + return + + # Request with tool available + messages = [ + {"role": "user", "content": TOOLS_PROMPT} + ] + + config = ChatCompletionConfig( + model=self.model, + messages=messages, + tools=self.tool_manager.get_ls_tool_definition(), + tool_choice="auto" + ) + + log.info(f"Making initial request with tool using model '{self.model}'...") + response = self.client.call_chat_completions(config) + + if not isinstance(response, dict): + raise ValueError("Expected dict response for tool use") + + choice = response.get("choices", [{}])[0] + message = choice.get("message", {}) + + print(f"Assistant response: {message.get('content', 'No content')}") + + # Check for tool calls + tool_calls = message.get("tool_calls") + if not tool_calls: + raise ValueError("No tool calls made - model may not support function calling") + + print(f"Tool calls detected: {len(tool_calls)}") + + # Execute the tool call + for tool_call in tool_calls: + function_name = tool_call["function"]["name"] + print(f"Executing tool: {function_name}") + + tool_result = self.tool_manager.execute_tool_call(tool_call) + print(f"Tool result:\n{tool_result}") + + # Add tool result and continue conversation + messages.append(message) # Add assistant's message with tool call + messages.append({ + "role": "tool", + "tool_call_id": tool_call["id"], + "content": tool_result + }) + + # Get final response + final_config = ChatCompletionConfig( + model=self.model, + messages=messages, + tools=self.tool_manager.get_ls_tool_definition() + ) + + print("Getting final response...") + final_response = self.client.call_chat_completions(final_config) + + if isinstance(final_response, dict): + final_choice = final_response.get("choices", [{}])[0] + final_message = final_choice.get("message", {}) + final_content = final_message.get("content", "") + + print("\n" + "=" * 60) + print("FINAL LLM ANALYSIS:") + print("=" * 60) + print(final_content) + print("=" * 60) + + def interactive_chat(self) -> None: + """Interactive chat session with streaming""" + print("=" * 60) + print("INTERACTIVE STREAMING CHAT") + print("=" * 60) + print(f"Using model: {self.model}") + print("Type 'quit' to exit, 'clear' to clear history") + print() + + messages = [] + + while True: + try: + user_input = input("You: ").strip() + + if user_input.lower() == 'quit': + print("šŸ‘‹ Goodbye!") + break + elif user_input.lower() == 'clear': + messages = [] + print("Chat history cleared") + continue + elif not user_input: + continue + + messages.append({"role": "user", "content": user_input}) + + config = ChatCompletionConfig( + model=self.model, + messages=messages, + stream=True, + temperature=0.7 + ) + + print("Assistant: ", end="", flush=True) + + response = self.client.call_chat_completions(config) + assistant_content = self.handle_streaming_response(response, show_reasoning=True) + + # Add assistant response to conversation history + messages.append({"role": "assistant", "content": assistant_content}) + + except KeyboardInterrupt: + print("\nšŸ‘‹ Chat interrupted. Goodbye!") + break + except Exception as e: + log.error(f"\nError: {e}") + continue + + +def main(): + """Main function with CLI switches for different tests""" + from lib.test_utils import test_args + + # Add mandatory model argument + test_args.add_argument( + "--model", + required=True, + help="Model to use for requests (required)" + ) + + # Add test mode arguments + test_args.add_argument( + "--completion", + action="store_true", + help="Test completions endpoint" + ) + test_args.add_argument( + "--chat", + action="store_true", + help="Test chat completions endpoint (non-streaming)" + ) + test_args.add_argument( + "--chat-stream", + action="store_true", + help="Test chat completions endpoint with streaming" + ) + test_args.add_argument( + "--tools", + action="store_true", + help="Test function calling with ls tool (non-streaming)" + ) + test_args.add_argument( + "--interactive", + action="store_true", + help="Start interactive streaming chat session" + ) + + args = test_args.parse_args() + + # Check that only one test mode is selected + test_modes = [ + args.completion, args.chat, args.chat_stream, + args.tools, args.interactive + ] + selected_count = sum(test_modes) + + if selected_count == 0: + print("Please specify exactly one test mode:") + print(" --completion : Test completions endpoint") + print(" --chat : Test chat completions endpoint (non-streaming)") + print(" --chat-stream : Test chat completions endpoint with streaming") + print(" --tools : Test function calling with ls tool (non-streaming)") + print(" --interactive : Start interactive streaming chat session") + print(f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT") + sys.exit(1) + elif selected_count > 1: + print("Please specify exactly one test mode") + sys.exit(1) + + try: + # Create the core API client + client = APIClient( + endpoint_group_name=args.endpoint_group_name, + api_key=args.api_key, + server_url=args.server_url + ) + + # Create tool manager and demo (passing the model parameter) + tool_manager = ToolManager() + demo = APIDemo(client, args.model, tool_manager) + + print(f"Using model: {args.model}") + print("=" * 60) + + # Run the selected test + if args.completion: + demo.demo_completions() + elif args.chat: + demo.demo_chat(use_streaming=False) + elif args.chat_stream: + demo.demo_chat(use_streaming=True) + elif args.tools: + demo.demo_ls_tool() + elif args.interactive: + demo.interactive_chat() + + except Exception as e: + log.error(f"Error during test: {e}", exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/workers/openai/data_types/__init__.py b/workers/openai/data_types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/workers/openai/data_types/client.py b/workers/openai/data_types/client.py new file mode 100644 index 0000000..205d596 --- /dev/null +++ b/workers/openai/data_types/client.py @@ -0,0 +1,54 @@ +import json +from dataclasses import dataclass, field, fields, is_dataclass +from typing import Optional, List, Dict, Any + + +class SerializableDataclass: + def _serialize_recursive(self, obj: Any) -> Any: + if is_dataclass(obj): + return {field.name: self._serialize_recursive(getattr(obj, field.name)) + for field in fields(obj)} + elif isinstance(obj, dict): + return {key: self._serialize_recursive(value) for key, value in obj.items()} + elif isinstance(obj, (list, tuple)): + return [self._serialize_recursive(item) for item in obj] + elif isinstance(obj, set): + return [self._serialize_recursive(item) for item in obj] + else: + return obj + + def to_dict(self) -> Dict[str, Any]: + return self._serialize_recursive(self) + + def to_json(self, indent: int = 2) -> str: + return json.dumps(self.to_dict(), indent=indent) + + +@dataclass +class CompletionConfig(SerializableDataclass): + """Configuration for completion requests""" + model: str + prompt: str = "Hello" + max_tokens: int = 256 + temperature: float = 0.7 + top_k: int = 20 + top_p: float = 0.4 + stream: bool = False + + +@dataclass +class ChatCompletionConfig(SerializableDataclass): + """Configuration for chat completion requests""" + model: str + messages: list = None + max_tokens: int = 2096 + temperature: float = 0.7 + top_k: int = 20 + top_p: float = 0.4 + stream: bool = False + tools: Optional[List[Dict[str, Any]]] = field(default_factory=list) + tool_choice: str = "auto" + + def __post_init__(self): + if self.messages is None: + self.messages = [{"role": "user", "content": "Hello"}] \ No newline at end of file diff --git a/workers/openai/data_types/server.py b/workers/openai/data_types/server.py new file mode 100644 index 0000000..f0e341e --- /dev/null +++ b/workers/openai/data_types/server.py @@ -0,0 +1,177 @@ +import os, json, random +from abc import ABC, abstractmethod +from dataclasses import dataclass +from lib.data_types import EndpointHandler, ApiPayload, JsonDataException +from typing import Union, Type, Dict, Any +from aiohttp import web, ClientResponse +import nltk +import logging + +nltk.download("words") +WORD_LIST = nltk.corpus.words.words() +log = logging.getLogger(__name__) + +""" +Generic dataclass accepts any dictionary in input. +""" +@dataclass +class GenericData(ApiPayload, ABC): + input: Dict[str, Any] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GenericData": + return cls( + input=data["input"] + ) + + @classmethod + def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData": + errors = {} + + # Validate required parameters + required_params = ["input"] + for param in required_params: + if param not in json_msg: + errors[param] = "missing parameter" + + if errors: + raise JsonDataException(errors) + + try: + # Create clean data dict and delegate to from_dict + clean_data = { + "input": json_msg["input"] + } + + return cls.from_dict(clean_data) + + except (json.JSONDecodeError, JsonDataException) as e: + errors["parameters"] = str(e) + raise JsonDataException(errors) + + @classmethod + @abstractmethod + def for_test(cls) -> "GenericData": + pass + + def generate_payload_json(self) -> Dict[str, Any]: + return self.input + + def count_workload(self) -> int: + return self.input.get("max_tokens", 0) + +@dataclass +class GenericHandler(EndpointHandler[GenericData], ABC): + + @property + @abstractmethod + def endpoint(self) -> str: + pass + + @property + def healthcheck_endpoint(self) -> str: + return os.environ.get('MODEL_HEALTH_ENDPOINT') + + @classmethod + def payload_cls(cls) -> Type[GenericData]: + return GenericData + + @abstractmethod + def make_benchmark_payload(self) -> GenericData: + pass + + async def generate_client_response( + self, client_request: web.Request, model_response: ClientResponse + ) -> Union[web.Response, web.StreamResponse]: + match model_response.status: + case 200: + # Check if the response is actually streaming based on response headers/content-type + is_streaming_response = ( + model_response.content_type == "text/event-stream" or + model_response.content_type == "application/x-ndjson" or + model_response.headers.get("Transfer-Encoding") == "chunked" or + "stream" in model_response.content_type.lower() + ) + + if is_streaming_response: + log.debug("Detected streaming response...") + res = web.StreamResponse() + res.content_type = model_response.content_type + await res.prepare(client_request) + async for chunk in model_response.content: + await res.write(chunk) + await res.write_eof() + log.debug("Done streaming response") + return res + else: + log.debug("Detected non-streaming response...") + content = await model_response.read() + return web.Response( + body=content, + status=200, + content_type=model_response.content_type + ) + case code: + log.debug("SENDING RESPONSE: ERROR: unknown code") + return web.Response(status=code) + +@dataclass +class CompletionsData(GenericData): + @classmethod + def for_test(cls) -> "CompletionsData": + prompt = " ".join(random.choices(WORD_LIST, k=int(250))) + model = os.environ.get("MODEL_NAME") + if not model: + raise ValueError("MODEL_NAME environment variable not set") + + test_input = { + "model": model, + "prompt": prompt, + "temperature": 0.7 + } + return cls(input=test_input) + +@dataclass +class CompletionsHandler(GenericHandler): + @property + def endpoint(self) -> str: + return "/v1/completions" + + @classmethod + def payload_cls(cls) -> Type[CompletionsData]: + return CompletionsData + + def make_benchmark_payload(self) -> CompletionsData: + return CompletionsData.for_test() + +@dataclass +class ChatCompletionsData(GenericData): + """Chat completions-specific data implementation""" + + @classmethod + def for_test(cls) -> "ChatCompletionsData": + prompt = " ".join(random.choices(WORD_LIST, k=int(250))) + model = os.environ.get("MODEL_NAME") + if not model: + raise ValueError("MODEL_NAME environment variable not set") + + # Chat completions use messages format instead of prompt + test_input = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.7 + } + return cls(input=test_input) + +@dataclass +class ChatCompletionsHandler(GenericHandler): + @property + def endpoint(self) -> str: + return "/v1/chat/completions" + + @classmethod + def payload_cls(cls) -> Type[ChatCompletionsData]: + return ChatCompletionsData + + def make_benchmark_payload(self) -> ChatCompletionsData: + return ChatCompletionsData.for_test() diff --git a/workers/openai/server.py b/workers/openai/server.py new file mode 100644 index 0000000..3eee141 --- /dev/null +++ b/workers/openai/server.py @@ -0,0 +1,58 @@ +import os +import logging +from .data_types.server import CompletionsHandler, ChatCompletionsHandler +from aiohttp import web +from lib.backend import Backend, LogAction +from lib.server import start_server + +# This line indicates that the inference server is listening +MODEL_SERVER_START_LOG_MSG = [ + "Application startup complete.", # vLLM + "llama runner started", # Ollama + '"message":"Connected","target":"text_generation_router"', # TGI + '"message":"Connected","target":"text_generation_router::server"', # TGI +] + +MODEL_SERVER_ERROR_LOG_MSGS = [ + "INFO exited: vllm", # vLLM + "RuntimeError: Engine", # vLLM + "Error: pull model manifest:", # Ollama + "stalled; retrying", # Ollama + "Error: WebserverFailed", # TGI + "Error: DownloadError", # TGI + "Error: ShardCannotStart", #TGI +] + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s[%(levelname)-5s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +log = logging.getLogger(__file__) + +backend = Backend( + model_server_url=os.environ.get("MODEL_SERVER_URL"), + model_log_file=os.environ.get("MODEL_LOG"), + allow_parallel_requests=True, + benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256), + log_actions=[ + *[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG], + (LogAction.Info, '"message":"Download'), + *[ + (LogAction.ModelError, error_msg) + for error_msg in MODEL_SERVER_ERROR_LOG_MSGS + ], + ], +) + +async def handle_ping(_): + return web.Response(body="pong") + +routes = [ + web.post("/v1/completions", backend.create_handler(CompletionsHandler())), + web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())), + web.get("/ping", handle_ping), +] + +if __name__ == "__main__": + start_server(backend, routes) diff --git a/workers/openai/test_load.py b/workers/openai/test_load.py new file mode 100644 index 0000000..385837c --- /dev/null +++ b/workers/openai/test_load.py @@ -0,0 +1,28 @@ +from lib.test_utils import test_load_cmd, test_args +from .data_types.server import CompletionsData +import os + +WORKER_ENDPOINT = "/v1/completions" + +if __name__ == "__main__": + # Check if MODEL_NAME environment variable is set + model_name_set = os.environ.get("MODEL_NAME") is not None + + # Add model argument - required only if MODEL_NAME is not set + test_args.add_argument( + "--model", + dest="model", + required=not model_name_set, + help="Model to use for completions request (required if MODEL_NAME env var not set)" + ) + + # Parse known args to get model early, before test_load_cmd adds its args + known_args, _ = test_args.parse_known_args() + + # Set environment variable if model was provided + if hasattr(known_args, 'model') and known_args.model: + os.environ["MODEL_NAME"] = known_args.model + print(f"Set MODEL_NAME environment variable to: {known_args.model}") + + # Now call test_load_cmd normally - it will add its own args and re-parse + test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args) \ No newline at end of file