Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 405a8f1c0d | |||
| 12f4f23d39 | |||
| e2a771bb5a | |||
| 0cd64adfc4 | |||
| 6f795b8fb8 | |||
| 4bcc508473 | |||
| 74d7330800 | |||
| 2ce0450809 |
+1
-1
@@ -8,4 +8,4 @@ Requests~=2.32
|
||||
transformers~=4.52
|
||||
utils==1.0.*
|
||||
hf_transfer>=0.1.9
|
||||
vastai-sdk>=0.2.0
|
||||
git+https://github.com/vast-ai/vast-sdk.git@worker-sdk
|
||||
+13
-2
@@ -133,8 +133,19 @@ cd "$SERVER_DIR"
|
||||
echo "launching PyWorker server"
|
||||
|
||||
set +e
|
||||
python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG"
|
||||
|
||||
# Try worker entrypoint first
|
||||
echo "trying workers.${BACKEND}.worker"
|
||||
python3 -m "workers.${BACKEND}.worker" |& tee -a "$PYWORKER_LOG"
|
||||
PY_STATUS=${PIPESTATUS[0]}
|
||||
|
||||
# If that fails, fall back to server
|
||||
if [ "${PY_STATUS}" -ne 0 ]; then
|
||||
echo "workers.${BACKEND}.worker failed with status ${PY_STATUS}, trying workers.${BACKEND}.server"
|
||||
python3 -m "workers.${BACKEND}.server" |& tee -a "$PYWORKER_LOG"
|
||||
PY_STATUS=${PIPESTATUS[0]}
|
||||
fi
|
||||
|
||||
set -e
|
||||
|
||||
if [ "${PY_STATUS}" -ne 0 ]; then
|
||||
@@ -171,4 +182,4 @@ JSON
|
||||
done
|
||||
fi
|
||||
|
||||
echo "launching PyWorker server done"
|
||||
echo "launching PyWorker server done"
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
import random
|
||||
import sys
|
||||
|
||||
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
|
||||
|
||||
# ComyUI model configuration
|
||||
MODEL_SERVER_URL = 'http://127.0.0.1'
|
||||
MODEL_SERVER_PORT = 18288
|
||||
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
|
||||
MODEL_HEALTHCHECK_ENDPOINT = "/health"
|
||||
|
||||
# ComyUI-specific log messages
|
||||
MODEL_LOAD_LOG_MSG = [
|
||||
"To see the GUI go to: "
|
||||
]
|
||||
|
||||
MODEL_ERROR_LOG_MSGS = [
|
||||
"MetadataIncompleteBuffer",
|
||||
"Value not in list: ",
|
||||
"[ERROR] Provisioning Script failed"
|
||||
]
|
||||
|
||||
MODEL_INFO_LOG_MSGS = [
|
||||
'"message":"Downloading'
|
||||
]
|
||||
|
||||
benchmark_lyrics = [
|
||||
"[verse]\nGuardian cloaked in twilight hue\nShadows melt where he breaks through\nEchoes swirl in mystic flight\nHooded hero owns the night\n\n[verse]\nThrough the chaos shapes arise\nFeral whispers, glowing eyes\nOrcs and creatures side by side\nMarch within the inky tide\n\n[chorus]\nRise above the fear and gloom\nLet your courage fully bloom\nIn the darkness stand your ground\nHear the night proclaim your sound",
|
||||
"[verse]\nMorning sun on fields of gold\nGentle stories unfold\nEvery breeze a quiet song\nWhere the peaceful hearts belong\n\n[verse]\nLanterns glow at stable doors\nRustling leaves on orchard floors\nSimple joys in every hand\nLife grows soft in fertile land\n\n[chorus]\nLet the day drift slow and free\nRoot your soul where you can be\nIn this haven warm and bright\nFeel the earth breathe pure delight",
|
||||
"[verse]\nLittle feet on dusty ground\nChasing dreams without a sound\nSoccer ball in morning light\nHopes take wing in youthful flight\n\n[verse]\nChrome reflections paint the day\nSwagger in the steps that play\nCopper tones in shining air\nChildhood gleaming everywhere\n\n[chorus]\nKick the world with boundless cheer\nHold the magic close and near\nIn each moment bold and true\nLet the sky belong to you",
|
||||
"[verse]\nSunset bleeds across the street\nGilded calm in summer heat\nLow-rise towers rimmed with fire\nDreams ignite as lights climb higher\n\n[verse]\nFootsteps scatter through the haze\nFutures shimmer in the blaze\nEvery window tells a tale\nFloating through a tangerine veil\n\n[chorus]\nLet the neon softly glow\nLet your restless heartbeat slow\nIn this city forged in light\nCarry hope into the night",
|
||||
"[verse]\nOcean breathes in rolling arcs\nSprays of diamond, glowing sparks\nWaves unfold a perfect line\nNature’s rhythm feels divine\n\n[verse]\nSun above in golden sweep\nPaints the rise of every deep\nShimmer drifting through the blue\nWorld reborn in every view\n\n[chorus]\nLet the tide pull you along\nHear the water’s ancient song\nIn the cresting waves you’ll find\nQuiet peace for heart and mind",
|
||||
"[verse]\nGlass aglow with swirling light\nFruits and mints in colors bright\nIcy whispers clink and chime\nFlowing forms suspend in time\n\n[verse]\nCreamy spirals drift within\nGentle currents slowly spin\nWarm reflections lingering sweet\nMixing flavors at your feet\n\n[chorus]\nSip the glow and let it rise\nTaste the sunset in disguise\nIn this moment clear and true\nLet the warmth flow into you",
|
||||
"[verse]\nEngines rumble down the lane\nCopper clouds of steam and rain\nOilpunk dreams in metal shine\nRider drifting down the line\n\n[verse]\nLeather jacket, steady glare\nStories sparking in the air\nMagazine lights frame his face\nKing of roads in timeless grace\n\n[chorus]\nThrottle up beyond the bend\nFeel the force of steel ascend\nRide the night and hold on tight\nClaim the world in streaks of light",
|
||||
"[verse]\nCut-out shapes in swirling play\nTextures dance in bold array\nCats in denim, grinning wide\nStrut across the patterned tide\n\n[verse]\nPosters hum with neon glow\nSurreal scenes begin to grow\nColors crisp as folded art\nPatchwork beating like a heart\n\n[chorus]\nLet the collage come alive\nWatch the vibrant pieces thrive\nIn this joyful, crafted space\nEvery shape finds its own place",
|
||||
"[verse]\nTiny world in crystal glass\nAncient tales behind the mass\nVillage lights in winter gleam\nFrozen in a mystic dream\n\n[verse]\nLantern beams in swirling air\nSoft enchantment everywhere\nShadows drift with gentle grace\nMagic sealed within the space\n\n[chorus]\nHold the sphere and you will see\nEchoes of a memory\nIn the glow of fragile light\nLives a realm of pure delight",
|
||||
"[verse]\nArmor hums with power bright\nChopping sparks in jungle night\nMecha spirits shift and scream\nThrough the ferns like shattered beams\n\n[verse]\nAxes blaze in glowing arcs\nLighting up the shadowed marks\nNature roars in trembling air\nClash of steel and cosmic flare\n\n[chorus]\nRaise the fire, strike the ground\nLet your legend shake the sound\nIn the wild where echoes roam\nForge the fight and carve your home",
|
||||
"[verse]\nCrowds ignite in vibrant flare\nBeats explode through smoky air\nDJ robes replaced with flame\nPope on decks in holy frame\n\n[verse]\nLeather gleams in blinding light\nTurntables spin with sacred might\nChoirs echo in the bass\nHeaven pulses through the place\n\n[chorus]\nLift the roof and shake the floor\nSacred rhythm evermore\nLet the music take control\nFeel the blessing in your soul",
|
||||
]
|
||||
|
||||
benchmark_dataset = [
|
||||
{
|
||||
"input": {
|
||||
"request_id": "",
|
||||
"workflow_json": {
|
||||
"14": {
|
||||
"inputs": {
|
||||
"tags": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic",
|
||||
"lyrics": lyrics,
|
||||
"lyrics_strength": 0.99,
|
||||
"clip": ["40", 1]
|
||||
},
|
||||
"class_type": "TextEncodeAceStepAudio",
|
||||
"_meta": {
|
||||
"title": "TextEncodeAceStepAudio"
|
||||
}
|
||||
},
|
||||
"17": {
|
||||
"inputs": {
|
||||
"seconds": 180,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyAceStepLatentAudio",
|
||||
"_meta": {
|
||||
"title": "EmptyAceStepLatentAudio"
|
||||
}
|
||||
},
|
||||
"18": {
|
||||
"inputs": {
|
||||
"samples": ["52", 0],
|
||||
"vae": ["40", 2]
|
||||
},
|
||||
"class_type": "VAEDecodeAudio",
|
||||
"_meta": {
|
||||
"title": "VAE Decode Audio"
|
||||
}
|
||||
},
|
||||
"40": {
|
||||
"inputs": {
|
||||
"ckpt_name": "ace_step_v1_3.5b.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"44": {
|
||||
"inputs": {
|
||||
"conditioning": ["14", 0]
|
||||
},
|
||||
"class_type": "ConditioningZeroOut",
|
||||
"_meta": {
|
||||
"title": "ConditioningZeroOut"
|
||||
}
|
||||
},
|
||||
"49": {
|
||||
"inputs": {
|
||||
"model": ["51", 0],
|
||||
"operation": ["50", 0]
|
||||
},
|
||||
"class_type": "LatentApplyOperationCFG",
|
||||
"_meta": {
|
||||
"title": "LatentApplyOperationCFG"
|
||||
}
|
||||
},
|
||||
"50": {
|
||||
"inputs": {
|
||||
"multiplier": 1.15
|
||||
},
|
||||
"class_type": "LatentOperationTonemapReinhard",
|
||||
"_meta": {
|
||||
"title": "LatentOperationTonemapReinhard"
|
||||
}
|
||||
},
|
||||
"51": {
|
||||
"inputs": {
|
||||
"shift": 6,
|
||||
"model": ["40", 0]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "ModelSamplingSD3"
|
||||
}
|
||||
},
|
||||
"52": {
|
||||
"inputs": {
|
||||
"seed": "__RANDOM_INT__",
|
||||
"steps": 65,
|
||||
"cfg": 4,
|
||||
"sampler_name": "er_sde",
|
||||
"scheduler": "linear_quadratic",
|
||||
"denoise": 1,
|
||||
"model": ["49", 0],
|
||||
"positive": ["14", 0],
|
||||
"negative": ["44", 0],
|
||||
"latent_image": ["17", 0]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
},
|
||||
"59": {
|
||||
"inputs": {
|
||||
"filename_prefix": "audio/ComfyUI",
|
||||
"quality": "V0",
|
||||
"audioUI": "",
|
||||
"audio": ["18", 0]
|
||||
},
|
||||
"class_type": "SaveAudioMP3",
|
||||
"_meta": {
|
||||
"title": "Save Audio (MP3)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} for lyrics in benchmark_lyrics
|
||||
]
|
||||
|
||||
worker_config = WorkerConfig(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
model_log_file=MODEL_LOG_FILE,
|
||||
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
|
||||
handlers=[
|
||||
HandlerConfig(
|
||||
route="/generate/sync",
|
||||
allow_parallel_requests=False,
|
||||
max_queue_time=10.0,
|
||||
benchmark_config=BenchmarkConfig(
|
||||
dataset=benchmark_dataset,
|
||||
runs=1
|
||||
),
|
||||
workload_calculator= lambda _ : 1000.0
|
||||
)
|
||||
],
|
||||
log_action_config=LogActionConfig(
|
||||
on_load=MODEL_LOAD_LOG_MSG,
|
||||
on_error=MODEL_ERROR_LOG_MSGS,
|
||||
on_info=MODEL_INFO_LOG_MSGS
|
||||
)
|
||||
)
|
||||
|
||||
Worker(worker_config).run()
|
||||
@@ -0,0 +1,81 @@
|
||||
import random
|
||||
import sys
|
||||
|
||||
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
|
||||
|
||||
# ComyUI model configuration
|
||||
MODEL_SERVER_URL = 'http://127.0.0.1'
|
||||
MODEL_SERVER_PORT = 18288
|
||||
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
|
||||
MODEL_HEALTHCHECK_ENDPOINT = "/health"
|
||||
|
||||
# ComyUI-specific log messages
|
||||
MODEL_LOAD_LOG_MSG = [
|
||||
"To see the GUI go to: "
|
||||
]
|
||||
|
||||
MODEL_ERROR_LOG_MSGS = [
|
||||
"MetadataIncompleteBuffer",
|
||||
"Value not in list: ",
|
||||
"[ERROR] Provisioning Script failed"
|
||||
]
|
||||
|
||||
MODEL_INFO_LOG_MSGS = [
|
||||
'"message":"Downloading'
|
||||
]
|
||||
|
||||
benchmark_prompts = [
|
||||
"Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.",
|
||||
"Cozy farming-game scene with fine details.",
|
||||
"2D vector child with soccer ball; airbrush chrome; swagger; antique copper.",
|
||||
"Realistic futuristic downtown of low buildings at sunset.",
|
||||
"Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.",
|
||||
"Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.",
|
||||
"Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.",
|
||||
"Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.",
|
||||
"Medieval village inside glass sphere; volumetric light; macro focus.",
|
||||
"Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.",
|
||||
"Pope Francis DJ in leather jacket, mixing on giant console; dramatic.",
|
||||
]
|
||||
|
||||
|
||||
|
||||
benchmark_dataset = [
|
||||
{
|
||||
"input": {
|
||||
"request_id": f"test-{random.randint(1000, 99999)}",
|
||||
"modifier": "Text2Image",
|
||||
"modifications": {
|
||||
"prompt": prompt,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"steps": 20,
|
||||
"seed": random.randint(0, sys.maxsize)
|
||||
}
|
||||
}
|
||||
} for prompt in benchmark_prompts
|
||||
]
|
||||
|
||||
worker_config = WorkerConfig(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
model_log_file=MODEL_LOG_FILE,
|
||||
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
|
||||
handlers=[
|
||||
HandlerConfig(
|
||||
route="/generate/sync",
|
||||
allow_parallel_requests=False,
|
||||
max_queue_time=10.0,
|
||||
benchmark_config=BenchmarkConfig(
|
||||
dataset=benchmark_dataset,
|
||||
)
|
||||
)
|
||||
],
|
||||
log_action_config=LogActionConfig(
|
||||
on_load=MODEL_LOAD_LOG_MSG,
|
||||
on_error=MODEL_ERROR_LOG_MSGS,
|
||||
on_info=MODEL_INFO_LOG_MSGS
|
||||
)
|
||||
)
|
||||
|
||||
Worker(worker_config).run()
|
||||
+22
-29
@@ -8,13 +8,14 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
|
||||
|
||||
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
||||
|
||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
|
||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
|
||||
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
|
||||
|
||||
|
||||
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
||||
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
|
||||
## Client Setup (Demo)
|
||||
|
||||
@@ -33,30 +34,12 @@ uv pip install -r requirements.txt
|
||||
|
||||
Several examples have been provided in the client to help you get started with your own implementation.
|
||||
|
||||
First, set your API key as an environment variable:
|
||||
### Completions
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
|
||||
```bash
|
||||
export VAST_API_KEY=<your_api_key>
|
||||
```
|
||||
|
||||
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
|
||||
|
||||
### Chat Completion (streaming)
|
||||
|
||||
Call to `/v1/chat/completions` with streaming response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Interactive Chat (streaming)
|
||||
|
||||
Interactive session with calls to `/v1/chat/completions`.
|
||||
|
||||
Type `clear` to clear the chat history or `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (json)
|
||||
@@ -64,7 +47,15 @@ python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model
|
||||
Call to `/v1/chat/completions` with json response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Chat Completion (streaming)
|
||||
|
||||
Call to `/v1/chat/completions` with streaming response
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Tool Use (json)
|
||||
@@ -74,14 +65,16 @@ Call to `/v1/chat/completions` with tool and json response.
|
||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
### Completions
|
||||
### Interactive Chat (streaming)
|
||||
|
||||
Call to `/v1/completions` with json response
|
||||
Interactive session with calls to `/v1/chat/completions`.
|
||||
|
||||
Type `clear` to clear the chat history or `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
||||
```
|
||||
|
||||
|
||||
+16
-32
@@ -18,7 +18,7 @@ logging.basicConfig(
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
# ---------------------- Prompts ----------------------
|
||||
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
|
||||
COMPLETIONS_PROMPT = "the capital of USA is"
|
||||
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
TOOLS_PROMPT = (
|
||||
"Can you list the files in the current working directory and tell me what you see? "
|
||||
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
|
||||
|
||||
|
||||
# ---- OpenAI-compatible calls (non-streaming) ----
|
||||
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
@@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, endpo
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
@@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
|
||||
return resp["response"]
|
||||
|
||||
# ---- Streaming variants ----
|
||||
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
|
||||
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
@@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, end
|
||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
|
||||
return resp["response"] # async generator
|
||||
|
||||
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
|
||||
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
|
||||
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
@@ -174,10 +174,9 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
|
||||
class APIDemo:
|
||||
"""Demo and testing functionality for the API client"""
|
||||
|
||||
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
|
||||
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.endpoint_name = endpoint_name
|
||||
self.tool_manager = tool_manager or ToolManager()
|
||||
|
||||
# ----- Streaming handler -----
|
||||
@@ -186,15 +185,10 @@ class APIDemo:
|
||||
reasoning_content = ""
|
||||
printed_reasoning = False
|
||||
printed_answer = False
|
||||
finish_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
choice = (chunk.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Track finish reason
|
||||
if choice.get("finish_reason"):
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
# reasoning tokens
|
||||
rc = delta.get("reasoning_content")
|
||||
@@ -225,8 +219,6 @@ class APIDemo:
|
||||
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
||||
if printed_answer:
|
||||
print(f"Response tokens: {len(full_response.split())}")
|
||||
if finish_reason:
|
||||
print(f"Finish reason: {finish_reason}")
|
||||
|
||||
return full_response
|
||||
|
||||
@@ -239,7 +231,6 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
prompt=COMPLETIONS_PROMPT,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
@@ -258,7 +249,6 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
@@ -271,7 +261,6 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE
|
||||
)
|
||||
@@ -298,7 +287,6 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
tools=minimal_tool,
|
||||
tool_choice="none",
|
||||
max_tokens=10
|
||||
@@ -324,7 +312,6 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
tools=self.tool_manager.get_ls_tool_definition(),
|
||||
tool_choice="auto",
|
||||
max_tokens=MAX_TOKENS,
|
||||
@@ -402,7 +389,6 @@ class APIDemo:
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
@@ -441,6 +427,7 @@ class APIDemo:
|
||||
print("=" * 60)
|
||||
print("INTERACTIVE STREAMING CHAT")
|
||||
print("=" * 60)
|
||||
print(f"Using model: {self.model}")
|
||||
print("Type 'quit' to exit, 'clear' to clear history")
|
||||
print()
|
||||
|
||||
@@ -466,8 +453,7 @@ class APIDemo:
|
||||
stream = await stream_chat_completions(
|
||||
client=self.client,
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
endpoint_name=self.endpoint_name,
|
||||
messages=messages,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=0.7
|
||||
)
|
||||
@@ -487,8 +473,8 @@ class APIDemo:
|
||||
# ---------------------- CLI ----------------------
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
|
||||
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
|
||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||
p.add_argument("--model", required=True, help="Model to use for requests (required)")
|
||||
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
|
||||
|
||||
modes = p.add_mutually_exclusive_group(required=False)
|
||||
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
|
||||
@@ -516,14 +502,12 @@ async def main_async():
|
||||
print("Please specify exactly one test mode")
|
||||
sys.exit(1)
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Using model: {args.model}")
|
||||
print(f"Using endpoint: {args.endpoint}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
|
||||
demo = APIDemo(client, args.model, ToolManager())
|
||||
|
||||
if args.completion:
|
||||
await demo.demo_completions()
|
||||
|
||||
@@ -35,7 +35,6 @@ backend = Backend(
|
||||
model_server_url=os.environ["MODEL_SERVER_URL"],
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
max_wait_time=600.0,
|
||||
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
import nltk
|
||||
import random
|
||||
import os
|
||||
|
||||
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
|
||||
|
||||
# vLLM model configuration
|
||||
MODEL_SERVER_URL = 'http://127.0.0.1'
|
||||
MODEL_SERVER_PORT = 18000
|
||||
MODEL_LOG_FILE = '/var/log/portal/vllm.log'
|
||||
MODEL_HEALTHCHECK_ENDPOINT = "/health"
|
||||
|
||||
# vLLM-specific log messages
|
||||
MODEL_LOAD_LOG_MSG = [
|
||||
"Application startup complete.",
|
||||
]
|
||||
|
||||
MODEL_ERROR_LOG_MSGS = [
|
||||
"INFO exited: vllm",
|
||||
"RuntimeError: Engine",
|
||||
"Traceback (most recent call last):"
|
||||
]
|
||||
|
||||
MODEL_INFO_LOG_MSGS = [
|
||||
'"message":"Download'
|
||||
]
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
|
||||
def completions_benchmark_generator() -> dict:
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
model = os.environ.get("MODEL_NAME")
|
||||
if not model:
|
||||
raise ValueError("MODEL_NAME environment variable not set")
|
||||
|
||||
benchmark_data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
}
|
||||
|
||||
return benchmark_data
|
||||
|
||||
worker_config = WorkerConfig(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
model_log_file=MODEL_LOG_FILE,
|
||||
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
|
||||
handlers=[
|
||||
HandlerConfig(
|
||||
route="/v1/completions",
|
||||
workload_calculator= lambda data: data.get("max_tokens", 0),
|
||||
allow_parallel_requests=True,
|
||||
max_queue_time=60.0,
|
||||
benchmark_config=BenchmarkConfig(
|
||||
generator=completions_benchmark_generator,
|
||||
concurrency=100,
|
||||
runs=2
|
||||
)
|
||||
),
|
||||
HandlerConfig(
|
||||
route="/v1/chat/completions",
|
||||
workload_calculator= lambda data: data.get("max_tokens", 0),
|
||||
allow_parallel_requests=True,
|
||||
max_queue_time=60.0,
|
||||
)
|
||||
],
|
||||
log_action_config=LogActionConfig(
|
||||
on_load=MODEL_LOAD_LOG_MSG,
|
||||
on_error=MODEL_ERROR_LOG_MSGS,
|
||||
on_info=MODEL_INFO_LOG_MSGS
|
||||
)
|
||||
)
|
||||
|
||||
Worker(worker_config).run()
|
||||
+9
-93
@@ -1,103 +1,19 @@
|
||||
# HuggingFace TGI PyWorker
|
||||
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
|
||||
|
||||
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||
1. `generate`: Generates the LLM's response to a given prompt in a single request.
|
||||
2. `generate_stream`: Streams the LLM's response token by token.
|
||||
|
||||
## Instance Setup
|
||||
|
||||
1. Pick a template
|
||||
|
||||
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
|
||||
|
||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
|
||||
|
||||
The template can be configured via the template interface. You may want to change the model or startup arguments.
|
||||
|
||||
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||
|
||||
## Client Setup (Demo)
|
||||
|
||||
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vast-ai/pyworker
|
||||
cd pyworker
|
||||
pip install uv
|
||||
uv venv -p 3.12
|
||||
source .venv/bin/activate
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Using the Test Client
|
||||
|
||||
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
|
||||
|
||||
First, set your API key as an environment variable:
|
||||
|
||||
```bash
|
||||
export VAST_API_KEY=<your_api_key>
|
||||
```
|
||||
|
||||
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
|
||||
|
||||
### Generate (Streaming)
|
||||
|
||||
Call to `/generate_stream` with streaming response:
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
### Generate (Non-Streaming)
|
||||
|
||||
Call to `/generate` with json response:
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
### Interactive Session (Streaming)
|
||||
|
||||
Interactive session with streaming responses. Type `quit` to exit.
|
||||
|
||||
```bash
|
||||
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
TGI provides two primary endpoints:
|
||||
|
||||
### Generate (Non-Streaming)
|
||||
|
||||
`/generate` - Returns the complete response in a single request.
|
||||
Both endpoints use the following API payload format:
|
||||
|
||||
```json
|
||||
{
|
||||
"inputs": "Your prompt here",
|
||||
"inputs": "PROMPT",
|
||||
"parameters": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0.7,
|
||||
"return_full_text": false
|
||||
"max_new_tokens": 250
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Generate Stream (Streaming)
|
||||
|
||||
`/generate_stream` - Streams the response token by token.
|
||||
|
||||
```json
|
||||
{
|
||||
"inputs": "Your prompt here",
|
||||
"parameters": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0.7,
|
||||
"do_sample": true,
|
||||
"return_full_text": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Notes
|
||||
|
||||
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
|
||||
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
|
||||
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
|
||||
approximately 2 seconds to complete.
|
||||
|
||||
+33
-194
@@ -1,222 +1,61 @@
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
from vastai import Serverless
|
||||
import asyncio
|
||||
|
||||
# ---------------------- Logging ----------------------
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
# ---------------------- Defaults ----------------------
|
||||
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
|
||||
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
|
||||
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
|
||||
MAX_TOKENS = 1024
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||
|
||||
|
||||
# ---------------------- API Calls ----------------------
|
||||
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
|
||||
"""Non-streaming generation via /generate endpoint"""
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
async def call_generate(client: Serverless) -> None:
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"inputs": prompt,
|
||||
"inputs": PROMPT,
|
||||
"parameters": {
|
||||
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"return_full_text": False,
|
||||
"max_new_tokens": MAX_TOKENS,
|
||||
"temperature": 0.7,
|
||||
"return_full_text": False
|
||||
}
|
||||
}
|
||||
log.debug("POST /generate %s", json.dumps(payload)[:500])
|
||||
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
|
||||
return resp["response"]
|
||||
|
||||
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
|
||||
|
||||
print(resp["response"]["generated_text"])
|
||||
|
||||
|
||||
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
|
||||
"""Streaming generation via /generate_stream endpoint"""
|
||||
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||
async def call_generate_stream(client: Serverless) -> None:
|
||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
||||
|
||||
payload = {
|
||||
"inputs": prompt,
|
||||
"inputs": PROMPT,
|
||||
"parameters": {
|
||||
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"max_new_tokens": MAX_TOKENS,
|
||||
"temperature": 0.7,
|
||||
"do_sample": True,
|
||||
"return_full_text": False,
|
||||
}
|
||||
}
|
||||
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
|
||||
|
||||
resp = await endpoint.request(
|
||||
"/generate_stream",
|
||||
payload,
|
||||
cost=payload["parameters"]["max_new_tokens"],
|
||||
cost=MAX_TOKENS,
|
||||
stream=True,
|
||||
)
|
||||
return resp["response"] # async generator
|
||||
stream = resp["response"]
|
||||
|
||||
printed_answer = False
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("Answer:\n", end="", flush=True)
|
||||
print(tok, end="", flush=True)
|
||||
|
||||
# ---------------------- Demo Runner ----------------------
|
||||
class APIDemo:
|
||||
"""Demo and testing functionality for the TGI API client"""
|
||||
|
||||
def __init__(self, client: Serverless, endpoint_name: str):
|
||||
self.client = client
|
||||
self.endpoint_name = endpoint_name
|
||||
|
||||
async def handle_streaming_response(self, stream) -> str:
|
||||
"""Process streaming response and print tokens"""
|
||||
full_response = ""
|
||||
printed_answer = False
|
||||
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
if not printed_answer:
|
||||
printed_answer = True
|
||||
print("\n💬 Response: ", end="", flush=True)
|
||||
print(tok, end="", flush=True)
|
||||
full_response += tok
|
||||
|
||||
print() # newline
|
||||
if printed_answer:
|
||||
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
|
||||
|
||||
return full_response
|
||||
|
||||
async def demo_generate(self) -> None:
|
||||
"""Demo non-streaming generation"""
|
||||
print("=" * 60)
|
||||
print("GENERATE DEMO (NON-STREAMING)")
|
||||
print("=" * 60)
|
||||
|
||||
response = await call_generate(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=DEFAULT_PROMPT,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
print(f"\n💬 Response: {response.get('generated_text', '')}")
|
||||
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
async def demo_generate_stream(self) -> None:
|
||||
"""Demo streaming generation"""
|
||||
print("=" * 60)
|
||||
print("GENERATE DEMO (STREAMING)")
|
||||
print("=" * 60)
|
||||
|
||||
stream = await call_generate_stream(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=DEFAULT_PROMPT,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
try:
|
||||
await self.handle_streaming_response(stream)
|
||||
except Exception as e:
|
||||
log.error("\nError during streaming: %s", e, exc_info=True)
|
||||
|
||||
async def interactive_chat(self) -> None:
|
||||
"""Interactive session with streaming generation"""
|
||||
print("=" * 60)
|
||||
print("INTERACTIVE STREAMING SESSION")
|
||||
print("=" * 60)
|
||||
print(f"Using endpoint: {self.endpoint_name}")
|
||||
print("Type 'quit' to exit")
|
||||
print()
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("You: ").strip()
|
||||
|
||||
if user_input.lower() == "quit":
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
elif not user_input:
|
||||
continue
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
stream = await call_generate_stream(
|
||||
client=self.client,
|
||||
endpoint_name=self.endpoint_name,
|
||||
prompt=user_input,
|
||||
max_tokens=MAX_TOKENS,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
async for event in stream:
|
||||
tok = (event.get("token") or {}).get("text")
|
||||
if tok:
|
||||
print(tok, end="", flush=True)
|
||||
full_response += tok
|
||||
print() # newline
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Session interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
log.error("\nError: %s", e)
|
||||
continue
|
||||
|
||||
|
||||
# ---------------------- CLI ----------------------
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
|
||||
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||
|
||||
modes = p.add_mutually_exclusive_group(required=False)
|
||||
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
|
||||
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
|
||||
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
|
||||
return p
|
||||
|
||||
|
||||
async def main_async():
|
||||
args = build_arg_parser().parse_args()
|
||||
|
||||
selected = sum([args.generate, args.generate_stream, args.interactive])
|
||||
if selected == 0:
|
||||
print("Please specify exactly one test mode:")
|
||||
print(" --generate : Test generate endpoint (non-streaming)")
|
||||
print(" --generate-stream : Test generate endpoint with streaming")
|
||||
print(" --interactive : Start interactive streaming session")
|
||||
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
|
||||
sys.exit(1)
|
||||
elif selected > 1:
|
||||
print("Please specify exactly one test mode")
|
||||
sys.exit(1)
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Using endpoint: {args.endpoint}")
|
||||
|
||||
try:
|
||||
async with Serverless() as client:
|
||||
demo = APIDemo(client, args.endpoint)
|
||||
|
||||
if args.generate:
|
||||
await demo.demo_generate()
|
||||
elif args.generate_stream:
|
||||
await demo.demo_generate_stream()
|
||||
elif args.interactive:
|
||||
await demo.interactive_chat()
|
||||
|
||||
except Exception as e:
|
||||
log.error("Error during test: %s", e, exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
async def main():
|
||||
async with Serverless() as client:
|
||||
await call_generate(client)
|
||||
await call_generate_stream(client)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main_async())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import nltk
|
||||
import random
|
||||
|
||||
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
|
||||
|
||||
# TGI model configuration
|
||||
MODEL_SERVER_URL = 'http://0.0.0.0'
|
||||
MODEL_SERVER_PORT = 5001
|
||||
MODEL_LOG_FILE = "/workspace/infer.log"
|
||||
MODEL_HEALTHCHECK_ENDPOINT = "/health"
|
||||
|
||||
# TGI-specific log messages
|
||||
MODEL_LOAD_LOG_MSG = [
|
||||
'"message":"Connected","target":"text_generation_router"',
|
||||
'"message":"Connected","target":"text_generation_router::server"',
|
||||
]
|
||||
|
||||
MODEL_ERROR_LOG_MSGS = [
|
||||
"Error: WebserverFailed",
|
||||
"Error: DownloadError",
|
||||
"Error: ShardCannotStart",
|
||||
]
|
||||
|
||||
MODEL_INFO_LOG_MSGS = [
|
||||
'"message":"Download'
|
||||
]
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
|
||||
def benchmark_generator() -> dict:
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
|
||||
benchmark_data = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": 128,
|
||||
"temperature": 0.7,
|
||||
"return_full_text": False
|
||||
}
|
||||
}
|
||||
|
||||
return benchmark_data
|
||||
|
||||
worker_config = WorkerConfig(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
model_log_file=MODEL_LOG_FILE,
|
||||
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
|
||||
handlers=[
|
||||
HandlerConfig(
|
||||
route="/generate",
|
||||
allow_parallel_requests=True,
|
||||
max_queue_time=60.0,
|
||||
benchmark_config=BenchmarkConfig(
|
||||
generator=benchmark_generator,
|
||||
concurrency=50
|
||||
),
|
||||
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
|
||||
),
|
||||
HandlerConfig(
|
||||
route="/generate_stream",
|
||||
allow_parallel_requests=True,
|
||||
max_queue_time=60.0,
|
||||
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
|
||||
)
|
||||
],
|
||||
log_action_config=LogActionConfig(
|
||||
on_load=MODEL_LOAD_LOG_MSG,
|
||||
on_error=MODEL_ERROR_LOG_MSGS,
|
||||
on_info=MODEL_INFO_LOG_MSGS
|
||||
)
|
||||
)
|
||||
|
||||
Worker(worker_config).run()
|
||||
@@ -0,0 +1,288 @@
|
||||
import random
|
||||
import sys
|
||||
|
||||
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
|
||||
|
||||
# ComyUI model configuration
|
||||
MODEL_SERVER_URL = 'http://127.0.0.1'
|
||||
MODEL_SERVER_PORT = 18288
|
||||
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
|
||||
MODEL_HEALTHCHECK_ENDPOINT = "/health"
|
||||
|
||||
# ComyUI-specific log messages
|
||||
MODEL_LOAD_LOG_MSG = [
|
||||
"To see the GUI go to: "
|
||||
]
|
||||
|
||||
MODEL_ERROR_LOG_MSGS = [
|
||||
"MetadataIncompleteBuffer",
|
||||
"Value not in list: ",
|
||||
"[ERROR] Provisioning Script failed"
|
||||
]
|
||||
|
||||
MODEL_INFO_LOG_MSGS = [
|
||||
'"message":"Downloading'
|
||||
]
|
||||
|
||||
benchmark_prompts = [
|
||||
"Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.",
|
||||
"Cozy farming-game scene with fine details.",
|
||||
"2D vector child with soccer ball; airbrush chrome; swagger; antique copper.",
|
||||
"Realistic futuristic downtown of low buildings at sunset.",
|
||||
"Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.",
|
||||
"Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.",
|
||||
"Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.",
|
||||
"Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.",
|
||||
"Medieval village inside glass sphere; volumetric light; macro focus.",
|
||||
"Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.",
|
||||
"Pope Francis DJ in leather jacket, mixing on giant console; dramatic.",
|
||||
]
|
||||
|
||||
benchmark_dataset = [
|
||||
{
|
||||
"input": {
|
||||
"workflow_json": {
|
||||
"90": {
|
||||
"inputs": {
|
||||
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||
"type": "wan",
|
||||
"device": "default"
|
||||
},
|
||||
"class_type": "CLIPLoader",
|
||||
"_meta": {
|
||||
"title": "Load CLIP"
|
||||
}
|
||||
},
|
||||
"91": {
|
||||
"inputs": {
|
||||
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW",
|
||||
"clip": [
|
||||
"90",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Negative Prompt)"
|
||||
}
|
||||
},
|
||||
"92": {
|
||||
"inputs": {
|
||||
"vae_name": "wan_2.1_vae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "Load VAE"
|
||||
}
|
||||
},
|
||||
"93": {
|
||||
"inputs": {
|
||||
"shift": 8.000000000000002,
|
||||
"model": [
|
||||
"101",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "ModelSamplingSD3"
|
||||
}
|
||||
},
|
||||
"94": {
|
||||
"inputs": {
|
||||
"shift": 8,
|
||||
"model": [
|
||||
"102",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "ModelSamplingSD3"
|
||||
}
|
||||
},
|
||||
"95": {
|
||||
"inputs": {
|
||||
"add_noise": "disable",
|
||||
"noise_seed": 0,
|
||||
"steps": 20,
|
||||
"cfg": 3.5,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"start_at_step": 10,
|
||||
"end_at_step": 10000,
|
||||
"return_with_leftover_noise": "disable",
|
||||
"model": [
|
||||
"94",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"99",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"91",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"96",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced",
|
||||
"_meta": {
|
||||
"title": "KSampler (Advanced)"
|
||||
}
|
||||
},
|
||||
"96": {
|
||||
"inputs": {
|
||||
"add_noise": "enable",
|
||||
"noise_seed": "__RANDOM_INT__",
|
||||
"steps": 20,
|
||||
"cfg": 3.5,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"start_at_step": 0,
|
||||
"end_at_step": 10,
|
||||
"return_with_leftover_noise": "enable",
|
||||
"model": [
|
||||
"93",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"99",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"91",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"104",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced",
|
||||
"_meta": {
|
||||
"title": "KSampler (Advanced)"
|
||||
}
|
||||
},
|
||||
"97": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"95",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"92",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"98": {
|
||||
"inputs": {
|
||||
"filename_prefix": "video/ComfyUI",
|
||||
"format": "auto",
|
||||
"codec": "auto",
|
||||
"video": [
|
||||
"100",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveVideo",
|
||||
"_meta": {
|
||||
"title": "Save Video"
|
||||
}
|
||||
},
|
||||
"99": {
|
||||
"inputs": {
|
||||
"text":prompt,
|
||||
"clip": [
|
||||
"90",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Positive Prompt)"
|
||||
}
|
||||
},
|
||||
"100": {
|
||||
"inputs": {
|
||||
"fps": 16,
|
||||
"images": [
|
||||
"97",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CreateVideo",
|
||||
"_meta": {
|
||||
"title": "Create Video"
|
||||
}
|
||||
},
|
||||
"101": {
|
||||
"inputs": {
|
||||
"unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "Load Diffusion Model"
|
||||
}
|
||||
},
|
||||
"102": {
|
||||
"inputs": {
|
||||
"unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "Load Diffusion Model"
|
||||
}
|
||||
},
|
||||
"104": {
|
||||
"inputs": {
|
||||
"width": 640,
|
||||
"height": 640,
|
||||
"length": 81,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyHunyuanLatentVideo",
|
||||
"_meta": {
|
||||
"title": "EmptyHunyuanLatentVideo"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} for prompt in benchmark_prompts
|
||||
]
|
||||
|
||||
worker_config = WorkerConfig(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
model_log_file=MODEL_LOG_FILE,
|
||||
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
|
||||
handlers=[
|
||||
HandlerConfig(
|
||||
route="/generate/sync",
|
||||
allow_parallel_requests=False,
|
||||
max_queue_time=10.0,
|
||||
benchmark_config=BenchmarkConfig(
|
||||
dataset=benchmark_dataset,
|
||||
runs=1
|
||||
),
|
||||
workload_calculator= lambda _ : 10000.0
|
||||
)
|
||||
],
|
||||
log_action_config=LogActionConfig(
|
||||
on_load=MODEL_LOAD_LOG_MSG,
|
||||
on_error=MODEL_ERROR_LOG_MSGS,
|
||||
on_info=MODEL_INFO_LOG_MSGS
|
||||
)
|
||||
)
|
||||
|
||||
Worker(worker_config).run()
|
||||
Reference in New Issue
Block a user