From 2ce0450809e843a31c8ef8731f03c74ce3853180 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 25 Nov 2025 13:33:12 -0800 Subject: [PATCH] Add worker.pys --- requirements.txt | 2 +- start_server.sh | 15 +++++- workers/comfyui-json/worker.py | 84 ++++++++++++++++++++++++++++++++++ workers/openai/worker.py | 77 +++++++++++++++++++++++++++++++ workers/tgi/worker.py | 81 ++++++++++++++++++++++++++++++++ 5 files changed, 256 insertions(+), 3 deletions(-) create mode 100644 workers/comfyui-json/worker.py create mode 100644 workers/openai/worker.py create mode 100644 workers/tgi/worker.py diff --git a/requirements.txt b/requirements.txt index 377b20a..f0a66ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ Requests~=2.32 transformers~=4.52 utils==1.0.* hf_transfer>=0.1.9 -vastai-sdk>=0.2.0 \ No newline at end of file +git+https://github.com/vast-ai/vast-sdk.git@worker-sdk \ No newline at end of file diff --git a/start_server.sh b/start_server.sh index 4b07e01..3bb80c4 100755 --- a/start_server.sh +++ b/start_server.sh @@ -133,8 +133,19 @@ cd "$SERVER_DIR" echo "launching PyWorker server" set +e -python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG" + +# Try worker entrypoint first +echo "trying workers.${BACKEND}.worker" +python3 -m "workers.${BACKEND}.worker" |& tee -a "$PYWORKER_LOG" PY_STATUS=${PIPESTATUS[0]} + +# If that fails, fall back to server +if [ "${PY_STATUS}" -ne 0 ]; then + echo "workers.${BACKEND}.worker failed with status ${PY_STATUS}, trying workers.${BACKEND}.server" + python3 -m "workers.${BACKEND}.server" |& tee -a "$PYWORKER_LOG" + PY_STATUS=${PIPESTATUS[0]} +fi + set -e if [ "${PY_STATUS}" -ne 0 ]; then @@ -171,4 +182,4 @@ JSON done fi -echo "launching PyWorker server done" \ No newline at end of file +echo "launching PyWorker server done" diff --git a/workers/comfyui-json/worker.py b/workers/comfyui-json/worker.py new file mode 100644 index 0000000..46ca82d --- /dev/null +++ b/workers/comfyui-json/worker.py @@ -0,0 +1,84 @@ +import random +import sys + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# ComyUI model configuration +MODEL_SERVER_URL = 'http://127.0.0.1' +MODEL_SERVER_PORT = 18288 +MODEL_LOG_FILE = '/var/log/portal/comfyui.log' +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# ComyUI-specific log messages +MODEL_LOAD_LOG_MSG = [ + "To see the GUI go to: " +] + +MODEL_ERROR_LOG_MSGS = [ + "MetadataIncompleteBuffer", + "Value not in list: ", + "[ERROR] Provisioning Script failed" +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Downloading' +] + +benchmark_prompts = [ + "Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.", + "Cozy farming-game scene with fine details.", + "2D vector child with soccer ball; airbrush chrome; swagger; antique copper.", + "Realistic futuristic downtown of low buildings at sunset.", + "Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.", + "Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.", + "Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.", + "Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.", + "Medieval village inside glass sphere; volumetric light; macro focus.", + "Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.", + "Pope Francis DJ in leather jacket, mixing on giant console; dramatic.", +] + + +def parse_request(json_msg): + return {"input" : json_msg} + +benchmark_dataset = [ + { + "input": { + "request_id": f"test-{random.randint(1000, 99999)}", + "modifier": "Text2Image", + "modifications": { + "prompt": prompt, + "width": 512, + "height": 512, + "steps": 20, + "seed": random.randint(0, sys.maxsize) + } + } + } for prompt in benchmark_prompts +] + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/generate/sync", + allow_parallel_requests=False, + max_queue_time=10.0, + request_parser=parse_request, + benchmark_config=BenchmarkConfig( + dataset=benchmark_dataset, + ) + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file diff --git a/workers/openai/worker.py b/workers/openai/worker.py new file mode 100644 index 0000000..105b8df --- /dev/null +++ b/workers/openai/worker.py @@ -0,0 +1,77 @@ +import nltk +import random +import os + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# vLLM model configuration +MODEL_SERVER_URL = 'http://127.0.0.1' +MODEL_SERVER_PORT = 18000 +MODEL_LOG_FILE = '/var/log/portal/vllm.log' +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# vLLM-specific log messages +MODEL_LOAD_LOG_MSG = [ + "Application startup complete.", +] + +MODEL_ERROR_LOG_MSGS = [ + "INFO exited: vllm", + "RuntimeError: Engine", + "Traceback (most recent call last):" +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Download' +] + +nltk.download("words") +WORD_LIST = nltk.corpus.words.words() + + +def completions_benchmark_generator() -> dict: + prompt = " ".join(random.choices(WORD_LIST, k=int(250))) + model = os.environ.get("MODEL_NAME") + if not model: + raise ValueError("MODEL_NAME environment variable not set") + + benchmark_data = { + "model": model, + "prompt": prompt, + "temperature": 0.7, + "max_tokens": 500, + } + + return benchmark_data + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/v1/completions", + workload_calculator= lambda data: data.get("max_tokens", 0), + allow_parallel_requests=True, + max_queue_time=60.0, + benchmark_config=BenchmarkConfig( + generator=completions_benchmark_generator, + concurrency=100 + ) + ), + HandlerConfig( + route="/v1/chat/completions", + workload_calculator= lambda data: data.get("max_tokens", 0), + allow_parallel_requests=True, + max_queue_time=60.0, + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file diff --git a/workers/tgi/worker.py b/workers/tgi/worker.py new file mode 100644 index 0000000..a2f40a2 --- /dev/null +++ b/workers/tgi/worker.py @@ -0,0 +1,81 @@ +import nltk +import random + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# TGI model configuration +MODEL_SERVER_URL = 'http://0.0.0.0' +MODEL_SERVER_PORT = 5001 +MODEL_LOG_FILE = "/workspace/infer.log" +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# TGI-specific log messages +MODEL_LOAD_LOG_MSG = [ + '"message":"Connected","target":"text_generation_router"', + '"message":"Connected","target":"text_generation_router::server"', +] + +MODEL_ERROR_LOG_MSGS = [ + "Error: WebserverFailed", + "Error: DownloadError", + "Error: ShardCannotStart", +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Download' +] + +nltk.download("words") +WORD_LIST = nltk.corpus.words.words() + + +def benchmark_generator() -> dict: + prompt = " ".join(random.choices(WORD_LIST, k=int(250))) + + benchmark_data = { + "inputs": prompt, + "parameters": { + "max_new_tokens": 128, + "temperature": 0.7, + "return_full_text": False + } + } + + return benchmark_data + +def parse_request(json_msg): + return {"input" : json_msg} + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/generate", + allow_parallel_requests=True, + max_queue_time=60.0, + request_parser=parse_request, + benchmark_config=BenchmarkConfig( + generator=benchmark_generator, + concurrency=50 + ), + workload_calculator= lambda x: x["parameters"]["max_new_tokens"] + ), + HandlerConfig( + route="/generate_stream", + allow_parallel_requests=True, + max_queue_time=60.0, + request_parser=parse_request, + workload_calculator= lambda x: x["parameters"]["max_new_tokens"] + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file