remove input wrapping for vllm

This commit is contained in:
Lucas Armand
2025-12-12 11:48:54 -08:00
parent 2b30c69933
commit ccd29ed8b6
2 changed files with 23 additions and 92 deletions
+23 -31
View File
@@ -102,12 +102,10 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"input": { "model": model,
"model": model, "prompt": prompt,
"prompt": prompt, "max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS), "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
}
} }
log.debug("POST /v1/completions %s", json.dumps(payload)[:500]) log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
@@ -118,14 +116,12 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"input": { "model": model,
"model": model, "messages": messages,
"messages": messages, "max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS), "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), **({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}), **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
} }
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500]) log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"]) resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
@@ -137,14 +133,12 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"input": { "model": model,
"model": model, "prompt": prompt,
"prompt": prompt, "max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS), "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), "stream": True,
"stream": True, **({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
}
} }
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500]) log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
@@ -155,15 +149,13 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"input": { "model": model,
"model": model, "messages": messages,
"messages": messages, "max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS), "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), "stream": True,
"stream": True, **({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}), **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
} }
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500]) log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True) resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
-61
View File
@@ -1,61 +0,0 @@
import os
import logging
from .data_types.server import CompletionsHandler, ChatCompletionsHandler
from aiohttp import web
from lib.backend import Backend, LogAction
from lib.server import start_server
# This line indicates that the inference server is listening
MODEL_SERVER_START_LOG_MSG = [
"Application startup complete.", # vLLM
"llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI
"main: model loaded" # llama.cpp
]
MODEL_SERVER_ERROR_LOG_MSGS = [
"INFO exited: vllm", # vLLM
"RuntimeError: Engine", # vLLM
"Error: pull model manifest:", # Ollama
"stalled; retrying", # Ollama
"Error: WebserverFailed", # TGI
"Error: DownloadError", # TGI
"Error: ShardCannotStart", # TGI
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
backend = Backend(
model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)