Use PyWorker SDK (#67)
* Change PyWorker to Worker SDK * Moved /lib to vast-sdk (https://github.com/vast-ai/vast-sdk)
This commit is contained in:
@@ -1,77 +0,0 @@
|
||||
# <INFERENCE_SERVER> + <MODEL_NAME> (serverless)
|
||||
|
||||
Run <INFERENCE_SERVER> with our serverless autoscaling infrastructure.
|
||||
|
||||
See the [serverless documentation](https://docs.vast.ai/serverless) and the [Getting Started](https://docs.vast.ai/serverless/getting-started) guide for in-depth details about how to use these templates.
|
||||
|
||||
## Configuration
|
||||
|
||||
Two environment variables are provided to help you configure the <INFERENCE_SERVER> server:
|
||||
|
||||
| Variable | Default Value | Used For |
|
||||
| --- | --- | --- |
|
||||
| `MODEL_NAME` | `<MODEL_NAME>` | The model to load. Also accepts [hf.co/repo/model](#) links |
|
||||
| `<ARGS_VAR>` | `<ARGS_VAL>` | Arguments to pass to the `<ARGS_RECEIVER>` command |
|
||||
|
||||
This template has been configured to work with <MIN_VRAM> VRAM. Setting alternative models and server arguments will change the VRAM requirements. Check model cards and <INFERENCE_SERVER_DOCS> for guidance.
|
||||
|
||||
## Usage
|
||||
|
||||
We have provided a demonstration client to help you implement this template into your own infrastructure
|
||||
|
||||
### Client Setup
|
||||
|
||||
Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vast-ai/pyworker
|
||||
cd pyworker
|
||||
pip install uv
|
||||
uv venv -p 3.12
|
||||
source .venv/bin/activate
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (streaming)
|
||||
|
||||
Call to `/v1/chat/completions` with streaming response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
|
||||
Call to `/v1/chat/completions` with tool and json response.
|
||||
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Interactive Chat (streaming)
|
||||
|
||||
Interactive session with calls to `/v1/chat/completions`.
|
||||
|
||||
Type `clear` to clear the chat history or `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||
```
|
||||
+27
-35
@@ -102,15 +102,13 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, endpo
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
}
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
}
|
||||
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||
@@ -118,17 +116,15 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
# ---- Streaming variants ----
|
||||
@@ -137,17 +133,15 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, end
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
|
||||
}
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
|
||||
}
|
||||
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
|
||||
@@ -155,18 +149,16 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"stream": True,
|
||||
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
|
||||
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
|
||||
}
|
||||
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
|
||||
class SerializableDataclass:
|
||||
def _serialize_recursive(self, obj: Any) -> Any:
|
||||
if is_dataclass(obj):
|
||||
return {
|
||||
field.name: self._serialize_recursive(getattr(obj, field.name))
|
||||
for field in fields(obj)
|
||||
}
|
||||
elif isinstance(obj, dict):
|
||||
return {key: self._serialize_recursive(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [self._serialize_recursive(item) for item in obj]
|
||||
elif isinstance(obj, set):
|
||||
return [self._serialize_recursive(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return self._serialize_recursive(self)
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionConfig(SerializableDataclass):
|
||||
"""Configuration for completion requests"""
|
||||
|
||||
model: str
|
||||
prompt: str = "Hello"
|
||||
max_tokens: int = 256
|
||||
temperature: float = 0.7
|
||||
top_k: int = 20
|
||||
top_p: float = 0.4
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionConfig(SerializableDataclass):
|
||||
"""Configuration for chat completion requests"""
|
||||
|
||||
model: str
|
||||
messages: list = field(default_factory=list)
|
||||
max_tokens: int = 2096
|
||||
temperature: float = 0.7
|
||||
top_k: int = 20
|
||||
top_p: float = 0.4
|
||||
stream: bool = False
|
||||
tools: Optional[List[Dict[str, Any]]] = field(default_factory=list)
|
||||
tool_choice: str = "auto"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.messages is None:
|
||||
self.messages = [{"role": "user", "content": "Hello"}]
|
||||
@@ -1,207 +0,0 @@
|
||||
import os, json, random
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
|
||||
from typing import Union, Type, Dict, Any, Optional
|
||||
from aiohttp import web, ClientResponse
|
||||
import nltk
|
||||
import logging
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
Generic dataclass accepts any dictionary in input.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenericData(ApiPayload, ABC):
|
||||
input: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
|
||||
return cls(input=data["input"])
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
|
||||
errors = {}
|
||||
|
||||
# Validate required parameters
|
||||
required_params = ["input"]
|
||||
for param in required_params:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
|
||||
try:
|
||||
# Create clean data dict and delegate to from_dict
|
||||
clean_data = {"input": json_msg["input"]}
|
||||
|
||||
return cls.from_dict(clean_data)
|
||||
|
||||
except (json.JSONDecodeError, JsonDataException) as e:
|
||||
errors["parameters"] = str(e)
|
||||
raise JsonDataException(errors)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def for_test(cls) -> "GenericData":
|
||||
pass
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
return self.input
|
||||
|
||||
def count_workload(self) -> int:
|
||||
return self.input.get("max_tokens", 0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenericHandler(EndpointHandler[GenericData], ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def endpoint(self) -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> Optional[str]:
|
||||
return os.environ.get("MODEL_HEALTH_ENDPOINT")
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[GenericData]:
|
||||
return GenericData
|
||||
|
||||
@abstractmethod
|
||||
def make_benchmark_payload(self) -> GenericData:
|
||||
pass
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
# Check if the response is actually streaming based on response headers/content-type
|
||||
is_streaming_response = (
|
||||
model_response.content_type == "text/event-stream"
|
||||
or model_response.content_type == "application/x-ndjson"
|
||||
or model_response.headers.get("Transfer-Encoding") == "chunked"
|
||||
or "stream" in model_response.content_type.lower()
|
||||
)
|
||||
|
||||
if is_streaming_response:
|
||||
log.debug("Detected streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = model_response.content_type
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
else:
|
||||
log.debug("Detected non-streaming response...")
|
||||
content = await model_response.read()
|
||||
return web.Response(
|
||||
body=content,
|
||||
status=200,
|
||||
content_type=model_response.content_type,
|
||||
)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionsData(GenericData):
|
||||
@classmethod
|
||||
def for_test(cls) -> "CompletionsData":
|
||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||
|
||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||
woodlands, shrublands, and mountainous areas.
|
||||
|
||||
Please answer the following question based on the above context."""
|
||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
test_input = {
|
||||
"model": model,
|
||||
"prompt": f"{system_prompt}\n\n{unique_question}",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
return cls(input=test_input)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionsHandler(GenericHandler):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/v1/completions"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[CompletionsData]:
|
||||
return CompletionsData
|
||||
|
||||
def make_benchmark_payload(self) -> CompletionsData:
|
||||
return CompletionsData.for_test()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionsData(GenericData):
|
||||
"""Chat completions-specific data implementation"""
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "ChatCompletionsData":
|
||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
||||
|
||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
||||
woodlands, shrublands, and mountainous areas.
|
||||
|
||||
Please answer the following question based on the above context."""
|
||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
# Chat completions use messages format instead of prompt
|
||||
test_input = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt}, # Shared prefix
|
||||
{"role": "user", "content": unique_question} # Unique per request
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
return cls(input=test_input)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionsHandler(GenericHandler):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/v1/chat/completions"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[ChatCompletionsData]:
|
||||
return ChatCompletionsData
|
||||
|
||||
def make_benchmark_payload(self) -> ChatCompletionsData:
|
||||
return ChatCompletionsData.for_test()
|
||||
@@ -1,62 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
from .data_types.server import CompletionsHandler, ChatCompletionsHandler
|
||||
from aiohttp import web
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.server import start_server
|
||||
|
||||
# This line indicates that the inference server is listening
|
||||
MODEL_SERVER_START_LOG_MSG = [
|
||||
"Application startup complete.", # vLLM
|
||||
"llama runner started", # Ollama
|
||||
'"message":"Connected","target":"text_generation_router"', # TGI
|
||||
'"message":"Connected","target":"text_generation_router::server"', # TGI
|
||||
"main: model loaded" # llama.cpp
|
||||
]
|
||||
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"INFO exited: vllm", # vLLM
|
||||
"RuntimeError: Engine", # vLLM
|
||||
"Error: pull model manifest:", # Ollama
|
||||
"stalled; retrying", # Ollama
|
||||
"Error: WebserverFailed", # TGI
|
||||
"Error: DownloadError", # TGI
|
||||
"Error: ShardCannotStart", # TGI
|
||||
]
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=os.environ["MODEL_SERVER_URL"],
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
max_wait_time=600.0,
|
||||
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def handle_ping(_):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
|
||||
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server(backend, routes)
|
||||
@@ -1,434 +0,0 @@
|
||||
from lib.test_utils import test_args
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from lib.data_types import AuthData
|
||||
from .data_types.server import CompletionsData
|
||||
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import requests
|
||||
from dataclasses import dataclass
|
||||
from collections import Counter
|
||||
from urllib.parse import urljoin, urlparse
|
||||
import re
|
||||
|
||||
# Headless plotting
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import logging
|
||||
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
|
||||
from requests.adapters import HTTPAdapter
|
||||
|
||||
def get_incremented_path(path: str) -> str:
|
||||
base, ext = os.path.splitext(path)
|
||||
if not os.path.exists(path):
|
||||
return path
|
||||
i = 1
|
||||
while os.path.exists(f"{base}-{i}{ext}"):
|
||||
i += 1
|
||||
return f"{base}-{i}{ext}"
|
||||
|
||||
WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT)
|
||||
|
||||
@dataclass
|
||||
class ReqResult:
|
||||
worker_url: str
|
||||
route_ms: float
|
||||
worker_ms: float
|
||||
total_ms: float
|
||||
ok: bool
|
||||
error: str = ""
|
||||
status_code: int = 0
|
||||
t_start: float = 0.0
|
||||
t_end: float = 0.0
|
||||
workload: float = 0.0
|
||||
|
||||
def do_one(endpoint_name: str,
|
||||
endpoint_id: int,
|
||||
endpoint_api_key: str,
|
||||
server_url: str,
|
||||
worker_endpoint: str,
|
||||
payload,
|
||||
results_list,
|
||||
t0,
|
||||
status_samples,
|
||||
route_session,
|
||||
worker_session):
|
||||
try:
|
||||
workload = payload.count_workload()
|
||||
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
|
||||
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
|
||||
start = time.time()
|
||||
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
|
||||
t_after_route = time.time()
|
||||
if r0.status_code != 200:
|
||||
results_list.append(ReqResult(worker_url="",
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=0.0,
|
||||
total_ms=(t_after_route - start) * 1000.0,
|
||||
ok=False,
|
||||
error=f"route error {r0.reason} {r0.text}",
|
||||
status_code=r0.status_code,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_route - t0,
|
||||
workload=workload))
|
||||
return
|
||||
msg = r0.json()
|
||||
|
||||
# 1) Check if we got a worker back from route
|
||||
worker_url = msg.get("url", "")
|
||||
if not worker_url:
|
||||
status = msg.get("status", "")
|
||||
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
|
||||
if m:
|
||||
tot, loading, standby, err = map(int, m.groups())
|
||||
idle = max(tot - loading - standby - err, 0)
|
||||
status_samples.append((time.time() - t0, idle))
|
||||
|
||||
# 2) If we got a worker, send the request
|
||||
if worker_url:
|
||||
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
|
||||
t_before_worker = time.time()
|
||||
r1 = worker_session.post(
|
||||
urljoin(worker_url, worker_endpoint),
|
||||
json=req,
|
||||
verify=get_cert_file_path(),
|
||||
timeout=(4, 120),
|
||||
)
|
||||
t_after_worker = time.time()
|
||||
if r1.status_code != 200:
|
||||
results_list.append(ReqResult(worker_url=worker_url,
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||
total_ms=(t_after_worker - start) * 1000.0,
|
||||
ok=False,
|
||||
error=f"worker inference error {r1.reason} {r1.text}",
|
||||
status_code=r1.status_code,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_worker - t0,
|
||||
workload=workload))
|
||||
return
|
||||
# Success case
|
||||
results_list.append(ReqResult(worker_url=worker_url,
|
||||
route_ms=(t_after_route - start) * 1000.0,
|
||||
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
|
||||
total_ms=(t_after_worker - start) * 1000.0,
|
||||
ok=True,
|
||||
error="",
|
||||
status_code=200,
|
||||
t_start=start - t0,
|
||||
t_end=t_after_worker - t0,
|
||||
workload=workload))
|
||||
|
||||
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
|
||||
if worker_url:
|
||||
try:
|
||||
r_status = route_session.post(
|
||||
urljoin(server_url, "/get_endpoint_workers/"),
|
||||
json={"id": endpoint_id},
|
||||
headers={"Authorization": f"Bearer {endpoint_api_key}"},
|
||||
timeout=3,
|
||||
)
|
||||
if r_status.status_code == 200:
|
||||
workers = r_status.json()
|
||||
idle = 0
|
||||
for w in workers:
|
||||
st = str(w.get("status", "")).lower()
|
||||
if (st in ("idle")):
|
||||
idle += 1
|
||||
status_samples.append((time.time() - t0, idle))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
t = time.time()
|
||||
results_list.append(ReqResult(worker_url="",
|
||||
route_ms=0.0,
|
||||
worker_ms=0.0,
|
||||
total_ms=0.0,
|
||||
ok=False,
|
||||
error=f"unknown error {e}",
|
||||
status_code=0,
|
||||
t_start=t - t0,
|
||||
t_end=t - t0,
|
||||
workload=0.0))
|
||||
|
||||
def run_load_with_metrics(num_requests: int,
|
||||
requests_per_second: float,
|
||||
endpoint_group_name: str,
|
||||
account_api_key: str,
|
||||
server_url: str,
|
||||
worker_endpoint: str,
|
||||
instance: str,
|
||||
out_path: str):
|
||||
|
||||
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
|
||||
account_api_key=account_api_key,
|
||||
instance=instance)
|
||||
if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"):
|
||||
print(f"Endpoint {endpoint_group_name} not found for API key")
|
||||
return
|
||||
endpoint_id = int(ep_info["id"])
|
||||
endpoint_api_key = ep_info["api_key"]
|
||||
|
||||
t0 = time.time()
|
||||
results = []
|
||||
status_samples = []
|
||||
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
|
||||
submit_queue_factor = 2 # cap queued tasks to reduce memory
|
||||
|
||||
# Shared HTTP sessions with connection pooling (persistent connections)
|
||||
def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session:
|
||||
sess = requests.Session()
|
||||
adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0)
|
||||
sess.mount("https://", adapter)
|
||||
sess.mount("http://", adapter)
|
||||
return sess
|
||||
|
||||
# Router: mostly single host, small connection pool is sufficient
|
||||
route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
|
||||
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
|
||||
worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
|
||||
|
||||
# Fire requests using a thread pool, scheduling at requested RPS
|
||||
inflight = set()
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
for i in range(num_requests):
|
||||
# Pace submissions to RPS
|
||||
target_time = t0 + i / max(requests_per_second, 1e-9)
|
||||
sleep_s = target_time - time.time()
|
||||
if sleep_s > 0:
|
||||
time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive
|
||||
|
||||
payload = CompletionsData.for_test()
|
||||
fut = executor.submit(
|
||||
do_one,
|
||||
endpoint_group_name,
|
||||
endpoint_id,
|
||||
endpoint_api_key,
|
||||
server_url,
|
||||
worker_endpoint,
|
||||
payload,
|
||||
results,
|
||||
t0,
|
||||
status_samples,
|
||||
route_session,
|
||||
worker_session,
|
||||
)
|
||||
inflight.add(fut)
|
||||
# Prevent unbounded queue growth
|
||||
if len(inflight) >= max_concurrency * submit_queue_factor:
|
||||
done, not_done = wait(inflight, return_when=FIRST_COMPLETED)
|
||||
inflight = not_done
|
||||
# Wait for all outstanding tasks
|
||||
if inflight:
|
||||
wait(inflight)
|
||||
# Close sessions
|
||||
try:
|
||||
route_session.close()
|
||||
finally:
|
||||
worker_session.close()
|
||||
|
||||
# Aggregate results
|
||||
oks = [r for r in results if r.ok]
|
||||
errs = [r for r in results if not r.ok]
|
||||
total_reqs = len(results)
|
||||
succ = len(oks)
|
||||
|
||||
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
|
||||
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
|
||||
route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
|
||||
|
||||
avg_total = float(np.mean(total_ms)) if succ else 0.0
|
||||
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
|
||||
avg_route = float(np.mean(route_ms)) if succ else 0.0
|
||||
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
|
||||
|
||||
# Distribution over workers (by host:port)
|
||||
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
|
||||
dist = Counter(hosts)
|
||||
|
||||
# Idle over time (mode per second)
|
||||
idle_ts, idle_vals = [], []
|
||||
if status_samples:
|
||||
buckets = {}
|
||||
for ts, idle in status_samples:
|
||||
k = int(ts)
|
||||
buckets.setdefault(k, []).append(idle)
|
||||
keys = sorted(buckets.keys())
|
||||
idle_ts = keys
|
||||
# Use the most frequent sampled value per second (mode) to keep integer counts
|
||||
idle_vals = []
|
||||
for k in keys:
|
||||
vals_k = [int(v) for v in buckets[k]]
|
||||
if vals_k:
|
||||
cnt = Counter(vals_k)
|
||||
idle_vals.append(cnt.most_common(1)[0][0])
|
||||
else:
|
||||
idle_vals.append(0)
|
||||
|
||||
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
|
||||
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
|
||||
print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
|
||||
if errs:
|
||||
print("Sample errors:")
|
||||
for e in errs[:5]:
|
||||
print(f" {e.status_code} {e.error}")
|
||||
|
||||
# Plot: 2x3 grid
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||||
fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}")
|
||||
|
||||
# Dist per worker
|
||||
ax0 = axes[0, 0]
|
||||
if dist:
|
||||
items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True)
|
||||
labels, counts = zip(*items)
|
||||
ax0.bar(range(len(labels)), counts)
|
||||
ax0.set_xticks(range(len(labels)))
|
||||
ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||
ax0.set_title("Request distribution over workers")
|
||||
ax0.set_ylabel("count")
|
||||
|
||||
# Latency histogram (total)
|
||||
ax1 = axes[0, 1]
|
||||
if succ:
|
||||
ax1.hist(total_ms, bins=30)
|
||||
ax1.set_title("Total latency (ms)")
|
||||
ax1.set_xlabel("ms")
|
||||
ax1.set_ylabel("freq")
|
||||
|
||||
# Eligible workers over time
|
||||
ax_idle = axes[0, 2]
|
||||
if idle_ts:
|
||||
ax_idle.plot(idle_ts, idle_vals, "-o", ms=3)
|
||||
ax_idle.set_title("Eligible workers over time")
|
||||
ax_idle.set_xlabel("time (s)")
|
||||
ax_idle.set_ylabel("eligible count")
|
||||
|
||||
# Throughput over time (completions/sec)
|
||||
ax_idle = axes[1, 0]
|
||||
ax_idle.clear()
|
||||
if succ:
|
||||
per_sec = {}
|
||||
for r in oks:
|
||||
s = int(r.t_end)
|
||||
per_sec[s] = per_sec.get(s, 0) + 1
|
||||
ts = sorted(per_sec.keys())
|
||||
vals = [per_sec[t] for t in ts]
|
||||
ax_idle.plot(ts, vals, "-o", ms=3)
|
||||
ax_idle.set_title("Completions per second")
|
||||
ax_idle.set_xlabel("time (s)")
|
||||
ax_idle.set_ylabel("completions / sec")
|
||||
|
||||
# Summary text
|
||||
ax3 = axes[1, 1]
|
||||
ax3.axis("off")
|
||||
text = (
|
||||
f"Total requests: {total_reqs}\n"
|
||||
f"Success: {succ} Errors: {len(errs)}\n"
|
||||
f"Avg total latency: {avg_total:.1f} ms\n"
|
||||
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
|
||||
f"Avg route latency: {avg_route:.1f} ms\n"
|
||||
f"Avg worker latency: {avg_worker:.1f} ms\n"
|
||||
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
|
||||
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
|
||||
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
|
||||
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
|
||||
)
|
||||
ax3.set_title("Summary")
|
||||
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
|
||||
|
||||
# Error count over time
|
||||
ax_errors = axes[1, 2]
|
||||
all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
|
||||
if all_end_times:
|
||||
min_second = min(all_end_times)
|
||||
max_second = max(all_end_times)
|
||||
# Count errors per second
|
||||
errors_per_second = {}
|
||||
for result in errs:
|
||||
second = int(result.t_end)
|
||||
errors_per_second[second] = errors_per_second.get(second, 0) + 1
|
||||
# Create complete timeline including zeros
|
||||
time_seconds = list(range(min_second, max_second + 1))
|
||||
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
|
||||
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
|
||||
ax_errors.set_title("Errors per second")
|
||||
ax_errors.set_xlabel("time (s)")
|
||||
ax_errors.set_ylabel("errors / sec")
|
||||
|
||||
# Ensure unique output path and create directory if needed
|
||||
final_out_path = get_incremented_path(out_path)
|
||||
out_dir = os.path.dirname(final_out_path)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
plt.savefig(final_out_path, dpi=120)
|
||||
print(f"Saved report to: {final_out_path}")
|
||||
|
||||
# Per-worker latency boxplot (top 12 by volume)
|
||||
groups = {}
|
||||
for r in oks:
|
||||
host = urlparse(r.worker_url).netloc
|
||||
groups.setdefault(host, []).append(r.total_ms)
|
||||
items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12]
|
||||
if items:
|
||||
labels, data = zip(*items)
|
||||
fig2, axb = plt.subplots(1, 1, figsize=(12, 5))
|
||||
axb.boxplot(data, showfliers=False)
|
||||
axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
||||
axb.set_title("Per-worker latency (ms)")
|
||||
axb.set_ylabel("ms")
|
||||
plt.tight_layout()
|
||||
extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png")
|
||||
plt.savefig(extra_out, dpi=120)
|
||||
fig2.tight_layout()
|
||||
fig2.savefig(extra_out, dpi=120)
|
||||
print(f"Saved worker latency plot to: {extra_out}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if MODEL_NAME environment variable is set
|
||||
model_name_set = os.environ.get("MODEL_NAME") is not None
|
||||
|
||||
# Add model argument - required only if MODEL_NAME is not set
|
||||
test_args.add_argument(
|
||||
"--model",
|
||||
dest="model",
|
||||
required=not model_name_set,
|
||||
help="Model to use for completions request (required if MODEL_NAME env var not set)",
|
||||
)
|
||||
|
||||
# Parse known args to get model early, before adding load args
|
||||
known_args, _ = test_args.parse_known_args()
|
||||
if hasattr(known_args, "model") and known_args.model:
|
||||
os.environ["MODEL_NAME"] = known_args.model
|
||||
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
|
||||
|
||||
# Load test args
|
||||
test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
|
||||
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
|
||||
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
|
||||
args = test_args.parse_args()
|
||||
|
||||
server_url = {
|
||||
"prod": "https://run.vast.ai",
|
||||
"alpha": "https://run-alpha.vast.ai",
|
||||
"candidate": "https://run-candidate.vast.ai",
|
||||
"local": "http://localhost:8080"
|
||||
}.get(args.instance, "http://localhost:8080")
|
||||
|
||||
run_load_with_metrics(
|
||||
num_requests=args.num_requests,
|
||||
requests_per_second=args.requests_per_second,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
account_api_key=args.api_key,
|
||||
server_url=server_url,
|
||||
worker_endpoint=WORKER_ENDPOINT,
|
||||
instance=args.instance,
|
||||
out_path=args.out_path,
|
||||
)
|
||||
@@ -0,0 +1,78 @@
|
||||
import nltk
|
||||
import random
|
||||
import os
|
||||
|
||||
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
|
||||
|
||||
# vLLM model configuration
|
||||
MODEL_SERVER_URL = 'http://127.0.0.1'
|
||||
MODEL_SERVER_PORT = 18000
|
||||
MODEL_LOG_FILE = '/var/log/portal/vllm.log'
|
||||
MODEL_HEALTHCHECK_ENDPOINT = "/health"
|
||||
|
||||
# vLLM-specific log messages
|
||||
MODEL_LOAD_LOG_MSG = [
|
||||
"Application startup complete.",
|
||||
]
|
||||
|
||||
MODEL_ERROR_LOG_MSGS = [
|
||||
"INFO exited: vllm",
|
||||
"RuntimeError: Engine",
|
||||
"Traceback (most recent call last):"
|
||||
]
|
||||
|
||||
MODEL_INFO_LOG_MSGS = [
|
||||
'"message":"Download'
|
||||
]
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
|
||||
def completions_benchmark_generator() -> dict:
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
benchmark_data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
|
||||
return benchmark_data
|
||||
|
||||
worker_config = WorkerConfig(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
model_log_file=MODEL_LOG_FILE,
|
||||
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
|
||||
handlers=[
|
||||
HandlerConfig(
|
||||
route="/v1/completions",
|
||||
workload_calculator= lambda data: data.get("max_tokens", 0),
|
||||
allow_parallel_requests=True,
|
||||
max_queue_time=60.0,
|
||||
benchmark_config=BenchmarkConfig(
|
||||
generator=completions_benchmark_generator,
|
||||
concurrency=100,
|
||||
runs=2
|
||||
)
|
||||
),
|
||||
HandlerConfig(
|
||||
route="/v1/chat/completions",
|
||||
workload_calculator= lambda data: data.get("max_tokens", 0),
|
||||
allow_parallel_requests=True,
|
||||
max_queue_time=60.0,
|
||||
)
|
||||
],
|
||||
log_action_config=LogActionConfig(
|
||||
on_load=MODEL_LOAD_LOG_MSG,
|
||||
on_error=MODEL_ERROR_LOG_MSGS,
|
||||
on_info=MODEL_INFO_LOG_MSGS
|
||||
)
|
||||
)
|
||||
|
||||
Worker(worker_config).run()
|
||||
Reference in New Issue
Block a user