Use PyWorker SDK (#67)
* Change PyWorker to Worker SDK * Moved /lib to vast-sdk (https://github.com/vast-ai/vast-sdk)
This commit is contained in:
@@ -1,73 +0,0 @@
|
||||
import dataclasses
|
||||
import random
|
||||
import inspect
|
||||
from typing import Dict, Any
|
||||
|
||||
from transformers import OpenAIGPTTokenizer
|
||||
import nltk
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputParameters:
|
||||
max_new_tokens: int = 256
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputParameters":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputData(ApiPayload):
|
||||
inputs: str
|
||||
parameters: InputParameters
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "InputData":
|
||||
return cls(
|
||||
inputs=data["inputs"], parameters=InputParameters(**data["parameters"])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "InputData":
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
return cls(inputs=prompt, parameters=InputParameters())
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def count_workload(self) -> int:
|
||||
return self.parameters.max_new_tokens
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
try:
|
||||
parameters = InputParameters.from_json_msg(json_msg["parameters"])
|
||||
return cls(inputs=json_msg["inputs"], parameters=parameters)
|
||||
except JsonDataException as e:
|
||||
errors["parameters"] = e.message
|
||||
raise JsonDataException(errors)
|
||||
@@ -1,130 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Union, Type
|
||||
import dataclasses
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.data_types import EndpointHandler
|
||||
from lib.server import start_server
|
||||
from .data_types import InputData
|
||||
|
||||
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = [
|
||||
'"message":"Connected","target":"text_generation_router"',
|
||||
'"message":"Connected","target":"text_generation_router::server"',
|
||||
]
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"Error: WebserverFailed",
|
||||
"Error: DownloadError",
|
||||
"Error: ShardCannotStart",
|
||||
]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GenerateHandler(EndpointHandler[InputData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate"
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> str:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
_ = client_request
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("SUCCESS")
|
||||
data = await model_response.json()
|
||||
return web.json_response(data=data)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
@property
|
||||
def healthcheck_endpoint(self) -> str:
|
||||
return f"{MODEL_SERVER_URL}/health"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("Streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = "text/event-stream"
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def handle_ping(_):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server(backend, routes)
|
||||
@@ -1,7 +0,0 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from .data_types import InputData
|
||||
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||
@@ -0,0 +1,76 @@
|
||||
import nltk
|
||||
import random
|
||||
|
||||
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
|
||||
|
||||
# TGI model configuration
|
||||
MODEL_SERVER_URL = 'http://0.0.0.0'
|
||||
MODEL_SERVER_PORT = 5001
|
||||
MODEL_LOG_FILE = "/workspace/infer.log"
|
||||
MODEL_HEALTHCHECK_ENDPOINT = "/health"
|
||||
|
||||
# TGI-specific log messages
|
||||
MODEL_LOAD_LOG_MSG = [
|
||||
'"message":"Connected","target":"text_generation_router"',
|
||||
'"message":"Connected","target":"text_generation_router::server"',
|
||||
]
|
||||
|
||||
MODEL_ERROR_LOG_MSGS = [
|
||||
"Error: WebserverFailed",
|
||||
"Error: DownloadError",
|
||||
"Error: ShardCannotStart",
|
||||
]
|
||||
|
||||
MODEL_INFO_LOG_MSGS = [
|
||||
'"message":"Download'
|
||||
]
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
|
||||
def benchmark_generator() -> dict:
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
|
||||
benchmark_data = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": 128,
|
||||
"temperature": 0.7,
|
||||
"return_full_text": False
|
||||
}
|
||||
}
|
||||
|
||||
return benchmark_data
|
||||
|
||||
worker_config = WorkerConfig(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
model_log_file=MODEL_LOG_FILE,
|
||||
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
|
||||
handlers=[
|
||||
HandlerConfig(
|
||||
route="/generate",
|
||||
allow_parallel_requests=True,
|
||||
max_queue_time=60.0,
|
||||
benchmark_config=BenchmarkConfig(
|
||||
generator=benchmark_generator,
|
||||
concurrency=50
|
||||
),
|
||||
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
|
||||
),
|
||||
HandlerConfig(
|
||||
route="/generate_stream",
|
||||
allow_parallel_requests=True,
|
||||
max_queue_time=60.0,
|
||||
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
|
||||
)
|
||||
],
|
||||
log_action_config=LogActionConfig(
|
||||
on_load=MODEL_LOAD_LOG_MSG,
|
||||
on_error=MODEL_ERROR_LOG_MSGS,
|
||||
on_info=MODEL_INFO_LOG_MSGS
|
||||
)
|
||||
)
|
||||
|
||||
Worker(worker_config).run()
|
||||
Reference in New Issue
Block a user