Use PyWorker SDK (#67)
* Change PyWorker to Worker SDK * Moved /lib to vast-sdk (https://github.com/vast-ai/vast-sdk)
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
import nltk
|
||||
import random
|
||||
import os
|
||||
|
||||
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
|
||||
|
||||
# vLLM model configuration
|
||||
MODEL_SERVER_URL = 'http://127.0.0.1'
|
||||
MODEL_SERVER_PORT = 18000
|
||||
MODEL_LOG_FILE = '/var/log/portal/vllm.log'
|
||||
MODEL_HEALTHCHECK_ENDPOINT = "/health"
|
||||
|
||||
# vLLM-specific log messages
|
||||
MODEL_LOAD_LOG_MSG = [
|
||||
"Application startup complete.",
|
||||
]
|
||||
|
||||
MODEL_ERROR_LOG_MSGS = [
|
||||
"INFO exited: vllm",
|
||||
"RuntimeError: Engine",
|
||||
"Traceback (most recent call last):"
|
||||
]
|
||||
|
||||
MODEL_INFO_LOG_MSGS = [
|
||||
'"message":"Download'
|
||||
]
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
|
||||
def completions_benchmark_generator() -> dict:
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
benchmark_data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
|
||||
return benchmark_data
|
||||
|
||||
worker_config = WorkerConfig(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
model_log_file=MODEL_LOG_FILE,
|
||||
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
|
||||
handlers=[
|
||||
HandlerConfig(
|
||||
route="/v1/completions",
|
||||
workload_calculator= lambda data: data.get("max_tokens", 0),
|
||||
allow_parallel_requests=True,
|
||||
max_queue_time=60.0,
|
||||
benchmark_config=BenchmarkConfig(
|
||||
generator=completions_benchmark_generator,
|
||||
concurrency=100,
|
||||
runs=2
|
||||
)
|
||||
),
|
||||
HandlerConfig(
|
||||
route="/v1/chat/completions",
|
||||
workload_calculator= lambda data: data.get("max_tokens", 0),
|
||||
allow_parallel_requests=True,
|
||||
max_queue_time=60.0,
|
||||
)
|
||||
],
|
||||
log_action_config=LogActionConfig(
|
||||
on_load=MODEL_LOAD_LOG_MSG,
|
||||
on_error=MODEL_ERROR_LOG_MSGS,
|
||||
on_info=MODEL_INFO_LOG_MSGS
|
||||
)
|
||||
)
|
||||
|
||||
Worker(worker_config).run()
|
||||
Reference in New Issue
Block a user