diff --git a/workers/openai/client.py b/workers/openai/client.py index 8c88444..e4836a4 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -102,12 +102,10 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa endpoint = await client.get_endpoint(name=ENDPOINT_NAME) payload = { - "input": { - "model": model, - "prompt": prompt, - "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), - "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), - } + "model": model, + "prompt": prompt, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), } log.debug("POST /v1/completions %s", json.dumps(payload)[:500]) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) @@ -118,14 +116,12 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis endpoint = await client.get_endpoint(name=ENDPOINT_NAME) payload = { - "input": { - "model": model, - "messages": messages, - "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), - "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), - **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), - **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), - } + "model": model, + "messages": messages, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), + **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), } log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500]) resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"]) @@ -137,14 +133,12 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k endpoint = await client.get_endpoint(name=ENDPOINT_NAME) payload = { - "input": { - "model": model, - "prompt": prompt, - "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), - "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), - "stream": True, - **({"stop": kwargs["stop"]} if "stop" in kwargs else {}), - } + "model": model, + "prompt": prompt, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + "stream": True, + **({"stop": kwargs["stop"]} if "stop" in kwargs else {}), } log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500]) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) @@ -155,15 +149,13 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L endpoint = await client.get_endpoint(name=ENDPOINT_NAME) payload = { - "input": { - "model": model, - "messages": messages, - "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), - "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), - "stream": True, - **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), - **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), - } + "model": model, + "messages": messages, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + "stream": True, + **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), + **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), } log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500]) resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True) diff --git a/workers/openai/server.py b/workers/openai/server.py deleted file mode 100644 index 8dc962f..0000000 --- a/workers/openai/server.py +++ /dev/null @@ -1,61 +0,0 @@ -import os -import logging -from .data_types.server import CompletionsHandler, ChatCompletionsHandler -from aiohttp import web -from lib.backend import Backend, LogAction -from lib.server import start_server - -# This line indicates that the inference server is listening -MODEL_SERVER_START_LOG_MSG = [ - "Application startup complete.", # vLLM - "llama runner started", # Ollama - '"message":"Connected","target":"text_generation_router"', # TGI - '"message":"Connected","target":"text_generation_router::server"', # TGI - "main: model loaded" # llama.cpp -] - -MODEL_SERVER_ERROR_LOG_MSGS = [ - "INFO exited: vllm", # vLLM - "RuntimeError: Engine", # vLLM - "Error: pull model manifest:", # Ollama - "stalled; retrying", # Ollama - "Error: WebserverFailed", # TGI - "Error: DownloadError", # TGI - "Error: ShardCannotStart", # TGI -] - -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) - -backend = Backend( - model_server_url=os.environ["MODEL_SERVER_URL"], - model_log_file=os.environ["MODEL_LOG"], - allow_parallel_requests=True, - benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256), - log_actions=[ - *[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG], - (LogAction.Info, '"message":"Download'), - *[ - (LogAction.ModelError, error_msg) - for error_msg in MODEL_SERVER_ERROR_LOG_MSGS - ], - ], -) - - -async def handle_ping(_): - return web.Response(body="pong") - - -routes = [ - web.post("/v1/completions", backend.create_handler(CompletionsHandler())), - web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())), - web.get("/ping", handle_ping), -] - -if __name__ == "__main__": - start_server(backend, routes)