remove input wrapping for vllm

This commit is contained in:
Lucas Armand
2025-12-12 11:48:54 -08:00
parent 2b30c69933
commit ccd29ed8b6
2 changed files with 23 additions and 92 deletions
-8
View File
@@ -102,13 +102,11 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"input": {
"model": model, "model": model,
"prompt": prompt, "prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS), "max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
} }
}
log.debug("POST /v1/completions %s", json.dumps(payload)[:500]) log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"] return resp["response"]
@@ -118,7 +116,6 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"input": {
"model": model, "model": model,
"messages": messages, "messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS), "max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
@@ -126,7 +123,6 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}), **({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
} }
}
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500]) log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"]) resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"] return resp["response"]
@@ -137,7 +133,6 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"input": {
"model": model, "model": model,
"prompt": prompt, "prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS), "max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
@@ -145,7 +140,6 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
"stream": True, "stream": True,
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}), **({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
} }
}
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500]) log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator return resp["response"] # async generator
@@ -155,7 +149,6 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
endpoint = await client.get_endpoint(name=ENDPOINT_NAME) endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = { payload = {
"input": {
"model": model, "model": model,
"messages": messages, "messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS), "max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
@@ -164,7 +157,6 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}), **({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
} }
}
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500]) log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True) resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator return resp["response"] # async generator
-61
View File
@@ -1,61 +0,0 @@
import os
import logging
from .data_types.server import CompletionsHandler, ChatCompletionsHandler
from aiohttp import web
from lib.backend import Backend, LogAction
from lib.server import start_server
# This line indicates that the inference server is listening
MODEL_SERVER_START_LOG_MSG = [
"Application startup complete.", # vLLM
"llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI
"main: model loaded" # llama.cpp
]
MODEL_SERVER_ERROR_LOG_MSGS = [
"INFO exited: vllm", # vLLM
"RuntimeError: Engine", # vLLM
"Error: pull model manifest:", # Ollama
"stalled; retrying", # Ollama
"Error: WebserverFailed", # TGI
"Error: DownloadError", # TGI
"Error: ShardCannotStart", # TGI
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
backend = Backend(
model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)