initial commit
This commit is contained in:
@@ -0,0 +1,300 @@
|
||||
# Vast PyWorker
|
||||
|
||||
## Hello_world example
|
||||
|
||||
There is a hello_world PyWorker implantation under `workers/hello_world`. This PyWorker is
|
||||
created for an LLM model server that runs on port 5001 has two API endpoints:
|
||||
|
||||
1. `/generate`: generates an full response to the prompt and sends a JSON response
|
||||
2. `/generate_stream`: streams a response one token at a time
|
||||
|
||||
Both of these endpoints take the same API JSON payload:
|
||||
|
||||
```
|
||||
{
|
||||
"prompt": String,
|
||||
"max_response_tokens": Number | null
|
||||
}
|
||||
```
|
||||
|
||||
We want the PyWorker to also expose two endpoints, for each of the above endpoints.
|
||||
|
||||
### Structure
|
||||
|
||||
All PyWorkers have four files:
|
||||
|
||||
```
|
||||
.
|
||||
└── workers
|
||||
└── hello_world
|
||||
├── __init__.py
|
||||
├── data_types.py # contains data types representing model API endpoints
|
||||
├── server.py # contains endpoint handlers
|
||||
├── client.py # a script to call an endpoint through the autoscaler
|
||||
└── test_load.py # script for load testing
|
||||
|
||||
```
|
||||
|
||||
All of the classes follow strict type hinting. It is recommended that you type hint all of your function.
|
||||
This will allow your IDE or VSCode with `pyright` plugin to find any type errors in your implementation.
|
||||
You can also install `pyright` with `sudo npm install -g pyright` and run `pyright` in the root of the project to find
|
||||
any type errors.
|
||||
|
||||
#### data_Types.py
|
||||
|
||||
data classes representing the model API are defined here. They must inherit from
|
||||
`lib.data_types.ApiPayload`. `ApiPayload` is an abstract class and you need to define several functions for it:
|
||||
|
||||
```python
|
||||
import dataclasses
|
||||
import random
|
||||
from typing import Dict, Any
|
||||
|
||||
from transformers import AutoTokenizer # used to count tokens in a prompt
|
||||
import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model
|
||||
|
||||
from lib.data_types import ApiPayload
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
# you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputData(ApiPayload):
|
||||
prompt: str
|
||||
max_response_tokens: int
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "ApiPayload":
|
||||
"""defines how create a payload for load testing"""
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
return cls(prompt=prompt, max_response_tokens=300)
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def count_workload(self) -> float:
|
||||
"""defines how to calculate workload for a payload"""
|
||||
return len(tokenizer.tokenize(self.prompt))
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||
"""
|
||||
defines how to transform JSON data to AuthData and payload type,
|
||||
in this case `InputData` defined above represents the data sent to the model API.
|
||||
AuthData is data generated by autoscaler in order to authenticate payloads.
|
||||
In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker
|
||||
for more complicated examples
|
||||
"""
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
#### server.py
|
||||
|
||||
For every model API endpoint you want to use, you must implement an `EndpointHandler`. This class handles incoming
|
||||
requests, processes them, sends them to the model API server, and finally returns an HTTP response.
|
||||
`EndpointHandler` has several abstract functions that must be implemented. Here, we implement two, one
|
||||
for `/generate`, and one for `/generate_stream`:
|
||||
|
||||
```python
|
||||
|
||||
"""
|
||||
AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route.
|
||||
When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData
|
||||
work.
|
||||
When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON:
|
||||
{
|
||||
auth_data: AuthData,
|
||||
payload : InputData # defined above
|
||||
}
|
||||
"""
|
||||
from aiohttp import web
|
||||
|
||||
from lib.data_types import EndpointHandler, JsonDataException
|
||||
from lib.server import start_server
|
||||
from .data_types import InputData
|
||||
|
||||
# This class is the implementer for the '/generate' endpoint of model API
|
||||
@dataclasses.dataclass
|
||||
class GenerateHandler(EndpointHandler[InputData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
# the API endpoint
|
||||
return "/generate"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
"""this function should just return ApiPayload subclass used by this handler"""
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
"""
|
||||
defines how to convert `InputData` defined above, to
|
||||
JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but
|
||||
can be more complicated, See comfyui for an example
|
||||
"""
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
"""
|
||||
defines how to generate an InputData for benchmarking. This needs to be defined in only
|
||||
one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test()
|
||||
method on InputData. However, in some cases you might need to fine tune your InputData used for
|
||||
benchmarking to closely resemble the average request users call the endpoint with in order to get best
|
||||
autoscaling performance
|
||||
"""
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""
|
||||
defines how to convert a model API response to a response to PyWorker client
|
||||
"""
|
||||
_ = client_request
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("SUCCESS")
|
||||
data = await model_response.json()
|
||||
return web.json_response(data=data)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
```
|
||||
|
||||
We also handle `GenerateStreamHandler` for streaming responses. It is identical to `GenerateHandler`, except for
|
||||
the endpoint name and how we create a web response, as it is a streaming response:
|
||||
|
||||
```python
|
||||
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("Streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = "text/event-stream"
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
```
|
||||
|
||||
You can now instantiate a Backend and use it to handle requests.
|
||||
|
||||
```python
|
||||
from lib.backend import Backend, LogAction
|
||||
|
||||
# the url and port of model API
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
|
||||
# This is the log line that is emitted once the server has started
|
||||
MODEL_SERVER_START_LOG_MSG = "server has started"
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
|
||||
]
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
# location of model log file
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
# for some model backends that can only handle one request at a time, be sure to set this to False to
|
||||
# let PyWorker handling queueing requests.
|
||||
allow_parallel_requests=True,
|
||||
# give the backend an EndpointHandler instance that is used for benchmarking
|
||||
# number of benchmark run and number of words for a random benchmark run are given
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
# defines how to handle specific log messages. See docstring of LogAction for details
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
# this is a simple ping handler for PyWorker
|
||||
async def handle_ping(_: web.Request):
|
||||
return web.Response(body="pong")
|
||||
|
||||
# this is a handler for forwarding a health check to model API
|
||||
async def handle_healthcheck(_: web.Request):
|
||||
healthcheck_res = await backend.session.get("/healthcheck")
|
||||
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
|
||||
|
||||
routes = [
|
||||
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
web.get("/healthcheck", handle_healthcheck),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
# start server, called from start_server.sh
|
||||
start_server(backend, routes)
|
||||
```
|
||||
|
||||
#### test_load.py
|
||||
|
||||
Here you can create a script that allows you test an endpoint group running instances with this PyWorker
|
||||
|
||||
```python
|
||||
from lib.test_harness import run
|
||||
from .data_types import InputData
|
||||
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
|
||||
if __name__ == "__main__":
|
||||
run(InputData.for_test(), WORKER_ENDPOINT)
|
||||
```
|
||||
|
||||
You can then run the following command from the root of this repo to load test endpoint group:
|
||||
|
||||
```sh
|
||||
# sends 1000 requests at the rate of 0.5 requests per second
|
||||
python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"
|
||||
```
|
||||
@@ -0,0 +1,48 @@
|
||||
import dataclasses
|
||||
import random
|
||||
import inspect
|
||||
from typing import Dict, Any
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
import nltk
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
# used to count to count tokens and workload for LLM
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputData(ApiPayload):
|
||||
prompt: str
|
||||
max_response_tokens: int
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "InputData":
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
return cls(prompt=prompt, max_response_tokens=300)
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def count_workload(self) -> int:
|
||||
return len(tokenizer.tokenize(self.prompt))
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
PyWorker works as a man-in-the-middle between the client and model API. It's function is:
|
||||
1. receive request from client, update metrics such as workload of a request, number of pending requests, etc.
|
||||
2a. transform the data and forward the transformed data to model API
|
||||
2b. send updated metrics to autoscaler
|
||||
3. transform response from model API(if needed) and forward the response to client
|
||||
|
||||
PyWorker forward requests to many model API endpoint. each endpoint must have an EndpointHandler. You can also
|
||||
write function to just forward requests that don't generate anything with the model to model API without an
|
||||
EndpointHandler. This is useful for endpoints such as healthchecks. See below for example
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import dataclasses
|
||||
from typing import Dict, Any, Union, Type
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.data_types import EndpointHandler
|
||||
from lib.server import start_server
|
||||
from .data_types import InputData
|
||||
|
||||
# the url and port of model API
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
|
||||
# This is the log line that is emitted once the server has started
|
||||
MODEL_SERVER_START_LOG_MSG = "infer server has started"
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
|
||||
]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
# This class is the implementer for the '/generate' endpoint of model API
|
||||
@dataclasses.dataclass
|
||||
class GenerateHandler(EndpointHandler[InputData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
# the API endpoint
|
||||
return "/generate"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
"""
|
||||
defines how to convert `InputData` defined above, to
|
||||
json data to be sent to the model API
|
||||
"""
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
"""
|
||||
defines how to generate an InputData for benchmarking. This needs to be defined in only
|
||||
one EndpointHandler, the one passed to the backend as the benchmark handler
|
||||
"""
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""
|
||||
defines how to convert a model API response to a response to PyWorker client
|
||||
"""
|
||||
_ = client_request
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("SUCCESS")
|
||||
data = await model_response.json()
|
||||
return web.json_response(data=data)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
# This is the same as GenerateHandler, except that it calls a streaming endpoint of the model API and streams the
|
||||
# response, which itself is streaming, back to the client.
|
||||
# it is nearly identical to handler as above, but it calls a different model API endpoint and it streams the
|
||||
# streaming response from model API to client
|
||||
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("Streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = "text/event-stream"
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
# This is the backend instance of pyworker. Only one must be made which uses EndpointHandlers to process
|
||||
# incoming requests
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
# give the backend a handler instance that is used for benchmarking
|
||||
# number of benchmark run and number of words for a random benchmark run are given
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
# defines how to handle specific log messages. See docstring of LogAction for details
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# this is a simple ping handler for pyworker
|
||||
async def handle_ping(_: web.Request):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
# this is a handler for forwarding a health check to modelAPI
|
||||
async def handle_healthcheck(_: web.Request):
|
||||
healthcheck_res = await backend.session.get("/healthcheck")
|
||||
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
web.get("/healthcheck", handle_healthcheck),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
# start the PyWorker server
|
||||
start_server(backend, routes)
|
||||
@@ -0,0 +1,7 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from .data_types import InputData
|
||||
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_cmd(InputData.for_test(), WORKER_ENDPOINT, arg_parser=test_args)
|
||||
Reference in New Issue
Block a user