remove input wrapping for vllm
This commit is contained in:
+23
-31
@@ -102,12 +102,10 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa
|
|||||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"input": {
|
"model": model,
|
||||||
"model": model,
|
"prompt": prompt,
|
||||||
"prompt": prompt,
|
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
|
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
|
||||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
||||||
@@ -118,14 +116,12 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
|
|||||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"input": {
|
"model": model,
|
||||||
"model": model,
|
"messages": messages,
|
||||||
"messages": messages,
|
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
|
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
|
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
|
||||||
@@ -137,14 +133,12 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
|
|||||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"input": {
|
"model": model,
|
||||||
"model": model,
|
"prompt": prompt,
|
||||||
"prompt": prompt,
|
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
"stream": True,
|
||||||
"stream": True,
|
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
|
||||||
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
|
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
|
||||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||||
@@ -155,15 +149,13 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
|
|||||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"input": {
|
"model": model,
|
||||||
"model": model,
|
"messages": messages,
|
||||||
"messages": messages,
|
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
"stream": True,
|
||||||
"stream": True,
|
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
|
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
import os
|
|
||||||
import logging
|
|
||||||
from .data_types.server import CompletionsHandler, ChatCompletionsHandler
|
|
||||||
from aiohttp import web
|
|
||||||
from lib.backend import Backend, LogAction
|
|
||||||
from lib.server import start_server
|
|
||||||
|
|
||||||
# This line indicates that the inference server is listening
|
|
||||||
MODEL_SERVER_START_LOG_MSG = [
|
|
||||||
"Application startup complete.", # vLLM
|
|
||||||
"llama runner started", # Ollama
|
|
||||||
'"message":"Connected","target":"text_generation_router"', # TGI
|
|
||||||
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
|
||||||
"main: model loaded" # llama.cpp
|
|
||||||
]
|
|
||||||
|
|
||||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
|
||||||
"INFO exited: vllm", # vLLM
|
|
||||||
"RuntimeError: Engine", # vLLM
|
|
||||||
"Error: pull model manifest:", # Ollama
|
|
||||||
"stalled; retrying", # Ollama
|
|
||||||
"Error: WebserverFailed", # TGI
|
|
||||||
"Error: DownloadError", # TGI
|
|
||||||
"Error: ShardCannotStart", # TGI
|
|
||||||
]
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.DEBUG,
|
|
||||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
|
||||||
log = logging.getLogger(__file__)
|
|
||||||
|
|
||||||
backend = Backend(
|
|
||||||
model_server_url=os.environ["MODEL_SERVER_URL"],
|
|
||||||
model_log_file=os.environ["MODEL_LOG"],
|
|
||||||
allow_parallel_requests=True,
|
|
||||||
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
|
||||||
log_actions=[
|
|
||||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
|
||||||
(LogAction.Info, '"message":"Download'),
|
|
||||||
*[
|
|
||||||
(LogAction.ModelError, error_msg)
|
|
||||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
|
||||||
],
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_ping(_):
|
|
||||||
return web.Response(body="pong")
|
|
||||||
|
|
||||||
|
|
||||||
routes = [
|
|
||||||
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
|
|
||||||
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
|
|
||||||
web.get("/ping", handle_ping),
|
|
||||||
]
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
start_server(backend, routes)
|
|
||||||
Reference in New Issue
Block a user