86 lines
2.3 KiB
Python
86 lines
2.3 KiB
Python
import nltk
|
|
import random
|
|
import os
|
|
|
|
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
|
|
|
|
# vLLM model configuration
|
|
MODEL_SERVER_URL = 'http://127.0.0.1'
|
|
MODEL_SERVER_PORT = 18000
|
|
MODEL_LOG_FILE = '/var/log/portal/vllm.log'
|
|
MODEL_HEALTHCHECK_ENDPOINT = "/health"
|
|
|
|
# vLLM-specific log messages
|
|
MODEL_LOAD_LOG_MSG = [
|
|
"Application startup complete.",
|
|
]
|
|
|
|
MODEL_ERROR_LOG_MSGS = [
|
|
"INFO exited: vllm",
|
|
"RuntimeError: Engine",
|
|
"Traceback (most recent call last):"
|
|
]
|
|
|
|
MODEL_INFO_LOG_MSGS = [
|
|
'"message":"Download'
|
|
]
|
|
|
|
nltk.download("words")
|
|
WORD_LIST = nltk.corpus.words.words()
|
|
|
|
def request_parser(request):
|
|
data = request
|
|
if request.get("input") is not None:
|
|
data = request.get("input")
|
|
return data
|
|
|
|
|
|
def completions_benchmark_generator() -> dict:
|
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
|
model = os.environ.get("MODEL_NAME")
|
|
if not model:
|
|
raise ValueError("MODEL_NAME environment variable not set")
|
|
|
|
benchmark_data = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"temperature": 0.7,
|
|
"max_tokens": 500,
|
|
}
|
|
|
|
return benchmark_data
|
|
|
|
worker_config = WorkerConfig(
|
|
model_server_url=MODEL_SERVER_URL,
|
|
model_server_port=MODEL_SERVER_PORT,
|
|
model_log_file=MODEL_LOG_FILE,
|
|
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
|
|
handlers=[
|
|
HandlerConfig(
|
|
route="/v1/completions",
|
|
workload_calculator= lambda data: data.get("max_tokens", 0),
|
|
allow_parallel_requests=True,
|
|
request_parser=request_parser,
|
|
max_queue_time=600.0,
|
|
benchmark_config=BenchmarkConfig(
|
|
generator=completions_benchmark_generator,
|
|
concurrency=100,
|
|
runs=2
|
|
)
|
|
),
|
|
HandlerConfig(
|
|
route="/v1/chat/completions",
|
|
workload_calculator= lambda data: data.get("max_tokens", 0),
|
|
allow_parallel_requests=True,
|
|
request_parser=request_parser
|
|
max_queue_time=600.0,
|
|
)
|
|
],
|
|
log_action_config=LogActionConfig(
|
|
on_load=MODEL_LOAD_LOG_MSG,
|
|
on_error=MODEL_ERROR_LOG_MSGS,
|
|
on_info=MODEL_INFO_LOG_MSGS
|
|
)
|
|
)
|
|
|
|
Worker(worker_config).run() |