Use PyWorker SDK (#67)

* Change PyWorker to Worker SDK
* Moved /lib to vast-sdk (https://github.com/vast-ai/vast-sdk)
This commit is contained in:
LucasArmandVast
2025-12-15 22:33:03 -05:00
committed by GitHub
parent 2ce741a8b7
commit 4380d98c01
54 changed files with 1622 additions and 4626 deletions
+168
View File
@@ -0,0 +1,168 @@
# ComfyUI ACE Step PyWorker
This is the PyWorker implementation for running **ACE Step v1 3.5B** text-to-music workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI audio-generation workflows through a proxy-based architecture and returning generated audio assets.
Each request has a static cost of `1000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
## Requirements
This worker requires the following components:
- ComfyUI (https://github.com/comfyanonymous/ComfyUI)
- ComfyUI API Wrapper (https://github.com/ai-dock/comfyui-api-wrapper)
- ACE Step v1 3.5B model and required custom nodes
A Docker image is provided with the ACE Step model pre-installed, but any image may be used if the above requirements are met.
## Endpoint
The worker exposes a single synchronous endpoint:
- `/generate/sync`: Processes a complete ComfyUI workflow JSON and generates audio output
## Request Format
The ACE Step worker **only supports custom workflow mode**. Modifier-based workflows are not supported.
```json
{
"input": {
"request_id": "uuid-string",
"workflow_json": {
// Complete ComfyUI ACE Step workflow JSON
},
"s3": { },
"webhook": { }
}
}
```
## Request Fields
### Required Fields
- `input`: Container for all request parameters
- `input.workflow_json`: Complete ComfyUI workflow graph for ACE Step audio generation
### Optional Fields
- `input.request_id`: Client-defined request identifier
- `input.s3`: S3-compatible storage configuration
- `input.webhook`: Webhook configuration for completion notifications
The special string `"__RANDOM_INT__"` may be used in the workflow JSON and will be replaced with a random integer before submission to ComfyUI.
## S3 Configuration
Generated audio assets can be automatically uploaded to S3-compatible storage. Configuration can be supplied per request or via environment variables. Request-level values take precedence.
### Via Request JSON
```json
"s3": {
"access_key_id": "your-s3-access-key",
"secret_access_key": "your-s3-secret-access-key",
"endpoint_url": "https://s3.amazonaws.com",
"bucket_name": "your-bucket",
"region": "us-east-1"
}
```
### Via Environment Variables
```bash
S3_ACCESS_KEY_ID=your-key
S3_SECRET_ACCESS_KEY=your-secret
S3_BUCKET_NAME=your-bucket
S3_ENDPOINT_URL=https://s3.amazonaws.com
S3_REGION=us-east-1
```
## Webhook Configuration
Webhooks are triggered on request completion or failure.
### Via Request JSON
```json
"webhook": {
"url": "https://your-webhook-url",
"extra_params": {
"custom_field": "value"
}
}
```
### Via Environment Variables
```bash
WEBHOOK_URL=https://your-webhook-url
WEBHOOK_TIMEOUT=30
```
## Example Request
### ACE Step Text-to-Music Workflow
```json
{
"input": {
"workflow_json": {
"14": {
"inputs": {
"tags": "funk, pop, upbeat, 105 BPM",
"lyrics": "Turn it up and let it flow",
"lyrics_strength": 0.99,
"clip": ["40", 1]
},
"class_type": "TextEncodeAceStepAudio"
},
"17": {
"inputs": {
"seconds": 180,
"batch_size": 1
},
"class_type": "EmptyAceStepLatentAudio"
},
"40": {
"inputs": {
"ckpt_name": "ace_step_v1_3.5b.safetensors"
},
"class_type": "CheckpointLoaderSimple"
}
}
}
}
```
## Response Format
A successful response includes execution metadata, ComfyUI output details, and generated audio assets.
### Response Fields
- `id`: Unique request identifier
- `status`: `completed`, `failed`, `processing`, `generating`, or `queued`
- `message`: Human-readable status message
- `comfyui_response`: Raw response from ComfyUI, including execution status and progress
- `output`: Array of generated outputs
- `timings`: Timing information for the request
### Output Object
Each entry in `output` includes:
- `filename`: Generated file name (e.g., `.mp3`)
- `local_path`: File path on the worker
- `url`: Pre-signed download URL (if S3 is configured)
- `type`: Output type (`output`)
- `subfolder`: Output directory (e.g., `audio`)
- `node_id`: ComfyUI node that produced the output
- `output_type`: Output category (e.g., `audio`)
## Notes and Limitations
- Only full ComfyUI workflow JSONs are supported
- Concurrent requests are not supported per worker
- ACE Step model must be installed before processing requests
- Audio generation duration and runtime depend on workflow configuration
+149
View File
@@ -0,0 +1,149 @@
from vastai import Serverless
import asyncio
async def main():
async with Serverless() as client:
endpoint = await client.get_endpoint(name="my-ace-endpoint")
# ComfyUI API compatible json workflow for ACE Step
workflow = {
"14": {
"inputs": {
"tags": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic",
"lyrics": "[verse]\nNeon lights they flicker bright\nCity hums in dead of night\nRhythms pulse through concrete veins\nLost in echoes of refrains\n\n[verse]\nBassline groovin in my chest\nHeartbeats match the citys zest\nElectric whispers fill the air\nSynthesized dreams everywhere\n\n[chorus]\nTurn it up and let it flow\nFeel the fire let it grow\nIn this rhythm we belong\nHear the night sing out our song",
"lyrics_strength": 0.99,
"clip": ["40", 1]
},
"class_type": "TextEncodeAceStepAudio",
"_meta": {
"title": "TextEncodeAceStepAudio"
}
},
"17": {
"inputs": {
"seconds": 180,
"batch_size": 1
},
"class_type": "EmptyAceStepLatentAudio",
"_meta": {
"title": "EmptyAceStepLatentAudio"
}
},
"18": {
"inputs": {
"samples": ["52", 0],
"vae": ["40", 2]
},
"class_type": "VAEDecodeAudio",
"_meta": {
"title": "VAE Decode Audio"
}
},
"40": {
"inputs": {
"ckpt_name": "ace_step_v1_3.5b.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"44": {
"inputs": {
"conditioning": ["14", 0]
},
"class_type": "ConditioningZeroOut",
"_meta": {
"title": "ConditioningZeroOut"
}
},
"49": {
"inputs": {
"model": ["51", 0],
"operation": ["50", 0]
},
"class_type": "LatentApplyOperationCFG",
"_meta": {
"title": "LatentApplyOperationCFG"
}
},
"50": {
"inputs": {
"multiplier": 1.15
},
"class_type": "LatentOperationTonemapReinhard",
"_meta": {
"title": "LatentOperationTonemapReinhard"
}
},
"51": {
"inputs": {
"shift": 6,
"model": ["40", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"52": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 65,
"cfg": 4,
"sampler_name": "er_sde",
"scheduler": "linear_quadratic",
"denoise": 1,
"model": ["49", 0],
"positive": ["14", 0],
"negative": ["44", 0],
"latent_image": ["17", 0]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"59": {
"inputs": {
"filename_prefix": "audio/ComfyUI",
"quality": "V0",
"audioUI": "",
"audio": ["18", 0]
},
"class_type": "SaveAudioMP3",
"_meta": {
"title": "Save Audio (MP3)"
}
}
}
payload = {
"input": {
"request_id": "",
"workflow_json": workflow,
"s3": {
"access_key_id": "",
"secret_access_key": "",
"endpoint_url": "",
"bucket_name": "",
"region": ""
},
"webhook": {
"url": "",
"extra_params": {
"user_id": "12345",
"project_id": "abc-def"
}
}
}
}
response = await endpoint.request("/generate/sync", payload)
# Response contains status, output, and any errors
print(response["response"])
if __name__ == "__main__":
asyncio.run(main())
+184
View File
@@ -0,0 +1,184 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_lyrics = [
"[verse]\nGuardian cloaked in twilight hue\nShadows melt where he breaks through\nEchoes swirl in mystic flight\nHooded hero owns the night\n\n[verse]\nThrough the chaos shapes arise\nFeral whispers, glowing eyes\nOrcs and creatures side by side\nMarch within the inky tide\n\n[chorus]\nRise above the fear and gloom\nLet your courage fully bloom\nIn the darkness stand your ground\nHear the night proclaim your sound",
"[verse]\nMorning sun on fields of gold\nGentle stories unfold\nEvery breeze a quiet song\nWhere the peaceful hearts belong\n\n[verse]\nLanterns glow at stable doors\nRustling leaves on orchard floors\nSimple joys in every hand\nLife grows soft in fertile land\n\n[chorus]\nLet the day drift slow and free\nRoot your soul where you can be\nIn this haven warm and bright\nFeel the earth breathe pure delight",
"[verse]\nLittle feet on dusty ground\nChasing dreams without a sound\nSoccer ball in morning light\nHopes take wing in youthful flight\n\n[verse]\nChrome reflections paint the day\nSwagger in the steps that play\nCopper tones in shining air\nChildhood gleaming everywhere\n\n[chorus]\nKick the world with boundless cheer\nHold the magic close and near\nIn each moment bold and true\nLet the sky belong to you",
"[verse]\nSunset bleeds across the street\nGilded calm in summer heat\nLow-rise towers rimmed with fire\nDreams ignite as lights climb higher\n\n[verse]\nFootsteps scatter through the haze\nFutures shimmer in the blaze\nEvery window tells a tale\nFloating through a tangerine veil\n\n[chorus]\nLet the neon softly glow\nLet your restless heartbeat slow\nIn this city forged in light\nCarry hope into the night",
"[verse]\nOcean breathes in rolling arcs\nSprays of diamond, glowing sparks\nWaves unfold a perfect line\nNatures rhythm feels divine\n\n[verse]\nSun above in golden sweep\nPaints the rise of every deep\nShimmer drifting through the blue\nWorld reborn in every view\n\n[chorus]\nLet the tide pull you along\nHear the waters ancient song\nIn the cresting waves youll find\nQuiet peace for heart and mind",
"[verse]\nGlass aglow with swirling light\nFruits and mints in colors bright\nIcy whispers clink and chime\nFlowing forms suspend in time\n\n[verse]\nCreamy spirals drift within\nGentle currents slowly spin\nWarm reflections lingering sweet\nMixing flavors at your feet\n\n[chorus]\nSip the glow and let it rise\nTaste the sunset in disguise\nIn this moment clear and true\nLet the warmth flow into you",
"[verse]\nEngines rumble down the lane\nCopper clouds of steam and rain\nOilpunk dreams in metal shine\nRider drifting down the line\n\n[verse]\nLeather jacket, steady glare\nStories sparking in the air\nMagazine lights frame his face\nKing of roads in timeless grace\n\n[chorus]\nThrottle up beyond the bend\nFeel the force of steel ascend\nRide the night and hold on tight\nClaim the world in streaks of light",
"[verse]\nCut-out shapes in swirling play\nTextures dance in bold array\nCats in denim, grinning wide\nStrut across the patterned tide\n\n[verse]\nPosters hum with neon glow\nSurreal scenes begin to grow\nColors crisp as folded art\nPatchwork beating like a heart\n\n[chorus]\nLet the collage come alive\nWatch the vibrant pieces thrive\nIn this joyful, crafted space\nEvery shape finds its own place",
"[verse]\nTiny world in crystal glass\nAncient tales behind the mass\nVillage lights in winter gleam\nFrozen in a mystic dream\n\n[verse]\nLantern beams in swirling air\nSoft enchantment everywhere\nShadows drift with gentle grace\nMagic sealed within the space\n\n[chorus]\nHold the sphere and you will see\nEchoes of a memory\nIn the glow of fragile light\nLives a realm of pure delight",
"[verse]\nArmor hums with power bright\nChopping sparks in jungle night\nMecha spirits shift and scream\nThrough the ferns like shattered beams\n\n[verse]\nAxes blaze in glowing arcs\nLighting up the shadowed marks\nNature roars in trembling air\nClash of steel and cosmic flare\n\n[chorus]\nRaise the fire, strike the ground\nLet your legend shake the sound\nIn the wild where echoes roam\nForge the fight and carve your home",
"[verse]\nCrowds ignite in vibrant flare\nBeats explode through smoky air\nDJ robes replaced with flame\nPope on decks in holy frame\n\n[verse]\nLeather gleams in blinding light\nTurntables spin with sacred might\nChoirs echo in the bass\nHeaven pulses through the place\n\n[chorus]\nLift the roof and shake the floor\nSacred rhythm evermore\nLet the music take control\nFeel the blessing in your soul",
]
benchmark_dataset = [
{
"input": {
"request_id": "",
"workflow_json": {
"14": {
"inputs": {
"tags": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic",
"lyrics": lyrics,
"lyrics_strength": 0.99,
"clip": ["40", 1]
},
"class_type": "TextEncodeAceStepAudio",
"_meta": {
"title": "TextEncodeAceStepAudio"
}
},
"17": {
"inputs": {
"seconds": 180,
"batch_size": 1
},
"class_type": "EmptyAceStepLatentAudio",
"_meta": {
"title": "EmptyAceStepLatentAudio"
}
},
"18": {
"inputs": {
"samples": ["52", 0],
"vae": ["40", 2]
},
"class_type": "VAEDecodeAudio",
"_meta": {
"title": "VAE Decode Audio"
}
},
"40": {
"inputs": {
"ckpt_name": "ace_step_v1_3.5b.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"44": {
"inputs": {
"conditioning": ["14", 0]
},
"class_type": "ConditioningZeroOut",
"_meta": {
"title": "ConditioningZeroOut"
}
},
"49": {
"inputs": {
"model": ["51", 0],
"operation": ["50", 0]
},
"class_type": "LatentApplyOperationCFG",
"_meta": {
"title": "LatentApplyOperationCFG"
}
},
"50": {
"inputs": {
"multiplier": 1.15
},
"class_type": "LatentOperationTonemapReinhard",
"_meta": {
"title": "LatentOperationTonemapReinhard"
}
},
"51": {
"inputs": {
"shift": 6,
"model": ["40", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"52": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 65,
"cfg": 4,
"sampler_name": "er_sde",
"scheduler": "linear_quadratic",
"denoise": 1,
"model": ["49", 0],
"positive": ["14", 0],
"negative": ["44", 0],
"latent_image": ["17", 0]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"59": {
"inputs": {
"filename_prefix": "audio/ComfyUI",
"quality": "V0",
"audioUI": "",
"audio": ["18", 0]
},
"class_type": "SaveAudioMP3",
"_meta": {
"title": "Save Audio (MP3)"
}
}
}
}
} for lyrics in benchmark_lyrics
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
runs=1
),
workload_calculator= lambda _ : 1000.0
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+9 -1
View File
@@ -2,7 +2,7 @@
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
The cost for each request has a static value of `100`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
## Instance Setup
@@ -302,3 +302,11 @@ WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
}
}
```
## Client Libraries
See the client example for implementation details on how to integrate with the ComfyUI worker.
---
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
-84
View File
@@ -1,84 +0,0 @@
import os
import sys
import random
import dataclasses
from typing import Dict, Any
from functools import cache
from math import ceil
from pathlib import Path
import json
import logging
from lib.data_types import ApiPayload, JsonDataException
log = logging.getLogger(__file__)
def count_workload() -> float:
# Always 100.0 where there is a single instance of ComfyUI handling requests
# Results will indicate % or a job completed per second. Avoids sub 0.1 sec performance indication
return 100.0
@dataclasses.dataclass
class ComfyWorkflowData(ApiPayload):
input: dict
@classmethod
def for_test(cls):
"""
If the user has provided a benchmark workflow we can use it here to properly gauge performance.
Otherwise, use the variables available to simulate workflows of the required running time
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
"""
# Try to load benchmark.json
benchmark_file = Path("workers/comfyui-json/misc/benchmark.json")
if benchmark_file.exists():
try:
with open(benchmark_file, "r") as f:
benchmark_workflow = json.load(f)
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"workflow_json": benchmark_workflow
}
)
except (json.JSONDecodeError, IOError):
# JSON is malformed or file can't be read, fall through to default
log.error(f"Failed to benchmark using {benchmark_file}")
# Fallback: read prompts and construct payload
log.info("Using fallback method for benchmarking")
with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
test_prompt = random.choice(test_prompts).rstrip()
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"modifier": "Text2Image",
"modifications": {
"prompt": test_prompt,
"width": os.getenv('BENCHMARK_TEST_WIDTH', 512),
"height": os.getenv('BENCHMARK_TEST_HEIGHT', 512),
"steps": os.getenv('BENCHMARK_TEST_STEPS', 20),
"seed": random.randint(0, sys.maxsize),
}
}
)
def generate_payload_json(self) -> Dict[str, Any]:
# input is already a dict, just return it wrapped in the expected structure
return {"input": self.input}
def count_workload(self) -> float:
return count_workload()
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ComfyWorkflowData":
# Extract required fields
if "input" not in json_msg:
raise JsonDataException({"input": "missing parameter"})
return cls(
input=json_msg["input"]
)
@@ -1,107 +0,0 @@
{
"3": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "text, watermark",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
+1
View File
@@ -0,0 +1 @@
# This folder is required for the provisioning scripts of ace and wan to complete.
@@ -1,34 +0,0 @@
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
stardew valley, fine details
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
realistic futuristic city-downtown with short buildings, sunset
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
Pope Francis wearing biker (leather jacket), a masterpiece
Luke Skywalker ordering a burger and fries from the Death Star canteen.
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
london luxurious interior living-room, light walls
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
the street of amedieval fantasy town, at dawn, dark, highly detailed
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
-150
View File
@@ -1,150 +0,0 @@
import os
import logging
import dataclasses
import base64
from typing import Optional, Union, Type
import aiohttp
from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import ComfyWorkflowData
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:18188") # Raw ComfyUI server
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: ", # This error is emitted when the model file is not there at all
"[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def generate_client_response(
client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
# Check if the response is actually streaming based on response headers/content-type
is_streaming_response = (
model_response.content_type == "text/event-stream"
or model_response.content_type == "application/x-ndjson"
or model_response.headers.get("Transfer-Encoding") == "chunked"
or "stream" in model_response.content_type.lower()
)
if is_streaming_response:
log.debug("Detected streaming response...")
res = web.StreamResponse()
res.content_type = model_response.content_type
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
else:
log.debug("Detected non-streaming response...")
content = await model_response.read()
return web.Response(
body=content,
status=model_response.status,
content_type=model_response.content_type
)
@dataclasses.dataclass
class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/generate/sync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[ComfyWorkflowData]:
return ComfyWorkflowData
def make_benchmark_payload(self) -> ComfyWorkflowData:
return ComfyWorkflowData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
return await generate_client_response(client_request, model_response)
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=False,
benchmark_handler=ComfyWorkflowHandler(
benchmark_runs=3, benchmark_words=100
),
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, "Downloading:"),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
async def handle_view(request: web.Request) -> web.Response:
"""Proxy /view requests to raw ComfyUI server to fetch generated images"""
# Forward query params to raw ComfyUI (not the API wrapper)
query_string = request.query_string
url = f"{COMFYUI_URL}/view?{query_string}"
log.debug(f"Proxying /view request to: {url}")
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status == 200:
content = await resp.read()
return web.Response(
body=content,
status=200,
content_type=resp.content_type or "image/png"
)
else:
text = await resp.text()
return web.Response(
text=text,
status=resp.status,
content_type="text/plain"
)
except Exception as e:
log.error(f"Error proxying /view: {e}")
return web.Response(text=str(e), status=500)
routes = [
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
web.get("/view", handle_view),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
-8
View File
@@ -1,8 +0,0 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import ComfyWorkflowData
WORKER_ENDPOINT = "/generate/sync"
if __name__ == "__main__":
test_load_cmd(ComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
+81
View File
@@ -0,0 +1,81 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_prompts = [
"Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.",
"Cozy farming-game scene with fine details.",
"2D vector child with soccer ball; airbrush chrome; swagger; antique copper.",
"Realistic futuristic downtown of low buildings at sunset.",
"Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.",
"Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.",
"Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.",
"Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.",
"Medieval village inside glass sphere; volumetric light; macro focus.",
"Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.",
"Pope Francis DJ in leather jacket, mixing on giant console; dramatic.",
]
benchmark_dataset = [
{
"input": {
"request_id": f"test-{random.randint(1000, 99999)}",
"modifier": "Text2Image",
"modifications": {
"prompt": prompt,
"width": 512,
"height": 512,
"steps": 20,
"seed": random.randint(0, sys.maxsize)
}
}
} for prompt in benchmark_prompts
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
)
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
-92
View File
@@ -1,92 +0,0 @@
This is the base PyWorker for comfyui. It can be used to create PyWorker that use various models and
workflows. It provides two endpoints:
1. `/prompt`: Uses the default comfy workflow defined under `misc/default_workflows`
2. `/custom_workflow`: Allows the client to send their own comfy workflow with each API request.
To use the comfyui PyWorker, `$COMFY_MODEL` env variable must be set in the template. Current options are
`sd3` and `flux`. Each have example clients.
To add new models, a JSON with name `$COMFY_MODEL.json` must be created under `misc/default_workflows`
NOTE: default workflows follow this format:
```json
{
"input": {
"handler": "RawWorkflow",
"aws_access_key_id": "your-s3-access-key",
"aws_secret_access_key": "your-s3-secret-access-key",
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
"aws_bucket_name": "your-bucket",
"webhook_url": "your-webhook-url",
"webhook_extra_params": {},
"workflow_json": {}
}
}
```
You can ignore all of these fields except for `workflow_json`.
Fields written as "{{FOO}}" will be replaced using data from a user request. For example, SD3's workflow has the
following nodes:
```json
"5": {
"inputs": {
"width": "{{WIDTH}}",
"height": "{{HEIGHT}}",
"batch_size": 1
},
"6": {
"inputs": {
"text": "{{PROMPT}}",
"clip": ["11", 0]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
...
"17": {
"inputs": {
"scheduler": "simple",
"steps": "{{STEPS}}",
"denoise": 1,
"model": ["12", 0]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
...
"25": {
"inputs": {
"noise_seed": "{{SEED}}"
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
}
```
Incoming requests have the following JSON format:
```json
{
prompt: str
width: int
height: int
steps: int
seed: int
}
```
Each value in those fields with replace the placeholder of the same name in the default workflow.
See Vast's serverless documentation for more details on how to use comfyui with autoscaler
-170
View File
@@ -1,170 +0,0 @@
import logging
from urllib.parse import urljoin
import requests
from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from vastai import Serverless
ENDPOINT_NAME = "my-comfyui-endpoint"
COST = 100 # Use a constant cost for image generation
def call_default_workflow(client: Serverless) -> None:
WORKER_ENDPOINT = "/prompt"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status()
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(
prompt="a fat fluffy cat", width=1024, height=1024, steps=20, seed=123456789
)
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
print_truncate_res(str(response.json()))
def call_custom_workflow_for_sd3(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/custom-workflow"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status()
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
request_idx=message["request_idx"],
)
workflow = {
"3": {
"inputs": {
"seed": 156680208700286,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": ["4", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["5", 0],
},
"class_type": "KSampler",
},
"4": {
"inputs": {"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"},
"class_type": "CheckpointLoaderSimple",
},
"5": {
"inputs": {"width": 512, "height": 512, "batch_size": 1},
"class_type": "EmptyLatentImage",
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, purple galaxy bottle",
"clip": ["4", 1],
},
"class_type": "CLIPTextEncode",
},
"7": {
"inputs": {"text": "text, watermark", "clip": ["4", 1]},
"class_type": "CLIPTextEncode",
},
"8": {
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
"class_type": "VAEDecode",
},
"9": {
"inputs": {"filename_prefix": "ComfyUI", "images": ["8", 0]},
"class_type": "SaveImage",
},
}
# these values should match the values in the custom workflow above,
# they are used to calculate workload
custom_fields = dict(
steps=20,
width=512,
height=512,
)
req_data = dict(
payload=dict(custom_fields=custom_fields, workflow=workflow),
auth_data=auth_data,
)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
print_truncate_res(str(response.json()))
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
try:
call_default_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_custom_workflow_for_sd3(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
-205
View File
@@ -1,205 +0,0 @@
import sys
import os
import json
import random
import dataclasses
import inspect
from typing import Dict, Any
from functools import cache
from math import ceil
from enum import Enum
from lib.data_types import ApiPayload, JsonDataException
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
class Model(Enum):
Flux = "flux"
Sd3 = "sd3"
def get_request_time(self) -> int:
match self:
case Model.Flux:
return 23
case Model.Sd3:
return 6
@cache
def get_model() -> Model:
match os.environ.get("COMFY_MODEL"):
case "flux":
return Model.Flux
case "sd3":
return Model.Sd3
case None:
raise Exception(
"For comfyui pyworker, $COMFY_MODEL must be set in the vast template"
)
case model:
raise Exception(f"Unsupported comfyui model: {model}")
@cache
def get_request_template() -> str:
with open(f"workers/comfyui/misc/default_workflows/{get_model().value}.json") as f:
return f.read()
def count_workload(width: int, height: int, steps: int) -> float:
"""
we want to normalize the workload is a number such that cur_perf(tokens/second) for 1024x1024 image with
28 steps is 200 tokens on a 4090.
in order get that we calculate the
A = ( absolute workload based on given data )
B = ( absolute workload for a 1024x1024 image with 28 steps )
and adjust the workload to 200 tokens by A/B.
we then adjust for difference between Flux and SD3 by multiplying this value by expected request time for a
standard image(23s for Flux, 6s for SD3).
On a 4090, this would give us a workload that would give a cur_perf(workload / request_time) of around 200
"""
def _calculate_absolute_tokens(width_: int, height_: int, steps_: int) -> float:
"""
This is based on how openai counts image generation tokens, see: https://openai.com/api/pricing/
we count how many 512x512 grids are needed to cover the image.
each tile is then counted as 175 tokens.
each image generation also has constant of 85 base tokens.
we then adjust the count based on the number of steps. The baseline number of steps is assumed to be 28.
Some testing with flux gave me this data:
steps(X) | request time(Y)
__________|_________________
07(0.25x) | 11s (0.47x)
14(0.50x) | 15s (0.65x)
21(0.75x) | 20s (0.86x)
28(1.00x) | 23s (1.00x)
35(1.25x) | 28s (1.21x)
42(1.50x) | 32s (1.39x)
49(1.75x) | 37s (1.60x)
this gives a linear regression of Y = 0.61*X + 6.57
we can use this as an adjustment_factor for token count
adjustment_factor = (0.61 * steps + 6.57)
"""
width_grids = ceil(width_ / 512)
height_grids = ceil(height_ / 512)
tokens = 85 + width_grids * height_grids * 175
adjustment_factor = 0.61 * steps_ + 6.57
return tokens * adjustment_factor
REQUEST_TIME_FOR_STANDARD_IMAGE = get_model().get_request_time()
absolute_tokens = _calculate_absolute_tokens(
width_=width, height_=height, steps_=steps
)
absolute_tokens_standard_image = _calculate_absolute_tokens(
width_=1024, height_=1024, steps_=28
)
return REQUEST_TIME_FOR_STANDARD_IMAGE * (
(absolute_tokens / absolute_tokens_standard_image) * 200
)
@dataclasses.dataclass
class DefaultComfyWorkflowData(ApiPayload):
prompt: str
width: int
height: int
steps: int
seed: int
@classmethod
def for_test(cls):
test_prompt = random.choice(test_prompts).rstrip()
return cls(
prompt=test_prompt,
width=1024,
height=1024,
steps=28,
seed=random.randint(0, sys.maxsize),
)
def generate_payload_json(
self,
) -> Dict[str, Any]:
return json.loads(
get_request_template()
.replace("{{PROMPT}}", self.prompt)
# these values should be of int type. Since "{{VAR}}" is wrapped with " in the template
# to make the JSON valid, we must replace the double quotes. i.e. "{{WIDTH}}" -> 1024 and not "1024"
.replace('"{{WIDTH}}"', str(self.width))
.replace('"{{HEIGHT}}"', str(self.height))
.replace('"{{STEPS}}"', str(self.steps))
.replace('"{{SEED}}"', str(self.seed))
)
def count_workload(self) -> float:
return count_workload(width=self.width, height=self.height, steps=self.steps)
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "DefaultComfyWorkflowData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
@dataclasses.dataclass
class CustomComfyWorkflowData(ApiPayload):
custom_fields: Dict[str, int]
workflow: Dict[str, Any]
@classmethod
def for_test(cls):
raise NotImplementedError("Custom comfy workflow is not used for testing")
def count_workload(self) -> float:
return count_workload(
width=int(self.custom_fields.get("width", 1024)),
height=int(self.custom_fields.get("height", 1024)),
steps=int(self.custom_fields.get("steps", 28)),
)
def generate_payload_json(self) -> Dict[str, Any]:
template_json = json.loads(get_request_template())
template_json["input"]["workflow_json"] = self.workflow
return template_json
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "CustomComfyWorkflowData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
@@ -1,137 +0,0 @@
{
"input": {
"handler": "RawWorkflow",
"aws_access_key_id": "your-s3-access-key",
"aws_secret_access_key": "your-s3-secret-access-key",
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
"aws_bucket_name": "your-bucket",
"webhook_url": "your-webhook-url",
"webhook_extra_params": {},
"workflow_json": {
"5": {
"inputs": {
"width": "{{WIDTH}}",
"height": "{{HEIGHT}}",
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "{{PROMPT}}",
"clip": ["11", 0]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": ["13", 0],
"vae": ["10", 0]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": ["8", 0]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"10": {
"inputs": {
"vae_name": "ae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"11": {
"inputs": {
"clip_name1": "t5xxl_fp16.safetensors",
"clip_name2": "clip_l.safetensors",
"type": "flux"
},
"class_type": "DualCLIPLoader",
"_meta": {
"title": "DualCLIPLoader"
}
},
"12": {
"inputs": {
"unet_name": "flux1-dev.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"13": {
"inputs": {
"noise": ["25", 0],
"guider": ["22", 0],
"sampler": ["16", 0],
"sigmas": ["17", 0],
"latent_image": ["5", 0]
},
"class_type": "SamplerCustomAdvanced",
"_meta": {
"title": "SamplerCustomAdvanced"
}
},
"16": {
"inputs": {
"sampler_name": "euler"
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect"
}
},
"17": {
"inputs": {
"scheduler": "simple",
"steps": "{{STEPS}}",
"denoise": 1,
"model": ["12", 0]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
"22": {
"inputs": {
"model": ["12", 0],
"conditioning": ["6", 0]
},
"class_type": "BasicGuider",
"_meta": {
"title": "BasicGuider"
}
},
"25": {
"inputs": {
"noise_seed": "{{SEED}}"
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
}
}
}
}
@@ -1,142 +0,0 @@
{
"input": {
"handler": "RawWorkflow",
"aws_access_key_id": "your-s3-access-key",
"aws_secret_access_key": "your-s3-secret-access-key",
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
"aws_bucket_name": "your-bucket",
"webhook_url": "your-webhook-url",
"webhook_extra_params": {},
"workflow_json": {
"6": {
"inputs": {
"text": "{{PROMPT}}",
"clip": ["252", 1]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"13": {
"inputs": {
"shift": 3,
"model": ["252", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"67": {
"inputs": {
"conditioning": ["71", 0]
},
"class_type": "ConditioningZeroOut",
"_meta": {
"title": "ConditioningZeroOut"
}
},
"68": {
"inputs": {
"start": 0.1,
"end": 1,
"conditioning": ["67", 0]
},
"class_type": "ConditioningSetTimestepRange",
"_meta": {
"title": "ConditioningSetTimestepRange"
}
},
"69": {
"inputs": {
"conditioning_1": ["68", 0],
"conditioning_2": ["70", 0]
},
"class_type": "ConditioningCombine",
"_meta": {
"title": "Conditioning (Combine)"
}
},
"70": {
"inputs": {
"start": 0,
"end": 0.1,
"conditioning": ["71", 0]
},
"class_type": "ConditioningSetTimestepRange",
"_meta": {
"title": "ConditioningSetTimestepRange"
}
},
"71": {
"inputs": {
"text": "bad quality, poor quality, doll, disfigured, jpg, toy, bad anatomy, missing limbs, missing fingers, 3d, cgi",
"clip": ["252", 1]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"135": {
"inputs": {
"width": "{{WIDTH}}",
"height": "{{HEIGHT}}",
"batch_size": 1
},
"class_type": "EmptySD3LatentImage",
"_meta": {
"title": "EmptySD3LatentImage"
}
},
"231": {
"inputs": {
"samples": ["271", 0],
"vae": ["252", 2]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"233": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": ["231", 0]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"252": {
"inputs": {
"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"271": {
"inputs": {
"seed": "{{SEED}}",
"steps": "{{STEPS}}",
"cfg": 4.5,
"sampler_name": "dpmpp_2m",
"scheduler": "sgm_uniform",
"denoise": 1,
"model": ["13", 0],
"positive": ["6", 0],
"negative": ["69", 0],
"latent_image": ["135", 0]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
}
}
}
}
-34
View File
@@ -1,34 +0,0 @@
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
stardew valley, fine details
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
realistic futuristic city-downtown with short buildings, sunset
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
Pope Francis wearing biker (leather jacket), a masterpiece
Luke Skywalker ordering a burger and fries from the Death Star canteen.
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
london luxurious interior living-room, light walls
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
the street of amedieval fantasy town, at dawn, dark, highly detailed
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
-143
View File
@@ -1,143 +0,0 @@
import os
import logging
import dataclasses
import base64
from typing import Optional, Union, Type
from aiohttp import web, ClientResponse
from anyio import open_file
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
MODEL_SERVER_URL = "http://127.0.0.1:18288" # API Wrapper Service
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: unet_name", # This error is emitted when the model file is not there at all
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def generate_client_response(
request: web.Request, response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
_ = request
match response.status:
case 200:
log.debug("SUCCESS")
res = await response.json()
if "output" not in res:
return web.json_response(
data=dict(error="there was an error in the workflow"),
status=422,
)
image_paths = [path["local_path"] for path in res["output"]["images"]]
if not image_paths:
return web.json_response(
data=dict(error="workflow did not produce any images"),
status=422,
)
images = []
for image_path in image_paths:
async with await open_file(image_path, mode="rb") as f:
contents = await f.read()
images.append(
f"data:image/png;base64,{base64.b64encode(contents).decode('utf-8')}"
)
return web.json_response(data=dict(images=images))
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
@dataclasses.dataclass
class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
return DefaultComfyWorkflowData
def make_benchmark_payload(self) -> DefaultComfyWorkflowData:
return DefaultComfyWorkflowData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
return await generate_client_response(client_request, model_response)
@dataclasses.dataclass
class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
return CustomComfyWorkflowData
def make_benchmark_payload(self) -> CustomComfyWorkflowData:
return CustomComfyWorkflowData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
return await generate_client_response(client_request, model_response)
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=False,
benchmark_handler=DefaultComfyWorkflowHandler(
benchmark_runs=3, benchmark_words=100
),
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, "Downloading:"),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/prompt", backend.create_handler(DefaultComfyWorkflowHandler())),
web.post("/custom-workflow", backend.create_handler(CustomComfyWorkflowHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
-15
View File
@@ -1,15 +0,0 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import DefaultComfyWorkflowData, Model
WORKER_ENDPOINT = "/prompt"
if __name__ == "__main__":
test_args.add_argument(
"-m",
dest="comfy_model",
choices=list(map(lambda x: x.value, Model)),
required=True,
help="Image generation model name",
)
test_load_cmd(DefaultComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
-321
View File
@@ -1,321 +0,0 @@
# Vast PyWorker
## Hello_world example
There is a hello_world PyWorker implementation under `workers/hello_world`. This PyWorker is
created for an LLM model server that runs on port 5001 has two API endpoints:
1. `/generate`: generates an full response to the prompt and sends a JSON response
2. `/generate_stream`: streams a response one token at a time
Both of these endpoints take the same API JSON payload:
```
{
"prompt": String,
"max_response_tokens": Number | null
}
```
We want the PyWorker to also expose two endpoints that correspond to the above endpoints.
### Structure
All PyWorkers have four files:
```
.
└── workers
└── hello_world
├── __init__.py
├── data_types.py # contains data types representing model API endpoints
├── server.py # contains endpoint handlers
└── test_load.py # script for load testing
```
All of the classes follow strict type hinting. It is recommended that you type hint all of your function.
This will allow your IDE or VSCode with `pyright` plugin to find any type errors in your implementation.
You can also install `pyright` with `sudo npm install -g pyright` and run `pyright` in the root of the project to find
any type errors.
### data_types.py: Contains data types representing model API endpoints
This file defines the structure of the data your model server expects (its API contract) and, critically, how PyWorker *interprets* that data for autoscaling purposes. You define Python data classes that mirror the JSON payloads your model's API uses.
These classes **must** inherit from `lib.data_types.ApiPayload`. This inheritance is not just for structure; it's how PyWorker knows how to:
* **Parse Incoming Requests:** Convert JSON from clients into usable Python objects.
* **Calculate Workload:** Determine the computational cost of a request.
* **Generate Test Data:** Create realistic inputs for benchmarking.
* **Format Requests for the Model Server:** Prepare data for the underlying model.
```python
import dataclasses
import random
from typing import Dict, Any
from transformers import OpenAIGPTTokenizer # used to count tokens in a prompt
import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model
from lib.data_types import ApiPayload
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
# you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs
tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
@dataclasses.dataclass
class InputData(ApiPayload):
prompt: str
max_response_tokens: int
@classmethod
def for_test(cls) -> "ApiPayload":
"""defines how create a payload for load testing"""
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
return cls(prompt=prompt, max_response_tokens=300)
def generate_payload_json(self) -> Dict[str, Any]:
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
return dataclasses.asdict(self)
def count_workload(self) -> float:
"""defines how to calculate workload for a payload"""
return len(tokenizer.tokenize(self.prompt))
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
"""
defines how to transform JSON data to AuthData and payload type,
in this case `InputData` defined above represents the data sent to the model API.
AuthData is data generated by autoscaler in order to authenticate payloads.
In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker
for more complicated examples
"""
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
```
### server.py: Creating Your Model's API Endpoints
This section guides you through creating the core of your custom model API: the `EndpointHandler`. Think of `EndpointHandler` as the bridge between incoming requests from users and your underlying model. It's the key to making your model accessible and scalable.
**Why use an `EndpointHandler`?**
* **Organized Request Handling:** It provides a structured way to handle different types of requests (like generating text, generating images, or performing other model-specific tasks).
* **Scalability:** By separating request handling from the model itself, you can easily scale your API to handle many concurrent users.
* **Flexibility:** You can customize how requests are processed, validated, and transformed before being sent to your model.
* **Standard Interface:** It provides a consistent interface for interacting with your model, regardless of the underlying implementation.
For every model API endpoint you want to expose (e.g., `/generate`, `/generate_stream`), you'll implement an `EndpointHandler`. This class is responsible for:
The `EndpointHandler` achieves this through several key methods:
* **Receiving and validating incoming requests (`get_data_from_request`):** This method ensures the request contains the necessary data (authentication and payload) and is in the correct format. It's the entry point for all requests.
* **Defining the endpoint (`endpoint`):** This method specifies the URL endpoint on the model API server where requests will be sent (e.g., `/generate`).
* **Specifying the payload type (`payload_cls`):** This method indicates the specific `ApiPayload` class used for this endpoint, defining the structure of the request data.
* **Creating benchmark payloads (`make_benchmark_payload`):** This method creates payloads specifically for benchmarking the model's performance.
* **Handling the model's response (`generate_client_response`):** This method takes the response from the model API server and transforms it into the format expected by the client making the request to your PyWorker. This allows you to customize the output as needed.
The `EndpointHandler` class has several abstract functions that you *must* implement to define the behavior of your specific endpoints. Here, we'll implement two common endpoints: `/generate` (for synchronous requests) and `/generate_stream` (for streaming responses):
```python
"""
AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route.
When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData
work.
When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON:
{
auth_data: AuthData,
payload : InputData # defined above
}
"""
from aiohttp import web
from lib.data_types import EndpointHandler, JsonDataException
from lib.server import start_server
from .data_types import InputData
# This class is the implementer for the '/generate' endpoint of model API
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
# the API endpoint
return "/generate"
@classmethod
def payload_cls(cls) -> Type[InputData]:
"""this function should just return ApiPayload subclass used by this handler"""
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
"""
defines how to convert `InputData` defined above, to
JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but
can be more complicated, See comfyui for an example
"""
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
"""
defines how to generate an InputData for benchmarking. This needs to be defined in only
one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test()
method on InputData. However, in some cases you might need to fine tune your InputData used for
benchmarking to closely resemble the average request users call the endpoint with in order to get best
autoscaling performance
"""
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
"""
defines how to convert a model API response to a response to PyWorker client
"""
_ = client_request
match model_response.status:
case 200:
log.debug("SUCCESS")
data = await model_response.json()
return web.json_response(data=data)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
```
We also handle `GenerateStreamHandler` for streaming responses. It is identical to `GenerateHandler`, except for
the endpoint name and how we create a web response, as it is a streaming response:
```python
class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("Streaming response...")
res = web.StreamResponse()
res.content_type = "text/event-stream"
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
```
You can now instantiate a Backend and use it to handle requests.
```python
from lib.backend import Backend, LogAction
# the url and port of model API
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the log line that is emitted once the server has started
MODEL_SERVER_START_LOG_MSG = "server has started"
MODEL_SERVER_ERROR_LOG_MSGS = [
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
]
backend = Backend(
model_server_url=MODEL_SERVER_URL,
# location of model log file
model_log_file=os.environ["MODEL_LOG"],
# for some model backends that can only handle one request at a time, be sure to set this to False to
# let PyWorker handling queueing requests.
allow_parallel_requests=True,
# give the backend an EndpointHandler instance that is used for benchmarking
# number of benchmark run and number of words for a random benchmark run are given
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
# defines how to handle specific log messages. See docstring of LogAction for details
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
# this is a simple ping handler for PyWorker
async def handle_ping(_: web.Request):
return web.Response(body="pong")
# this is a handler for forwarding a health check to model API
async def handle_healthcheck(_: web.Request):
healthcheck_res = await backend.session.get("/healthcheck")
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
routes = [
web.post("/generate", backend.create_handler(GenerateHandler())),
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
web.get("/ping", handle_ping),
web.get("/healthcheck", handle_healthcheck),
]
if __name__ == "__main__":
# start server, called from start_server.sh
start_server(backend, routes)
```
### test_load.py
Here you can create a script that allows you test an endpoint group running instances with this PyWorker
```python
from lib.test_harness import run
from .data_types import InputData
WORKER_ENDPOINT = "/generate"
if __name__ == "__main__":
run(InputData.for_test(), WORKER_ENDPOINT)
```
You can then run the following command from the root of this repo to load test endpoint group:
```sh
# sends 1000 requests at the rate of 0.5 requests per second
python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"
```
View File
-48
View File
@@ -1,48 +0,0 @@
import dataclasses
import random
import inspect
from typing import Dict, Any
from transformers import OpenAIGPTTokenizer
import nltk
from lib.data_types import ApiPayload, JsonDataException
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
# used to count to count tokens and workload for LLM
tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
@dataclasses.dataclass
class InputData(ApiPayload):
prompt: str
max_response_tokens: int
@classmethod
def for_test(cls) -> "InputData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
return cls(prompt=prompt, max_response_tokens=300)
def generate_payload_json(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
def count_workload(self) -> int:
return len(tokenizer.tokenize(self.prompt))
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
-175
View File
@@ -1,175 +0,0 @@
"""
PyWorker works as a man-in-the-middle between the client and model API. It's function is:
1. receive request from client, update metrics such as workload of a request, number of pending requests, etc.
2a. transform the data and forward the transformed data to model API
2b. send updated metrics to autoscaler
3. transform response from model API(if needed) and forward the response to client
PyWorker forward requests to many model API endpoint. each endpoint must have an EndpointHandler. You can also
write function to just forward requests that don't generate anything with the model to model API without an
EndpointHandler. This is useful for endpoints such as healthchecks. See below for example
"""
import os
import logging
import dataclasses
from typing import Dict, Any, Optional, Union, Type
from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import InputData
# the url and port of model API
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the log line that is emitted once the server has started
MODEL_SERVER_START_LOG_MSG = "infer server has started"
MODEL_SERVER_ERROR_LOG_MSGS = [
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
# This class is the implementer for the '/generate' endpoint of model API
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
# the API endpoint
return "/generate"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
"""
defines how to convert `InputData` defined above, to
json data to be sent to the model API
"""
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
"""
defines how to generate an InputData for benchmarking. This needs to be defined in only
one EndpointHandler, the one passed to the backend as the benchmark handler
"""
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
"""
defines how to convert a model API response to a response to PyWorker client
"""
_ = client_request
match model_response.status:
case 200:
log.debug("SUCCESS")
data = await model_response.json()
return web.json_response(data=data)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
# This is the same as GenerateHandler, except that it calls a streaming endpoint of the model API and streams the
# response, which itself is streaming, back to the client.
# it is nearly identical to handler as above, but it calls a different model API endpoint and it streams the
# streaming response from model API to client
class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("Streaming response...")
res = web.StreamResponse()
res.content_type = "text/event-stream"
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
# This is the backend instance of pyworker. Only one must be made which uses EndpointHandlers to process
# incoming requests
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
# give the backend a handler instance that is used for benchmarking
# number of benchmark run and number of words for a random benchmark run are given
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
# defines how to handle specific log messages. See docstring of LogAction for details
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
# this is a simple ping handler for pyworker
async def handle_ping(_: web.Request):
return web.Response(body="pong")
# this is a handler for forwarding a health check to modelAPI
async def handle_healthcheck(_: web.Request):
healthcheck_res = await backend.session.get("/healthcheck")
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
routes = [
web.post("/generate", backend.create_handler(GenerateHandler())),
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
web.get("/ping", handle_ping),
web.get("/healthcheck", handle_healthcheck),
]
if __name__ == "__main__":
# start the PyWorker server
start_server(backend, routes)
-7
View File
@@ -1,7 +0,0 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import InputData
WORKER_ENDPOINT = "/generate"
if __name__ == "__main__":
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
-77
View File
@@ -1,77 +0,0 @@
# <INFERENCE_SERVER> + <MODEL_NAME> (serverless)
Run <INFERENCE_SERVER> with our serverless autoscaling infrastructure.
See the [serverless documentation](https://docs.vast.ai/serverless) and the [Getting Started](https://docs.vast.ai/serverless/getting-started) guide for in-depth details about how to use these templates.
## Configuration
Two environment variables are provided to help you configure the <INFERENCE_SERVER> server:
| Variable | Default Value | Used For |
| --- | --- | --- |
| `MODEL_NAME` | `<MODEL_NAME>` | The model to load. Also accepts [hf.co/repo/model](#) links |
| `<ARGS_VAR>` | `<ARGS_VAL>` | Arguments to pass to the `<ARGS_RECEIVER>` command |
This template has been configured to work with <MIN_VRAM> VRAM. Setting alternative models and server arguments will change the VRAM requirements. Check model cards and <INFERENCE_SERVER_DOCS> for guidance.
## Usage
We have provided a demonstration client to help you implement this template into your own infrastructure
### Client Setup
Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
```
### Interactive Chat (streaming)
Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
```
+27 -35
View File
@@ -102,15 +102,13 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, endpo
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
}
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
}
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"])
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
@@ -118,17 +116,15 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["max_tokens"])
return resp["response"]
# ---- Streaming variants ----
@@ -137,17 +133,15 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, end
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
}
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
}
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
@@ -155,18 +149,16 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True)
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["max_tokens"], stream=True)
return resp["response"] # async generator
-58
View File
@@ -1,58 +0,0 @@
import json
from dataclasses import dataclass, field, fields, is_dataclass
from typing import Optional, List, Dict, Any
class SerializableDataclass:
def _serialize_recursive(self, obj: Any) -> Any:
if is_dataclass(obj):
return {
field.name: self._serialize_recursive(getattr(obj, field.name))
for field in fields(obj)
}
elif isinstance(obj, dict):
return {key: self._serialize_recursive(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return [self._serialize_recursive(item) for item in obj]
elif isinstance(obj, set):
return [self._serialize_recursive(item) for item in obj]
else:
return obj
def to_dict(self) -> Dict[str, Any]:
return self._serialize_recursive(self)
def to_json(self, indent: int = 2) -> str:
return json.dumps(self.to_dict(), indent=indent)
@dataclass
class CompletionConfig(SerializableDataclass):
"""Configuration for completion requests"""
model: str
prompt: str = "Hello"
max_tokens: int = 256
temperature: float = 0.7
top_k: int = 20
top_p: float = 0.4
stream: bool = False
@dataclass
class ChatCompletionConfig(SerializableDataclass):
"""Configuration for chat completion requests"""
model: str
messages: list = field(default_factory=list)
max_tokens: int = 2096
temperature: float = 0.7
top_k: int = 20
top_p: float = 0.4
stream: bool = False
tools: Optional[List[Dict[str, Any]]] = field(default_factory=list)
tool_choice: str = "auto"
def __post_init__(self):
if self.messages is None:
self.messages = [{"role": "user", "content": "Hello"}]
-207
View File
@@ -1,207 +0,0 @@
import os, json, random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
from typing import Union, Type, Dict, Any, Optional
from aiohttp import web, ClientResponse
import nltk
import logging
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
log = logging.getLogger(__name__)
"""
Generic dataclass accepts any dictionary in input.
"""
@dataclass
class GenericData(ApiPayload, ABC):
input: Dict[str, Any]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
return cls(input=data["input"])
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
errors = {}
# Validate required parameters
required_params = ["input"]
for param in required_params:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
try:
# Create clean data dict and delegate to from_dict
clean_data = {"input": json_msg["input"]}
return cls.from_dict(clean_data)
except (json.JSONDecodeError, JsonDataException) as e:
errors["parameters"] = str(e)
raise JsonDataException(errors)
@classmethod
@abstractmethod
def for_test(cls) -> "GenericData":
pass
def generate_payload_json(self) -> Dict[str, Any]:
return self.input
def count_workload(self) -> int:
return self.input.get("max_tokens", 0)
@dataclass
class GenericHandler(EndpointHandler[GenericData], ABC):
@property
@abstractmethod
def endpoint(self) -> str:
pass
@property
def healthcheck_endpoint(self) -> Optional[str]:
return os.environ.get("MODEL_HEALTH_ENDPOINT")
@classmethod
def payload_cls(cls) -> Type[GenericData]:
return GenericData
@abstractmethod
def make_benchmark_payload(self) -> GenericData:
pass
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
# Check if the response is actually streaming based on response headers/content-type
is_streaming_response = (
model_response.content_type == "text/event-stream"
or model_response.content_type == "application/x-ndjson"
or model_response.headers.get("Transfer-Encoding") == "chunked"
or "stream" in model_response.content_type.lower()
)
if is_streaming_response:
log.debug("Detected streaming response...")
res = web.StreamResponse()
res.content_type = model_response.content_type
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
else:
log.debug("Detected non-streaming response...")
content = await model_response.read()
return web.Response(
body=content,
status=200,
content_type=model_response.content_type,
)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
@dataclass
class CompletionsData(GenericData):
@classmethod
def for_test(cls) -> "CompletionsData":
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
test_input = {
"model": model,
"prompt": f"{system_prompt}\n\n{unique_question}",
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input)
@dataclass
class CompletionsHandler(GenericHandler):
@property
def endpoint(self) -> str:
return "/v1/completions"
@classmethod
def payload_cls(cls) -> Type[CompletionsData]:
return CompletionsData
def make_benchmark_payload(self) -> CompletionsData:
return CompletionsData.for_test()
@dataclass
class ChatCompletionsData(GenericData):
"""Chat completions-specific data implementation"""
@classmethod
def for_test(cls) -> "ChatCompletionsData":
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
# Chat completions use messages format instead of prompt
test_input = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt}, # Shared prefix
{"role": "user", "content": unique_question} # Unique per request
],
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input)
@dataclass
class ChatCompletionsHandler(GenericHandler):
@property
def endpoint(self) -> str:
return "/v1/chat/completions"
@classmethod
def payload_cls(cls) -> Type[ChatCompletionsData]:
return ChatCompletionsData
def make_benchmark_payload(self) -> ChatCompletionsData:
return ChatCompletionsData.for_test()
-62
View File
@@ -1,62 +0,0 @@
import os
import logging
from .data_types.server import CompletionsHandler, ChatCompletionsHandler
from aiohttp import web
from lib.backend import Backend, LogAction
from lib.server import start_server
# This line indicates that the inference server is listening
MODEL_SERVER_START_LOG_MSG = [
"Application startup complete.", # vLLM
"llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI
"main: model loaded" # llama.cpp
]
MODEL_SERVER_ERROR_LOG_MSGS = [
"INFO exited: vllm", # vLLM
"RuntimeError: Engine", # vLLM
"Error: pull model manifest:", # Ollama
"stalled; retrying", # Ollama
"Error: WebserverFailed", # TGI
"Error: DownloadError", # TGI
"Error: ShardCannotStart", # TGI
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
backend = Backend(
model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
max_wait_time=600.0,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
-434
View File
@@ -1,434 +0,0 @@
from lib.test_utils import test_args
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from lib.data_types import AuthData
from .data_types.server import CompletionsData
import os
import time
import threading
import requests
from dataclasses import dataclass
from collections import Counter
from urllib.parse import urljoin, urlparse
import re
# Headless plotting
import matplotlib
matplotlib.use("Agg")
import logging
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
import matplotlib.pyplot as plt
import numpy as np
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
from requests.adapters import HTTPAdapter
def get_incremented_path(path: str) -> str:
base, ext = os.path.splitext(path)
if not os.path.exists(path):
return path
i = 1
while os.path.exists(f"{base}-{i}{ext}"):
i += 1
return f"{base}-{i}{ext}"
WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT)
@dataclass
class ReqResult:
worker_url: str
route_ms: float
worker_ms: float
total_ms: float
ok: bool
error: str = ""
status_code: int = 0
t_start: float = 0.0
t_end: float = 0.0
workload: float = 0.0
def do_one(endpoint_name: str,
endpoint_id: int,
endpoint_api_key: str,
server_url: str,
worker_endpoint: str,
payload,
results_list,
t0,
status_samples,
route_session,
worker_session):
try:
workload = payload.count_workload()
route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload}
headers = {"Authorization": f"Bearer {endpoint_api_key}"}
start = time.time()
r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4)
t_after_route = time.time()
if r0.status_code != 200:
results_list.append(ReqResult(worker_url="",
route_ms=(t_after_route - start) * 1000.0,
worker_ms=0.0,
total_ms=(t_after_route - start) * 1000.0,
ok=False,
error=f"route error {r0.reason} {r0.text}",
status_code=r0.status_code,
t_start=start - t0,
t_end=t_after_route - t0,
workload=workload))
return
msg = r0.json()
# 1) Check if we got a worker back from route
worker_url = msg.get("url", "")
if not worker_url:
status = msg.get("status", "")
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
if m:
tot, loading, standby, err = map(int, m.groups())
idle = max(tot - loading - standby - err, 0)
status_samples.append((time.time() - t0, idle))
# 2) If we got a worker, send the request
if worker_url:
req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__)
t_before_worker = time.time()
r1 = worker_session.post(
urljoin(worker_url, worker_endpoint),
json=req,
verify=get_cert_file_path(),
timeout=(4, 120),
)
t_after_worker = time.time()
if r1.status_code != 200:
results_list.append(ReqResult(worker_url=worker_url,
route_ms=(t_after_route - start) * 1000.0,
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
total_ms=(t_after_worker - start) * 1000.0,
ok=False,
error=f"worker inference error {r1.reason} {r1.text}",
status_code=r1.status_code,
t_start=start - t0,
t_end=t_after_worker - t0,
workload=workload))
return
# Success case
results_list.append(ReqResult(worker_url=worker_url,
route_ms=(t_after_route - start) * 1000.0,
worker_ms=(t_after_worker - t_before_worker) * 1000.0,
total_ms=(t_after_worker - start) * 1000.0,
ok=True,
error="",
status_code=200,
t_start=start - t0,
t_end=t_after_worker - t0,
workload=workload))
# 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking
if worker_url:
try:
r_status = route_session.post(
urljoin(server_url, "/get_endpoint_workers/"),
json={"id": endpoint_id},
headers={"Authorization": f"Bearer {endpoint_api_key}"},
timeout=3,
)
if r_status.status_code == 200:
workers = r_status.json()
idle = 0
for w in workers:
st = str(w.get("status", "")).lower()
if (st in ("idle")):
idle += 1
status_samples.append((time.time() - t0, idle))
except Exception:
pass
except Exception as e:
t = time.time()
results_list.append(ReqResult(worker_url="",
route_ms=0.0,
worker_ms=0.0,
total_ms=0.0,
ok=False,
error=f"unknown error {e}",
status_code=0,
t_start=t - t0,
t_end=t - t0,
workload=0.0))
def run_load_with_metrics(num_requests: int,
requests_per_second: float,
endpoint_group_name: str,
account_api_key: str,
server_url: str,
worker_endpoint: str,
instance: str,
out_path: str):
ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name,
account_api_key=account_api_key,
instance=instance)
if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"):
print(f"Endpoint {endpoint_group_name} not found for API key")
return
endpoint_id = int(ep_info["id"])
endpoint_api_key = ep_info["api_key"]
t0 = time.time()
results = []
status_samples = []
max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192"))
submit_queue_factor = 2 # cap queued tasks to reduce memory
# Shared HTTP sessions with connection pooling (persistent connections)
def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session:
sess = requests.Session()
adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0)
sess.mount("https://", adapter)
sess.mount("http://", adapter)
return sess
# Router: mostly single host, small connection pool is sufficient
route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency)
# Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency
worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8)
# Fire requests using a thread pool, scheduling at requested RPS
inflight = set()
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
for i in range(num_requests):
# Pace submissions to RPS
target_time = t0 + i / max(requests_per_second, 1e-9)
sleep_s = target_time - time.time()
if sleep_s > 0:
time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive
payload = CompletionsData.for_test()
fut = executor.submit(
do_one,
endpoint_group_name,
endpoint_id,
endpoint_api_key,
server_url,
worker_endpoint,
payload,
results,
t0,
status_samples,
route_session,
worker_session,
)
inflight.add(fut)
# Prevent unbounded queue growth
if len(inflight) >= max_concurrency * submit_queue_factor:
done, not_done = wait(inflight, return_when=FIRST_COMPLETED)
inflight = not_done
# Wait for all outstanding tasks
if inflight:
wait(inflight)
# Close sessions
try:
route_session.close()
finally:
worker_session.close()
# Aggregate results
oks = [r for r in results if r.ok]
errs = [r for r in results if not r.ok]
total_reqs = len(results)
succ = len(oks)
total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([])
worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([])
route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([])
avg_total = float(np.mean(total_ms)) if succ else 0.0
avg_worker = float(np.mean(worker_ms)) if succ else 0.0
avg_route = float(np.mean(route_ms)) if succ else 0.0
p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0)
# Distribution over workers (by host:port)
hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url]
dist = Counter(hosts)
# Idle over time (mode per second)
idle_ts, idle_vals = [], []
if status_samples:
buckets = {}
for ts, idle in status_samples:
k = int(ts)
buckets.setdefault(k, []).append(idle)
keys = sorted(buckets.keys())
idle_ts = keys
# Use the most frequent sampled value per second (mode) to keep integer counts
idle_vals = []
for k in keys:
vals_k = [int(v) for v in buckets[k]]
if vals_k:
cnt = Counter(vals_k)
idle_vals.append(cnt.most_common(1)[0][0])
else:
idle_vals.append(0)
print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}")
print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}")
print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}")
if errs:
print("Sample errors:")
for e in errs[:5]:
print(f" {e.status_code} {e.error}")
# Plot: 2x3 grid
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}")
# Dist per worker
ax0 = axes[0, 0]
if dist:
items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True)
labels, counts = zip(*items)
ax0.bar(range(len(labels)), counts)
ax0.set_xticks(range(len(labels)))
ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
ax0.set_title("Request distribution over workers")
ax0.set_ylabel("count")
# Latency histogram (total)
ax1 = axes[0, 1]
if succ:
ax1.hist(total_ms, bins=30)
ax1.set_title("Total latency (ms)")
ax1.set_xlabel("ms")
ax1.set_ylabel("freq")
# Eligible workers over time
ax_idle = axes[0, 2]
if idle_ts:
ax_idle.plot(idle_ts, idle_vals, "-o", ms=3)
ax_idle.set_title("Eligible workers over time")
ax_idle.set_xlabel("time (s)")
ax_idle.set_ylabel("eligible count")
# Throughput over time (completions/sec)
ax_idle = axes[1, 0]
ax_idle.clear()
if succ:
per_sec = {}
for r in oks:
s = int(r.t_end)
per_sec[s] = per_sec.get(s, 0) + 1
ts = sorted(per_sec.keys())
vals = [per_sec[t] for t in ts]
ax_idle.plot(ts, vals, "-o", ms=3)
ax_idle.set_title("Completions per second")
ax_idle.set_xlabel("time (s)")
ax_idle.set_ylabel("completions / sec")
# Summary text
ax3 = axes[1, 1]
ax3.axis("off")
text = (
f"Total requests: {total_reqs}\n"
f"Success: {succ} Errors: {len(errs)}\n"
f"Avg total latency: {avg_total:.1f} ms\n"
f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n"
f"Avg route latency: {avg_route:.1f} ms\n"
f"Avg worker latency: {avg_worker:.1f} ms\n"
f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n"
f"429 errors: {len([r for r in errs if r.status_code == 429])}\n"
f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n"
f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n"
)
ax3.set_title("Summary")
ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes)
# Error count over time
ax_errors = axes[1, 2]
all_end_times = [int(r.t_end) for r in results if r.t_end > 0]
if all_end_times:
min_second = min(all_end_times)
max_second = max(all_end_times)
# Count errors per second
errors_per_second = {}
for result in errs:
second = int(result.t_end)
errors_per_second[second] = errors_per_second.get(second, 0) + 1
# Create complete timeline including zeros
time_seconds = list(range(min_second, max_second + 1))
error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds]
ax_errors.plot(time_seconds, error_counts, "-o", ms=3)
ax_errors.set_title("Errors per second")
ax_errors.set_xlabel("time (s)")
ax_errors.set_ylabel("errors / sec")
# Ensure unique output path and create directory if needed
final_out_path = get_incremented_path(out_path)
out_dir = os.path.dirname(final_out_path)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(final_out_path, dpi=120)
print(f"Saved report to: {final_out_path}")
# Per-worker latency boxplot (top 12 by volume)
groups = {}
for r in oks:
host = urlparse(r.worker_url).netloc
groups.setdefault(host, []).append(r.total_ms)
items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12]
if items:
labels, data = zip(*items)
fig2, axb = plt.subplots(1, 1, figsize=(12, 5))
axb.boxplot(data, showfliers=False)
axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
axb.set_title("Per-worker latency (ms)")
axb.set_ylabel("ms")
plt.tight_layout()
extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png")
plt.savefig(extra_out, dpi=120)
fig2.tight_layout()
fig2.savefig(extra_out, dpi=120)
print(f"Saved worker latency plot to: {extra_out}")
if __name__ == "__main__":
# Check if MODEL_NAME environment variable is set
model_name_set = os.environ.get("MODEL_NAME") is not None
# Add model argument - required only if MODEL_NAME is not set
test_args.add_argument(
"--model",
dest="model",
required=not model_name_set,
help="Model to use for completions request (required if MODEL_NAME env var not set)",
)
# Parse known args to get model early, before adding load args
known_args, _ = test_args.parse_known_args()
if hasattr(known_args, "model") and known_args.model:
os.environ["MODEL_NAME"] = known_args.model
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
# Load test args
test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests")
test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second")
test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image")
args = test_args.parse_args()
server_url = {
"prod": "https://run.vast.ai",
"alpha": "https://run-alpha.vast.ai",
"candidate": "https://run-candidate.vast.ai",
"local": "http://localhost:8080"
}.get(args.instance, "http://localhost:8080")
run_load_with_metrics(
num_requests=args.num_requests,
requests_per_second=args.requests_per_second,
endpoint_group_name=args.endpoint_group_name,
account_api_key=args.api_key,
server_url=server_url,
worker_endpoint=WORKER_ENDPOINT,
instance=args.instance,
out_path=args.out_path,
)
+78
View File
@@ -0,0 +1,78 @@
import nltk
import random
import os
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# vLLM model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18000
MODEL_LOG_FILE = '/var/log/portal/vllm.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# vLLM-specific log messages
MODEL_LOAD_LOG_MSG = [
"Application startup complete.",
]
MODEL_ERROR_LOG_MSGS = [
"INFO exited: vllm",
"RuntimeError: Engine",
"Traceback (most recent call last):"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Download'
]
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
def completions_benchmark_generator() -> dict:
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
benchmark_data = {
"model": model,
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
return benchmark_data
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/v1/completions",
workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True,
max_queue_time=60.0,
benchmark_config=BenchmarkConfig(
generator=completions_benchmark_generator,
concurrency=100,
runs=2
)
),
HandlerConfig(
route="/v1/chat/completions",
workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True,
max_queue_time=60.0,
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
-73
View File
@@ -1,73 +0,0 @@
import dataclasses
import random
import inspect
from typing import Dict, Any
from transformers import OpenAIGPTTokenizer
import nltk
from lib.data_types import ApiPayload, JsonDataException
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
@dataclasses.dataclass
class InputParameters:
max_new_tokens: int = 256
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputParameters":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
@dataclasses.dataclass
class InputData(ApiPayload):
inputs: str
parameters: InputParameters
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "InputData":
return cls(
inputs=data["inputs"], parameters=InputParameters(**data["parameters"])
)
@classmethod
def for_test(cls) -> "InputData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
return cls(inputs=prompt, parameters=InputParameters())
def generate_payload_json(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
def count_workload(self) -> int:
return self.parameters.max_new_tokens
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
try:
parameters = InputParameters.from_json_msg(json_msg["parameters"])
return cls(inputs=json_msg["inputs"], parameters=parameters)
except JsonDataException as e:
errors["parameters"] = e.message
raise JsonDataException(errors)
-130
View File
@@ -1,130 +0,0 @@
import os
import logging
from typing import Union, Type
import dataclasses
from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import InputData
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = [
'"message":"Connected","target":"text_generation_router"',
'"message":"Connected","target":"text_generation_router::server"',
]
MODEL_SERVER_ERROR_LOG_MSGS = [
"Error: WebserverFailed",
"Error: DownloadError",
"Error: ShardCannotStart",
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
_ = client_request
match model_response.status:
case 200:
log.debug("SUCCESS")
data = await model_response.json()
return web.json_response(data=data)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("Streaming response...")
res = web.StreamResponse()
res.content_type = "text/event-stream"
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/generate", backend.create_handler(GenerateHandler())),
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
-7
View File
@@ -1,7 +0,0 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import InputData
WORKER_ENDPOINT = "/generate"
if __name__ == "__main__":
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
+76
View File
@@ -0,0 +1,76 @@
import nltk
import random
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# TGI model configuration
MODEL_SERVER_URL = 'http://0.0.0.0'
MODEL_SERVER_PORT = 5001
MODEL_LOG_FILE = "/workspace/infer.log"
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# TGI-specific log messages
MODEL_LOAD_LOG_MSG = [
'"message":"Connected","target":"text_generation_router"',
'"message":"Connected","target":"text_generation_router::server"',
]
MODEL_ERROR_LOG_MSGS = [
"Error: WebserverFailed",
"Error: DownloadError",
"Error: ShardCannotStart",
]
MODEL_INFO_LOG_MSGS = [
'"message":"Download'
]
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
def benchmark_generator() -> dict:
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
benchmark_data = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 128,
"temperature": 0.7,
"return_full_text": False
}
}
return benchmark_data
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate",
allow_parallel_requests=True,
max_queue_time=60.0,
benchmark_config=BenchmarkConfig(
generator=benchmark_generator,
concurrency=50
),
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
),
HandlerConfig(
route="/generate_stream",
allow_parallel_requests=True,
max_queue_time=60.0,
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+170
View File
@@ -0,0 +1,170 @@
# ComfyUI Wan 2.2 PyWorker
This is the PyWorker implementation for running **Wan 2.2 T2V A14B** text-to-video workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI video-generation workflows through a proxy-based architecture and returning generated video assets.
Each request has a static cost of `10000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
## Requirements
This worker requires the following components:
- ComfyUI (https://github.com/comfyanonymous/ComfyUI)
- ComfyUI API Wrapper (https://github.com/ai-dock/comfyui-api-wrapper)
- Wan 2.2 T2V A14B models and required custom nodes
A Docker image is provided with all required Wan 2.2 models pre-installed, but any image may be used if the above requirements are met.
## Endpoint
The worker exposes a single synchronous endpoint:
- `/generate/sync`: Processes a complete ComfyUI workflow JSON and generates video output
## Request Format
The Wan 2.2 worker **only supports custom workflow mode**. Modifier-based workflows are not supported.
```json
{
"input": {
"request_id": "uuid-string",
"workflow_json": {
// Complete ComfyUI Wan 2.2 workflow JSON
},
"s3": { },
"webhook": { }
}
}
```
## Request Fields
### Required Fields
- `input`: Container for all request parameters
- `input.workflow_json`: Complete ComfyUI workflow graph for Wan 2.2 video generation
### Optional Fields
- `input.request_id`: Client-defined request identifier
- `input.s3`: S3-compatible storage configuration
- `input.webhook`: Webhook configuration for completion notifications
The special string `"__RANDOM_INT__"` may be used in the workflow JSON and will be replaced with a random integer before submission to ComfyUI.
## S3 Configuration
Generated video assets can be automatically uploaded to S3-compatible storage. Configuration can be supplied per request or via environment variables. Request-level values take precedence.
### Via Request JSON
```json
"s3": {
"access_key_id": "your-s3-access-key",
"secret_access_key": "your-s3-secret-access-key",
"endpoint_url": "https://s3.amazonaws.com",
"bucket_name": "your-bucket",
"region": "us-east-1"
}
```
### Via Environment Variables
```bash
S3_ACCESS_KEY_ID=your-key
S3_SECRET_ACCESS_KEY=your-secret
S3_BUCKET_NAME=your-bucket
S3_ENDPOINT_URL=https://s3.amazonaws.com
S3_REGION=us-east-1
```
## Webhook Configuration
Webhooks are triggered on request completion or failure.
### Via Request JSON
```json
"webhook": {
"url": "https://your-webhook-url",
"extra_params": {
"custom_field": "value"
}
}
```
### Via Environment Variables
```bash
WEBHOOK_URL=https://your-webhook-url
WEBHOOK_TIMEOUT=30
```
## Example Request
### Wan 2.2 Text-to-Video Workflow
```json
{
"input": {
"workflow_json": {
"90": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader"
},
"99": {
"inputs": {
"text": "A cinematic slow-motion portrait of a woman turning her head",
"clip": ["90", 0]
},
"class_type": "CLIPTextEncode"
},
"104": {
"inputs": {
"width": 640,
"height": 640,
"length": 81,
"batch_size": 1
},
"class_type": "EmptyHunyuanLatentVideo"
}
}
}
}
```
## Response Format
A successful response includes execution metadata, ComfyUI output details, and generated video assets.
### Response Fields
- `id`: Unique request identifier
- `status`: `completed`, `failed`, `processing`, `generating`, or `queued`
- `message`: Human-readable status message
- `comfyui_response`: Raw response from ComfyUI, including execution status and progress
- `output`: Array of generated outputs
- `timings`: Timing information for the request
### Output Object
Each entry in `output` includes:
- `filename`: Generated file name (e.g., `.mp4`)
- `local_path`: File path on the worker
- `url`: Pre-signed download URL (if S3 is configured)
- `type`: Output type (`output`)
- `subfolder`: Output directory (e.g., `video`)
- `node_id`: ComfyUI node that produced the output
- `output_type`: Output category (e.g., `images`)
## Notes and Limitations
- Only full ComfyUI workflow JSONs are supported
- Concurrent requests are not supported per worker
- Wan 2.2 models must be installed before processing requests
- Video generation workflows may take several minutes depending on resolution, length, and GPU performance
+205
View File
@@ -0,0 +1,205 @@
from vastai import Serverless
import asyncio
async def main():
async with Serverless() as client:
endpoint = await client.get_endpoint(name="my-wan-endpoint")
# ComfyUI API compatible json workflow for Wan 2.2 T2V
workflow = {
"90": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader",
"_meta": {
"title": "Load CLIP"
}
},
"91": {
"inputs": {
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW",
"clip": ["90", 0]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"92": {
"inputs": {
"vae_name": "wan_2.1_vae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"93": {
"inputs": {
"shift": 8.000000000000002,
"model": ["101", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"94": {
"inputs": {
"shift": 8,
"model": ["102", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"95": {
"inputs": {
"add_noise": "disable",
"noise_seed": 0,
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 10,
"end_at_step": 10000,
"return_with_leftover_noise": "disable",
"model": ["94", 0],
"positive": ["99", 0],
"negative": ["91", 0],
"latent_image": ["96", 0]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"96": {
"inputs": {
"add_noise": "enable",
"noise_seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 0,
"end_at_step": 10,
"return_with_leftover_noise": "enable",
"model": ["93", 0],
"positive": ["99", 0],
"negative": ["91", 0],
"latent_image": ["104", 0]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"97": {
"inputs": {
"samples": ["95", 0],
"vae": ["92", 0]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"98": {
"inputs": {
"filename_prefix": "video/ComfyUI",
"format": "auto",
"codec": "auto",
"video": ["100", 0]
},
"class_type": "SaveVideo",
"_meta": {
"title": "Save Video"
}
},
"99": {
"inputs": {
"text": "Beautiful young European woman with honey blonde hair gracefully turning her head back over shoulder, gentle smile, bright eyes looking at camera. Hair flowing in slow motion as she turns. Soft natural lighting, clean background, cinematic portrait.",
"clip": ["90", 0]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"100": {
"inputs": {
"fps": 16,
"images": ["97", 0]
},
"class_type": "CreateVideo",
"_meta": {
"title": "Create Video"
}
},
"101": {
"inputs": {
"unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"102": {
"inputs": {
"unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"104": {
"inputs": {
"width": 640,
"height": 640,
"length": 81,
"batch_size": 1
},
"class_type": "EmptyHunyuanLatentVideo",
"_meta": {
"title": "EmptyHunyuanLatentVideo"
}
}
}
payload = {
"input": {
"request_id": "",
"workflow_json": workflow,
"s3": {
"access_key_id": "",
"secret_access_key": "",
"endpoint_url": "",
"bucket_name": "",
"region": ""
},
"webhook": {
"url": "",
"extra_params": {
"user_id": "12345",
"project_id": "abc-def"
}
}
}
}
response = await endpoint.request("/generate/sync", payload)
# Response contains status, output, and any errors
print(response["response"])
if __name__ == "__main__":
asyncio.run(main())
+288
View File
@@ -0,0 +1,288 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_prompts = [
"Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.",
"Cozy farming-game scene with fine details.",
"2D vector child with soccer ball; airbrush chrome; swagger; antique copper.",
"Realistic futuristic downtown of low buildings at sunset.",
"Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.",
"Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.",
"Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.",
"Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.",
"Medieval village inside glass sphere; volumetric light; macro focus.",
"Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.",
"Pope Francis DJ in leather jacket, mixing on giant console; dramatic.",
]
benchmark_dataset = [
{
"input": {
"workflow_json": {
"90": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader",
"_meta": {
"title": "Load CLIP"
}
},
"91": {
"inputs": {
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW",
"clip": [
"90",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"92": {
"inputs": {
"vae_name": "wan_2.1_vae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"93": {
"inputs": {
"shift": 8.000000000000002,
"model": [
"101",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"94": {
"inputs": {
"shift": 8,
"model": [
"102",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"95": {
"inputs": {
"add_noise": "disable",
"noise_seed": 0,
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 10,
"end_at_step": 10000,
"return_with_leftover_noise": "disable",
"model": [
"94",
0
],
"positive": [
"99",
0
],
"negative": [
"91",
0
],
"latent_image": [
"96",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"96": {
"inputs": {
"add_noise": "enable",
"noise_seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 0,
"end_at_step": 10,
"return_with_leftover_noise": "enable",
"model": [
"93",
0
],
"positive": [
"99",
0
],
"negative": [
"91",
0
],
"latent_image": [
"104",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"97": {
"inputs": {
"samples": [
"95",
0
],
"vae": [
"92",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"98": {
"inputs": {
"filename_prefix": "video/ComfyUI",
"format": "auto",
"codec": "auto",
"video": [
"100",
0
]
},
"class_type": "SaveVideo",
"_meta": {
"title": "Save Video"
}
},
"99": {
"inputs": {
"text":prompt,
"clip": [
"90",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"100": {
"inputs": {
"fps": 16,
"images": [
"97",
0
]
},
"class_type": "CreateVideo",
"_meta": {
"title": "Create Video"
}
},
"101": {
"inputs": {
"unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"102": {
"inputs": {
"unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"104": {
"inputs": {
"width": 640,
"height": 640,
"length": 81,
"batch_size": 1
},
"class_type": "EmptyHunyuanLatentVideo",
"_meta": {
"title": "EmptyHunyuanLatentVideo"
}
}
}
}
} for prompt in benchmark_prompts
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
runs=1
),
workload_calculator= lambda _ : 10000.0
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()