6de4ee2b59
Update README.md
322 lines
13 KiB
Markdown
322 lines
13 KiB
Markdown
# Vast PyWorker
|
|
|
|
## Hello_world example
|
|
|
|
There is a hello_world PyWorker implementation under `workers/hello_world`. This PyWorker is
|
|
created for an LLM model server that runs on port 5001 has two API endpoints:
|
|
|
|
1. `/generate`: generates an full response to the prompt and sends a JSON response
|
|
2. `/generate_stream`: streams a response one token at a time
|
|
|
|
Both of these endpoints take the same API JSON payload:
|
|
|
|
```
|
|
{
|
|
"prompt": String,
|
|
"max_response_tokens": Number | null
|
|
}
|
|
```
|
|
|
|
We want the PyWorker to also expose two endpoints that correspond to the above endpoints.
|
|
|
|
### Structure
|
|
|
|
All PyWorkers have four files:
|
|
|
|
```
|
|
.
|
|
└── workers
|
|
└── hello_world
|
|
├── __init__.py
|
|
├── data_types.py # contains data types representing model API endpoints
|
|
├── server.py # contains endpoint handlers
|
|
└── test_load.py # script for load testing
|
|
|
|
```
|
|
|
|
All of the classes follow strict type hinting. It is recommended that you type hint all of your function.
|
|
This will allow your IDE or VSCode with `pyright` plugin to find any type errors in your implementation.
|
|
You can also install `pyright` with `sudo npm install -g pyright` and run `pyright` in the root of the project to find
|
|
any type errors.
|
|
|
|
### data_types.py: Contains data types representing model API endpoints
|
|
|
|
This file defines the structure of the data your model server expects (its API contract) and, critically, how PyWorker *interprets* that data for autoscaling purposes. You define Python data classes that mirror the JSON payloads your model's API uses.
|
|
|
|
These classes **must** inherit from `lib.data_types.ApiPayload`. This inheritance is not just for structure; it's how PyWorker knows how to:
|
|
|
|
* **Parse Incoming Requests:** Convert JSON from clients into usable Python objects.
|
|
* **Calculate Workload:** Determine the computational cost of a request.
|
|
* **Generate Test Data:** Create realistic inputs for benchmarking.
|
|
* **Format Requests for the Model Server:** Prepare data for the underlying model.
|
|
|
|
|
|
```python
|
|
import dataclasses
|
|
import random
|
|
from typing import Dict, Any
|
|
|
|
from transformers import AutoTokenizer # used to count tokens in a prompt
|
|
import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model
|
|
|
|
from lib.data_types import ApiPayload
|
|
|
|
nltk.download("words")
|
|
WORD_LIST = nltk.corpus.words.words()
|
|
|
|
# you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs
|
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
|
|
|
|
@dataclasses.dataclass
|
|
class InputData(ApiPayload):
|
|
prompt: str
|
|
max_response_tokens: int
|
|
|
|
@classmethod
|
|
def for_test(cls) -> "ApiPayload":
|
|
"""defines how create a payload for load testing"""
|
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
|
return cls(prompt=prompt, max_response_tokens=300)
|
|
|
|
def generate_payload_json(self) -> Dict[str, Any]:
|
|
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
|
|
return dataclasses.asdict(self)
|
|
|
|
def count_workload(self) -> float:
|
|
"""defines how to calculate workload for a payload"""
|
|
return len(tokenizer.tokenize(self.prompt))
|
|
|
|
@classmethod
|
|
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
|
"""
|
|
defines how to transform JSON data to AuthData and payload type,
|
|
in this case `InputData` defined above represents the data sent to the model API.
|
|
AuthData is data generated by autoscaler in order to authenticate payloads.
|
|
In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker
|
|
for more complicated examples
|
|
"""
|
|
errors = {}
|
|
for param in inspect.signature(cls).parameters:
|
|
if param not in json_msg:
|
|
errors[param] = "missing parameter"
|
|
if errors:
|
|
raise JsonDataException(errors)
|
|
return cls(
|
|
**{
|
|
k: v
|
|
for k, v in json_msg.items()
|
|
if k in inspect.signature(cls).parameters
|
|
}
|
|
)
|
|
|
|
```
|
|
|
|
### server.py: Creating Your Model's API Endpoints
|
|
|
|
This section guides you through creating the core of your custom model API: the `EndpointHandler`. Think of `EndpointHandler` as the bridge between incoming requests from users and your underlying model. It's the key to making your model accessible and scalable.
|
|
|
|
**Why use an `EndpointHandler`?**
|
|
|
|
* **Organized Request Handling:** It provides a structured way to handle different types of requests (like generating text, generating images, or performing other model-specific tasks).
|
|
* **Scalability:** By separating request handling from the model itself, you can easily scale your API to handle many concurrent users.
|
|
* **Flexibility:** You can customize how requests are processed, validated, and transformed before being sent to your model.
|
|
* **Standard Interface:** It provides a consistent interface for interacting with your model, regardless of the underlying implementation.
|
|
|
|
For every model API endpoint you want to expose (e.g., `/generate`, `/generate_stream`), you'll implement an `EndpointHandler`. This class is responsible for:
|
|
The `EndpointHandler` achieves this through several key methods:
|
|
|
|
* **Receiving and validating incoming requests (`get_data_from_request`):** This method ensures the request contains the necessary data (authentication and payload) and is in the correct format. It's the entry point for all requests.
|
|
* **Defining the endpoint (`endpoint`):** This method specifies the URL endpoint on the model API server where requests will be sent (e.g., `/generate`).
|
|
* **Specifying the payload type (`payload_cls`):** This method indicates the specific `ApiPayload` class used for this endpoint, defining the structure of the request data.
|
|
* **Creating benchmark payloads (`make_benchmark_payload`):** This method creates payloads specifically for benchmarking the model's performance.
|
|
* **Handling the model's response (`generate_client_response`):** This method takes the response from the model API server and transforms it into the format expected by the client making the request to your PyWorker. This allows you to customize the output as needed.
|
|
|
|
The `EndpointHandler` class has several abstract functions that you *must* implement to define the behavior of your specific endpoints. Here, we'll implement two common endpoints: `/generate` (for synchronous requests) and `/generate_stream` (for streaming responses):
|
|
|
|
```python
|
|
|
|
"""
|
|
AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route.
|
|
When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData
|
|
work.
|
|
When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON:
|
|
{
|
|
auth_data: AuthData,
|
|
payload : InputData # defined above
|
|
}
|
|
"""
|
|
from aiohttp import web
|
|
|
|
from lib.data_types import EndpointHandler, JsonDataException
|
|
from lib.server import start_server
|
|
from .data_types import InputData
|
|
|
|
# This class is the implementer for the '/generate' endpoint of model API
|
|
@dataclasses.dataclass
|
|
class GenerateHandler(EndpointHandler[InputData]):
|
|
|
|
@property
|
|
def endpoint(self) -> str:
|
|
# the API endpoint
|
|
return "/generate"
|
|
|
|
@classmethod
|
|
def payload_cls(cls) -> Type[InputData]:
|
|
"""this function should just return ApiPayload subclass used by this handler"""
|
|
return InputData
|
|
|
|
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
|
"""
|
|
defines how to convert `InputData` defined above, to
|
|
JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but
|
|
can be more complicated, See comfyui for an example
|
|
"""
|
|
return dataclasses.asdict(payload)
|
|
|
|
def make_benchmark_payload(self) -> InputData:
|
|
"""
|
|
defines how to generate an InputData for benchmarking. This needs to be defined in only
|
|
one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test()
|
|
method on InputData. However, in some cases you might need to fine tune your InputData used for
|
|
benchmarking to closely resemble the average request users call the endpoint with in order to get best
|
|
autoscaling performance
|
|
"""
|
|
return InputData.for_test()
|
|
|
|
async def generate_client_response(
|
|
self, client_request: web.Request, model_response: ClientResponse
|
|
) -> Union[web.Response, web.StreamResponse]:
|
|
"""
|
|
defines how to convert a model API response to a response to PyWorker client
|
|
"""
|
|
_ = client_request
|
|
match model_response.status:
|
|
case 200:
|
|
log.debug("SUCCESS")
|
|
data = await model_response.json()
|
|
return web.json_response(data=data)
|
|
case code:
|
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
|
return web.Response(status=code)
|
|
|
|
|
|
```
|
|
|
|
We also handle `GenerateStreamHandler` for streaming responses. It is identical to `GenerateHandler`, except for
|
|
the endpoint name and how we create a web response, as it is a streaming response:
|
|
|
|
```python
|
|
class GenerateStreamHandler(EndpointHandler[InputData]):
|
|
@property
|
|
def endpoint(self) -> str:
|
|
return "/generate_stream"
|
|
|
|
@classmethod
|
|
def payload_cls(cls) -> Type[InputData]:
|
|
return InputData
|
|
|
|
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
|
return dataclasses.asdict(payload)
|
|
|
|
def make_benchmark_payload(self) -> InputData:
|
|
return InputData.for_test()
|
|
|
|
async def generate_client_response(
|
|
self, client_request: web.Request, model_response: ClientResponse
|
|
) -> Union[web.Response, web.StreamResponse]:
|
|
match model_response.status:
|
|
case 200:
|
|
log.debug("Streaming response...")
|
|
res = web.StreamResponse()
|
|
res.content_type = "text/event-stream"
|
|
await res.prepare(client_request)
|
|
async for chunk in model_response.content:
|
|
await res.write(chunk)
|
|
await res.write_eof()
|
|
log.debug("Done streaming response")
|
|
return res
|
|
case code:
|
|
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
|
return web.Response(status=code)
|
|
|
|
|
|
```
|
|
|
|
You can now instantiate a Backend and use it to handle requests.
|
|
|
|
```python
|
|
from lib.backend import Backend, LogAction
|
|
|
|
# the url and port of model API
|
|
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
|
|
|
|
|
# This is the log line that is emitted once the server has started
|
|
MODEL_SERVER_START_LOG_MSG = "server has started"
|
|
MODEL_SERVER_ERROR_LOG_MSGS = [
|
|
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
|
|
]
|
|
|
|
backend = Backend(
|
|
model_server_url=MODEL_SERVER_URL,
|
|
# location of model log file
|
|
model_log_file=os.environ["MODEL_LOG"],
|
|
# for some model backends that can only handle one request at a time, be sure to set this to False to
|
|
# let PyWorker handling queueing requests.
|
|
allow_parallel_requests=True,
|
|
# give the backend an EndpointHandler instance that is used for benchmarking
|
|
# number of benchmark run and number of words for a random benchmark run are given
|
|
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
|
# defines how to handle specific log messages. See docstring of LogAction for details
|
|
log_actions=[
|
|
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
|
(LogAction.Info, '"message":"Download'),
|
|
*[
|
|
(LogAction.ModelError, error_msg)
|
|
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
|
],
|
|
],
|
|
)
|
|
|
|
# this is a simple ping handler for PyWorker
|
|
async def handle_ping(_: web.Request):
|
|
return web.Response(body="pong")
|
|
|
|
# this is a handler for forwarding a health check to model API
|
|
async def handle_healthcheck(_: web.Request):
|
|
healthcheck_res = await backend.session.get("/healthcheck")
|
|
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
|
|
|
|
routes = [
|
|
web.post("/generate", backend.create_handler(GenerateHandler())),
|
|
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
|
web.get("/ping", handle_ping),
|
|
web.get("/healthcheck", handle_healthcheck),
|
|
]
|
|
|
|
if __name__ == "__main__":
|
|
# start server, called from start_server.sh
|
|
start_server(backend, routes)
|
|
```
|
|
|
|
### test_load.py
|
|
|
|
Here you can create a script that allows you test an endpoint group running instances with this PyWorker
|
|
|
|
```python
|
|
from lib.test_harness import run
|
|
from .data_types import InputData
|
|
|
|
WORKER_ENDPOINT = "/generate"
|
|
|
|
if __name__ == "__main__":
|
|
run(InputData.for_test(), WORKER_ENDPOINT)
|
|
```
|
|
|
|
You can then run the following command from the root of this repo to load test endpoint group:
|
|
|
|
```sh
|
|
# sends 1000 requests at the rate of 0.5 requests per second
|
|
python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"
|
|
```
|