Compare commits

..

8 Commits

Author SHA1 Message Date
Lucas Armand 405a8f1c0d returned to worker-sdk 2025-12-10 16:37:09 -08:00
Lucas Armand 12f4f23d39 remove parse request 2025-12-10 15:16:23 -08:00
Lucas Armand e2a771bb5a update ace and wan workers 2025-12-10 15:09:27 -08:00
Lucas Armand 0cd64adfc4 remove input 2025-12-10 14:47:47 -08:00
Lucas Armand 6f795b8fb8 remove input from workers 2025-12-10 14:46:10 -08:00
Lucas Armand 4bcc508473 reduce vllm benchmark runs to 2 2025-11-25 16:54:17 -08:00
Lucas Armand 74d7330800 add wan and ace workers 2025-11-25 16:08:40 -08:00
Lucas Armand 2ce0450809 Add worker.pys 2025-11-25 16:08:38 -08:00
12 changed files with 801 additions and 352 deletions
+1 -1
View File
@@ -8,4 +8,4 @@ Requests~=2.32
transformers~=4.52
utils==1.0.*
hf_transfer>=0.1.9
vastai-sdk>=0.2.0
git+https://github.com/vast-ai/vast-sdk.git@worker-sdk
+13 -2
View File
@@ -133,8 +133,19 @@ cd "$SERVER_DIR"
echo "launching PyWorker server"
set +e
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
# Try worker entrypoint first
echo "trying workers.${BACKEND}.worker"
python3 -m "workers.${BACKEND}.worker" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
# If that fails, fall back to server
if [ "${PY_STATUS}" -ne 0 ]; then
echo "workers.${BACKEND}.worker failed with status ${PY_STATUS}, trying workers.${BACKEND}.server"
python3 -m "workers.${BACKEND}.server" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
fi
set -e
if [ "${PY_STATUS}" -ne 0 ]; then
@@ -171,4 +182,4 @@ JSON
done
fi
echo "launching PyWorker server done"
echo "launching PyWorker server done"
+184
View File
@@ -0,0 +1,184 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_lyrics = [
"[verse]\nGuardian cloaked in twilight hue\nShadows melt where he breaks through\nEchoes swirl in mystic flight\nHooded hero owns the night\n\n[verse]\nThrough the chaos shapes arise\nFeral whispers, glowing eyes\nOrcs and creatures side by side\nMarch within the inky tide\n\n[chorus]\nRise above the fear and gloom\nLet your courage fully bloom\nIn the darkness stand your ground\nHear the night proclaim your sound",
"[verse]\nMorning sun on fields of gold\nGentle stories unfold\nEvery breeze a quiet song\nWhere the peaceful hearts belong\n\n[verse]\nLanterns glow at stable doors\nRustling leaves on orchard floors\nSimple joys in every hand\nLife grows soft in fertile land\n\n[chorus]\nLet the day drift slow and free\nRoot your soul where you can be\nIn this haven warm and bright\nFeel the earth breathe pure delight",
"[verse]\nLittle feet on dusty ground\nChasing dreams without a sound\nSoccer ball in morning light\nHopes take wing in youthful flight\n\n[verse]\nChrome reflections paint the day\nSwagger in the steps that play\nCopper tones in shining air\nChildhood gleaming everywhere\n\n[chorus]\nKick the world with boundless cheer\nHold the magic close and near\nIn each moment bold and true\nLet the sky belong to you",
"[verse]\nSunset bleeds across the street\nGilded calm in summer heat\nLow-rise towers rimmed with fire\nDreams ignite as lights climb higher\n\n[verse]\nFootsteps scatter through the haze\nFutures shimmer in the blaze\nEvery window tells a tale\nFloating through a tangerine veil\n\n[chorus]\nLet the neon softly glow\nLet your restless heartbeat slow\nIn this city forged in light\nCarry hope into the night",
"[verse]\nOcean breathes in rolling arcs\nSprays of diamond, glowing sparks\nWaves unfold a perfect line\nNatures rhythm feels divine\n\n[verse]\nSun above in golden sweep\nPaints the rise of every deep\nShimmer drifting through the blue\nWorld reborn in every view\n\n[chorus]\nLet the tide pull you along\nHear the waters ancient song\nIn the cresting waves youll find\nQuiet peace for heart and mind",
"[verse]\nGlass aglow with swirling light\nFruits and mints in colors bright\nIcy whispers clink and chime\nFlowing forms suspend in time\n\n[verse]\nCreamy spirals drift within\nGentle currents slowly spin\nWarm reflections lingering sweet\nMixing flavors at your feet\n\n[chorus]\nSip the glow and let it rise\nTaste the sunset in disguise\nIn this moment clear and true\nLet the warmth flow into you",
"[verse]\nEngines rumble down the lane\nCopper clouds of steam and rain\nOilpunk dreams in metal shine\nRider drifting down the line\n\n[verse]\nLeather jacket, steady glare\nStories sparking in the air\nMagazine lights frame his face\nKing of roads in timeless grace\n\n[chorus]\nThrottle up beyond the bend\nFeel the force of steel ascend\nRide the night and hold on tight\nClaim the world in streaks of light",
"[verse]\nCut-out shapes in swirling play\nTextures dance in bold array\nCats in denim, grinning wide\nStrut across the patterned tide\n\n[verse]\nPosters hum with neon glow\nSurreal scenes begin to grow\nColors crisp as folded art\nPatchwork beating like a heart\n\n[chorus]\nLet the collage come alive\nWatch the vibrant pieces thrive\nIn this joyful, crafted space\nEvery shape finds its own place",
"[verse]\nTiny world in crystal glass\nAncient tales behind the mass\nVillage lights in winter gleam\nFrozen in a mystic dream\n\n[verse]\nLantern beams in swirling air\nSoft enchantment everywhere\nShadows drift with gentle grace\nMagic sealed within the space\n\n[chorus]\nHold the sphere and you will see\nEchoes of a memory\nIn the glow of fragile light\nLives a realm of pure delight",
"[verse]\nArmor hums with power bright\nChopping sparks in jungle night\nMecha spirits shift and scream\nThrough the ferns like shattered beams\n\n[verse]\nAxes blaze in glowing arcs\nLighting up the shadowed marks\nNature roars in trembling air\nClash of steel and cosmic flare\n\n[chorus]\nRaise the fire, strike the ground\nLet your legend shake the sound\nIn the wild where echoes roam\nForge the fight and carve your home",
"[verse]\nCrowds ignite in vibrant flare\nBeats explode through smoky air\nDJ robes replaced with flame\nPope on decks in holy frame\n\n[verse]\nLeather gleams in blinding light\nTurntables spin with sacred might\nChoirs echo in the bass\nHeaven pulses through the place\n\n[chorus]\nLift the roof and shake the floor\nSacred rhythm evermore\nLet the music take control\nFeel the blessing in your soul",
]
benchmark_dataset = [
{
"input": {
"request_id": "",
"workflow_json": {
"14": {
"inputs": {
"tags": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic",
"lyrics": lyrics,
"lyrics_strength": 0.99,
"clip": ["40", 1]
},
"class_type": "TextEncodeAceStepAudio",
"_meta": {
"title": "TextEncodeAceStepAudio"
}
},
"17": {
"inputs": {
"seconds": 180,
"batch_size": 1
},
"class_type": "EmptyAceStepLatentAudio",
"_meta": {
"title": "EmptyAceStepLatentAudio"
}
},
"18": {
"inputs": {
"samples": ["52", 0],
"vae": ["40", 2]
},
"class_type": "VAEDecodeAudio",
"_meta": {
"title": "VAE Decode Audio"
}
},
"40": {
"inputs": {
"ckpt_name": "ace_step_v1_3.5b.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"44": {
"inputs": {
"conditioning": ["14", 0]
},
"class_type": "ConditioningZeroOut",
"_meta": {
"title": "ConditioningZeroOut"
}
},
"49": {
"inputs": {
"model": ["51", 0],
"operation": ["50", 0]
},
"class_type": "LatentApplyOperationCFG",
"_meta": {
"title": "LatentApplyOperationCFG"
}
},
"50": {
"inputs": {
"multiplier": 1.15
},
"class_type": "LatentOperationTonemapReinhard",
"_meta": {
"title": "LatentOperationTonemapReinhard"
}
},
"51": {
"inputs": {
"shift": 6,
"model": ["40", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"52": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 65,
"cfg": 4,
"sampler_name": "er_sde",
"scheduler": "linear_quadratic",
"denoise": 1,
"model": ["49", 0],
"positive": ["14", 0],
"negative": ["44", 0],
"latent_image": ["17", 0]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"59": {
"inputs": {
"filename_prefix": "audio/ComfyUI",
"quality": "V0",
"audioUI": "",
"audio": ["18", 0]
},
"class_type": "SaveAudioMP3",
"_meta": {
"title": "Save Audio (MP3)"
}
}
}
}
} for lyrics in benchmark_lyrics
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
runs=1
),
workload_calculator= lambda _ : 1000.0
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+81
View File
@@ -0,0 +1,81 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_prompts = [
"Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.",
"Cozy farming-game scene with fine details.",
"2D vector child with soccer ball; airbrush chrome; swagger; antique copper.",
"Realistic futuristic downtown of low buildings at sunset.",
"Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.",
"Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.",
"Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.",
"Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.",
"Medieval village inside glass sphere; volumetric light; macro focus.",
"Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.",
"Pope Francis DJ in leather jacket, mixing on giant console; dramatic.",
]
benchmark_dataset = [
{
"input": {
"request_id": f"test-{random.randint(1000, 99999)}",
"modifier": "Text2Image",
"modifications": {
"prompt": prompt,
"width": 512,
"height": 512,
"steps": 20,
"seed": random.randint(0, sys.maxsize)
}
}
} for prompt in benchmark_prompts
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
)
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+22 -29
View File
@@ -8,13 +8,14 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
@@ -33,30 +34,12 @@ uv pip install -r requirements.txt
Several examples have been provided in the client to help you get started with your own implementation.
First, set your API key as an environment variable:
### Completions
Call to `/v1/completions` with json response
```bash
export VAST_API_KEY=<your_api_key>
```
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Interactive Chat (streaming)
Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
```
### Chat Completion (json)
@@ -64,7 +47,15 @@ python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
```
### Tool Use (json)
@@ -74,14 +65,16 @@ Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
```
### Completions
### Interactive Chat (streaming)
Call to `/v1/completions` with json response
Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
```
+16 -32
View File
@@ -18,7 +18,7 @@ logging.basicConfig(
log = logging.getLogger(__file__)
# ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
COMPLETIONS_PROMPT = "the capital of USA is"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? "
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
# ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=endpoint_name)
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
@@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, endpo
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=endpoint_name)
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
@@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
return resp["response"]
# ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name)
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
@@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, end
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name)
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"input": {
@@ -174,10 +174,9 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
class APIDemo:
"""Demo and testing functionality for the API client"""
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
self.client = client
self.model = model
self.endpoint_name = endpoint_name
self.tool_manager = tool_manager or ToolManager()
# ----- Streaming handler -----
@@ -186,15 +185,10 @@ class APIDemo:
reasoning_content = ""
printed_reasoning = False
printed_answer = False
finish_reason = None
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
# Track finish reason
if choice.get("finish_reason"):
finish_reason = choice.get("finish_reason")
# reasoning tokens
rc = delta.get("reasoning_content")
@@ -225,8 +219,6 @@ class APIDemo:
print(f"Reasoning tokens: {len(reasoning_content.split())}")
if printed_answer:
print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
return full_response
@@ -239,7 +231,6 @@ class APIDemo:
client=self.client,
model=self.model,
prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
@@ -258,7 +249,6 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
@@ -271,7 +261,6 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
@@ -298,7 +287,6 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool,
tool_choice="none",
max_tokens=10
@@ -324,7 +312,6 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto",
max_tokens=MAX_TOKENS,
@@ -402,7 +389,6 @@ class APIDemo:
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
@@ -441,6 +427,7 @@ class APIDemo:
print("=" * 60)
print("INTERACTIVE STREAMING CHAT")
print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history")
print()
@@ -466,8 +453,7 @@ class APIDemo:
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=0.7
)
@@ -487,8 +473,8 @@ class APIDemo:
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
p.add_argument("--model", required=True, help="Model to use for requests (required)")
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
@@ -516,14 +502,12 @@ async def main_async():
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}")
print("=" * 60)
try:
async with Serverless() as client:
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
demo = APIDemo(client, args.model, ToolManager())
if args.completion:
await demo.demo_completions()
-1
View File
@@ -35,7 +35,6 @@ backend = Backend(
model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
max_wait_time=600.0,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
+78
View File
@@ -0,0 +1,78 @@
import nltk
import random
import os
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# vLLM model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18000
MODEL_LOG_FILE = '/var/log/portal/vllm.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# vLLM-specific log messages
MODEL_LOAD_LOG_MSG = [
"Application startup complete.",
]
MODEL_ERROR_LOG_MSGS = [
"INFO exited: vllm",
"RuntimeError: Engine",
"Traceback (most recent call last):"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Download'
]
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
def completions_benchmark_generator() -> dict:
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
benchmark_data = {
"model": model,
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
return benchmark_data
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/v1/completions",
workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True,
max_queue_time=60.0,
benchmark_config=BenchmarkConfig(
generator=completions_benchmark_generator,
concurrency=100,
runs=2
)
),
HandlerConfig(
route="/v1/chat/completions",
workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True,
max_queue_time=60.0,
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+9 -93
View File
@@ -1,103 +1,19 @@
# HuggingFace TGI PyWorker
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
1. `generate`: Generates the LLM's response to a given prompt in a single request.
2. `generate_stream`: Streams the LLM's response token by token.
## Instance Setup
1. Pick a template
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
The template can be configured via the template interface. You may want to change the model or startup arguments.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
## Using the Test Client
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
First, set your API key as an environment variable:
```bash
export VAST_API_KEY=<your_api_key>
```
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
### Generate (Streaming)
Call to `/generate_stream` with streaming response:
```bash
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
```
### Generate (Non-Streaming)
Call to `/generate` with json response:
```bash
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
```
### Interactive Session (Streaming)
Interactive session with streaming responses. Type `quit` to exit.
```bash
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
```
## API Endpoints
TGI provides two primary endpoints:
### Generate (Non-Streaming)
`/generate` - Returns the complete response in a single request.
Both endpoints use the following API payload format:
```json
{
"inputs": "Your prompt here",
"inputs": "PROMPT",
"parameters": {
"max_new_tokens": 1024,
"temperature": 0.7,
"return_full_text": false
"max_new_tokens": 250
}
}
```
### Generate Stream (Streaming)
`/generate_stream` - Streams the response token by token.
```json
{
"inputs": "Your prompt here",
"parameters": {
"max_new_tokens": 1024,
"temperature": 0.7,
"do_sample": true,
"return_full_text": false
}
}
```
## Performance Notes
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
approximately 2 seconds to complete.
+33 -194
View File
@@ -1,222 +1,61 @@
import logging
import json
import os
import sys
import argparse
from vastai import Serverless
import asyncio
# ---------------------- Logging ----------------------
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
# ---------------------- Defaults ----------------------
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
PROMPT = "Think step by step: Tell me about the Python programming language."
# ---------------------- API Calls ----------------------
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
"""Non-streaming generation via /generate endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
async def call_generate(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"inputs": prompt,
"inputs": PROMPT,
"parameters": {
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"return_full_text": False,
"max_new_tokens": MAX_TOKENS,
"temperature": 0.7,
"return_full_text": False
}
}
log.debug("POST /generate %s", json.dumps(payload)[:500])
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
return resp["response"]
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
print(resp["response"]["generated_text"])
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
"""Streaming generation via /generate_stream endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
async def call_generate_stream(client: Serverless) -> None:
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
payload = {
"inputs": prompt,
"inputs": PROMPT,
"parameters": {
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"max_new_tokens": MAX_TOKENS,
"temperature": 0.7,
"do_sample": True,
"return_full_text": False,
}
}
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
resp = await endpoint.request(
"/generate_stream",
payload,
cost=payload["parameters"]["max_new_tokens"],
cost=MAX_TOKENS,
stream=True,
)
return resp["response"] # async generator
stream = resp["response"]
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("Answer:\n", end="", flush=True)
print(tok, end="", flush=True)
# ---------------------- Demo Runner ----------------------
class APIDemo:
"""Demo and testing functionality for the TGI API client"""
def __init__(self, client: Serverless, endpoint_name: str):
self.client = client
self.endpoint_name = endpoint_name
async def handle_streaming_response(self, stream) -> str:
"""Process streaming response and print tokens"""
full_response = ""
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(tok, end="", flush=True)
full_response += tok
print() # newline
if printed_answer:
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
return full_response
async def demo_generate(self) -> None:
"""Demo non-streaming generation"""
print("=" * 60)
print("GENERATE DEMO (NON-STREAMING)")
print("=" * 60)
response = await call_generate(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print(f"\n💬 Response: {response.get('generated_text', '')}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def demo_generate_stream(self) -> None:
"""Demo streaming generation"""
print("=" * 60)
print("GENERATE DEMO (STREAMING)")
print("=" * 60)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
try:
await self.handle_streaming_response(stream)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
async def interactive_chat(self) -> None:
"""Interactive session with streaming generation"""
print("=" * 60)
print("INTERACTIVE STREAMING SESSION")
print("=" * 60)
print(f"Using endpoint: {self.endpoint_name}")
print("Type 'quit' to exit")
print()
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() == "quit":
print("👋 Goodbye!")
break
elif not user_input:
continue
print("Assistant: ", end="", flush=True)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=user_input,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
full_response = ""
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
print(tok, end="", flush=True)
full_response += tok
print() # newline
except KeyboardInterrupt:
print("\n👋 Session interrupted. Goodbye!")
break
except Exception as e:
log.error("\nError: %s", e)
continue
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
return p
async def main_async():
args = build_arg_parser().parse_args()
selected = sum([args.generate, args.generate_stream, args.interactive])
if selected == 0:
print("Please specify exactly one test mode:")
print(" --generate : Test generate endpoint (non-streaming)")
print(" --generate-stream : Test generate endpoint with streaming")
print(" --interactive : Start interactive streaming session")
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
sys.exit(1)
elif selected > 1:
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint)
if args.generate:
await demo.demo_generate()
elif args.generate_stream:
await demo.demo_generate_stream()
elif args.interactive:
await demo.interactive_chat()
except Exception as e:
log.error("Error during test: %s", e, exc_info=True)
sys.exit(1)
async def main():
async with Serverless() as client:
await call_generate(client)
await call_generate_stream(client)
if __name__ == "__main__":
asyncio.run(main_async())
asyncio.run(main())
+76
View File
@@ -0,0 +1,76 @@
import nltk
import random
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# TGI model configuration
MODEL_SERVER_URL = 'http://0.0.0.0'
MODEL_SERVER_PORT = 5001
MODEL_LOG_FILE = "/workspace/infer.log"
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# TGI-specific log messages
MODEL_LOAD_LOG_MSG = [
'"message":"Connected","target":"text_generation_router"',
'"message":"Connected","target":"text_generation_router::server"',
]
MODEL_ERROR_LOG_MSGS = [
"Error: WebserverFailed",
"Error: DownloadError",
"Error: ShardCannotStart",
]
MODEL_INFO_LOG_MSGS = [
'"message":"Download'
]
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
def benchmark_generator() -> dict:
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
benchmark_data = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 128,
"temperature": 0.7,
"return_full_text": False
}
}
return benchmark_data
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate",
allow_parallel_requests=True,
max_queue_time=60.0,
benchmark_config=BenchmarkConfig(
generator=benchmark_generator,
concurrency=50
),
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
),
HandlerConfig(
route="/generate_stream",
allow_parallel_requests=True,
max_queue_time=60.0,
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+288
View File
@@ -0,0 +1,288 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_prompts = [
"Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.",
"Cozy farming-game scene with fine details.",
"2D vector child with soccer ball; airbrush chrome; swagger; antique copper.",
"Realistic futuristic downtown of low buildings at sunset.",
"Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.",
"Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.",
"Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.",
"Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.",
"Medieval village inside glass sphere; volumetric light; macro focus.",
"Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.",
"Pope Francis DJ in leather jacket, mixing on giant console; dramatic.",
]
benchmark_dataset = [
{
"input": {
"workflow_json": {
"90": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader",
"_meta": {
"title": "Load CLIP"
}
},
"91": {
"inputs": {
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW",
"clip": [
"90",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"92": {
"inputs": {
"vae_name": "wan_2.1_vae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"93": {
"inputs": {
"shift": 8.000000000000002,
"model": [
"101",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"94": {
"inputs": {
"shift": 8,
"model": [
"102",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"95": {
"inputs": {
"add_noise": "disable",
"noise_seed": 0,
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 10,
"end_at_step": 10000,
"return_with_leftover_noise": "disable",
"model": [
"94",
0
],
"positive": [
"99",
0
],
"negative": [
"91",
0
],
"latent_image": [
"96",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"96": {
"inputs": {
"add_noise": "enable",
"noise_seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 0,
"end_at_step": 10,
"return_with_leftover_noise": "enable",
"model": [
"93",
0
],
"positive": [
"99",
0
],
"negative": [
"91",
0
],
"latent_image": [
"104",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"97": {
"inputs": {
"samples": [
"95",
0
],
"vae": [
"92",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"98": {
"inputs": {
"filename_prefix": "video/ComfyUI",
"format": "auto",
"codec": "auto",
"video": [
"100",
0
]
},
"class_type": "SaveVideo",
"_meta": {
"title": "Save Video"
}
},
"99": {
"inputs": {
"text":prompt,
"clip": [
"90",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"100": {
"inputs": {
"fps": 16,
"images": [
"97",
0
]
},
"class_type": "CreateVideo",
"_meta": {
"title": "Create Video"
}
},
"101": {
"inputs": {
"unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"102": {
"inputs": {
"unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"104": {
"inputs": {
"width": 640,
"height": 640,
"length": 81,
"batch_size": 1
},
"class_type": "EmptyHunyuanLatentVideo",
"_meta": {
"title": "EmptyHunyuanLatentVideo"
}
}
}
}
} for prompt in benchmark_prompts
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
runs=1
),
workload_calculator= lambda _ : 10000.0
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()