@@ -0,0 +1,92 @@
|
||||
This is the base PyWorker for comfyui. It can be used to create PyWorker that use various models and
|
||||
workflows. It provides two endpoints:
|
||||
|
||||
1. `/prompt`: Uses the default comfy workflow defined under `misc/default_workflows`
|
||||
2. `/custom_workflow`: Allows the client to send their own comfy workflow with each API request.
|
||||
|
||||
To use the comfyui PyWorker, `$COMFY_MODEL` env variable must be set in the template. Current options are
|
||||
`sd3` and `flux`. Each have example clients.
|
||||
|
||||
To add new models, a JSON with name `$COMFY_MODEL.json` must be created under `misc/default_workflows`
|
||||
|
||||
NOTE: default workflows follow this format:
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"handler": "RawWorkflow",
|
||||
"aws_access_key_id": "your-s3-access-key",
|
||||
"aws_secret_access_key": "your-s3-secret-access-key",
|
||||
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
|
||||
"aws_bucket_name": "your-bucket",
|
||||
"webhook_url": "your-webhook-url",
|
||||
"webhook_extra_params": {},
|
||||
"workflow_json": {}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
You can ignore all of these fields except for `workflow_json`.
|
||||
|
||||
Fields written as "{{FOO}}" will be replaced using data from a user request. For example, SD3's workflow has the
|
||||
following nodes:
|
||||
|
||||
```json
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": "{{WIDTH}}",
|
||||
"height": "{{HEIGHT}}",
|
||||
"batch_size": 1
|
||||
},
|
||||
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "{{PROMPT}}",
|
||||
"clip": ["11", 0]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
...
|
||||
"17": {
|
||||
"inputs": {
|
||||
"scheduler": "simple",
|
||||
"steps": "{{STEPS}}",
|
||||
"denoise": 1,
|
||||
"model": ["12", 0]
|
||||
},
|
||||
"class_type": "BasicScheduler",
|
||||
"_meta": {
|
||||
"title": "BasicScheduler"
|
||||
}
|
||||
},
|
||||
...
|
||||
"25": {
|
||||
"inputs": {
|
||||
"noise_seed": "{{SEED}}"
|
||||
},
|
||||
"class_type": "RandomNoise",
|
||||
"_meta": {
|
||||
"title": "RandomNoise"
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
Incoming requests have the following JSON format:
|
||||
|
||||
```json
|
||||
{
|
||||
prompt: str
|
||||
width: int
|
||||
height: int
|
||||
steps: int
|
||||
seed: int
|
||||
}
|
||||
```
|
||||
|
||||
Each value in those fields with replace the placeholder of the same name in the default workflow.
|
||||
|
||||
See Vast's serverless documentation for more details on how to use comfyui with autoscaler
|
||||
@@ -0,0 +1,150 @@
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from lib.test_utils import print_truncate_res
|
||||
|
||||
"""
|
||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||
"""
|
||||
|
||||
|
||||
def call_default_workflow(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
WORKER_ENDPOINT = "/prompt"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
)
|
||||
payload = dict(
|
||||
prompt="a fat fluffy cat", width=1024, height=1024, steps=20, seed=123456789
|
||||
)
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
)
|
||||
print_truncate_res(str(response.json()))
|
||||
|
||||
|
||||
def call_custom_workflow_for_sd3(
|
||||
endpoint_group_name: str, api_key: str, server_url: str
|
||||
) -> None:
|
||||
WORKER_ENDPOINT = "/custom-workflow"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
)
|
||||
workflow = {
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": 156680208700286,
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": ["4", 0],
|
||||
"positive": ["6", 0],
|
||||
"negative": ["7", 0],
|
||||
"latent_image": ["5", 0],
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
},
|
||||
"4": {
|
||||
"inputs": {"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
},
|
||||
"5": {
|
||||
"inputs": {"width": 512, "height": 512, "batch_size": 1},
|
||||
"class_type": "EmptyLatentImage",
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "beautiful scenery nature glass bottle landscape, purple galaxy bottle",
|
||||
"clip": ["4", 1],
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
},
|
||||
"7": {
|
||||
"inputs": {"text": "text, watermark", "clip": ["4", 1]},
|
||||
"class_type": "CLIPTextEncode",
|
||||
},
|
||||
"8": {
|
||||
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
|
||||
"class_type": "VAEDecode",
|
||||
},
|
||||
"9": {
|
||||
"inputs": {"filename_prefix": "ComfyUI", "images": ["8", 0]},
|
||||
"class_type": "SaveImage",
|
||||
},
|
||||
}
|
||||
# these values should match the values in the custom workflow above,
|
||||
# they are used to calculate workload
|
||||
custom_fields = dict(
|
||||
steps=20,
|
||||
width=512,
|
||||
height=512,
|
||||
)
|
||||
req_data = dict(
|
||||
payload=dict(custom_fields=custom_fields, workflow=workflow),
|
||||
auth_data=auth_data,
|
||||
)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
)
|
||||
print_truncate_res(str(response.json()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lib.test_utils import test_args
|
||||
|
||||
args = test_args.parse_args()
|
||||
call_default_workflow(
|
||||
api_key=args.api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
call_custom_workflow_for_sd3(
|
||||
api_key=args.api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
@@ -0,0 +1,205 @@
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import dataclasses
|
||||
import inspect
|
||||
from typing import Dict, Any
|
||||
from functools import cache
|
||||
from math import ceil
|
||||
from enum import Enum
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
|
||||
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
|
||||
test_prompts = f.readlines()
|
||||
|
||||
|
||||
class Model(Enum):
|
||||
Flux = "flux"
|
||||
Sd3 = "sd3"
|
||||
|
||||
def get_request_time(self) -> int:
|
||||
match self:
|
||||
case Model.Flux:
|
||||
return 23
|
||||
case Model.Sd3:
|
||||
return 6
|
||||
|
||||
|
||||
@cache
|
||||
def get_model() -> Model:
|
||||
match os.environ.get("COMFY_MODEL"):
|
||||
case "flux":
|
||||
return Model.Flux
|
||||
case "sd3":
|
||||
return Model.Sd3
|
||||
case None:
|
||||
raise Exception(
|
||||
"For comfyui pyworker, $COMFY_MODEL must be set in the vast template"
|
||||
)
|
||||
case model:
|
||||
raise Exception(f"Unsupported comfyui model: {model}")
|
||||
|
||||
|
||||
@cache
|
||||
def get_request_template() -> str:
|
||||
with open(f"workers/comfyui/misc/default_workflows/{get_model().value}.json") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def count_workload(width: int, height: int, steps: int) -> float:
|
||||
"""
|
||||
we want to normalize the workload is a number such that cur_perf(tokens/second) for 1024x1024 image with
|
||||
28 steps is 200 tokens on a 4090.
|
||||
|
||||
in order get that we calculate the
|
||||
|
||||
A = ( absolute workload based on given data )
|
||||
B = ( absolute workload for a 1024x1024 image with 28 steps )
|
||||
|
||||
and adjust the workload to 200 tokens by A/B.
|
||||
|
||||
we then adjust for difference between Flux and SD3 by multiplying this value by expected request time for a
|
||||
standard image(23s for Flux, 6s for SD3).
|
||||
On a 4090, this would give us a workload that would give a cur_perf(workload / request_time) of around 200
|
||||
"""
|
||||
|
||||
def _calculate_absolute_tokens(width_: int, height_: int, steps_: int) -> float:
|
||||
"""
|
||||
This is based on how openai counts image generation tokens, see: https://openai.com/api/pricing/
|
||||
|
||||
we count how many 512x512 grids are needed to cover the image.
|
||||
each tile is then counted as 175 tokens.
|
||||
each image generation also has constant of 85 base tokens.
|
||||
|
||||
we then adjust the count based on the number of steps. The baseline number of steps is assumed to be 28.
|
||||
Some testing with flux gave me this data:
|
||||
|
||||
steps(X) | request time(Y)
|
||||
__________|_________________
|
||||
07(0.25x) | 11s (0.47x)
|
||||
14(0.50x) | 15s (0.65x)
|
||||
21(0.75x) | 20s (0.86x)
|
||||
28(1.00x) | 23s (1.00x)
|
||||
35(1.25x) | 28s (1.21x)
|
||||
42(1.50x) | 32s (1.39x)
|
||||
49(1.75x) | 37s (1.60x)
|
||||
|
||||
this gives a linear regression of Y = 0.61*X + 6.57
|
||||
|
||||
we can use this as an adjustment_factor for token count
|
||||
|
||||
adjustment_factor = (0.61 * steps + 6.57)
|
||||
"""
|
||||
|
||||
width_grids = ceil(width_ / 512)
|
||||
height_grids = ceil(height_ / 512)
|
||||
tokens = 85 + width_grids * height_grids * 175
|
||||
adjustment_factor = 0.61 * steps_ + 6.57
|
||||
return tokens * adjustment_factor
|
||||
|
||||
REQUEST_TIME_FOR_STANDARD_IMAGE = get_model().get_request_time()
|
||||
|
||||
absolute_tokens = _calculate_absolute_tokens(
|
||||
width_=width, height_=height, steps_=steps
|
||||
)
|
||||
absolute_tokens_standard_image = _calculate_absolute_tokens(
|
||||
width_=1024, height_=1024, steps_=28
|
||||
)
|
||||
return REQUEST_TIME_FOR_STANDARD_IMAGE * (
|
||||
(absolute_tokens / absolute_tokens_standard_image) * 200
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DefaultComfyWorkflowData(ApiPayload):
|
||||
prompt: str
|
||||
width: int
|
||||
height: int
|
||||
steps: int
|
||||
seed: int
|
||||
|
||||
@classmethod
|
||||
def for_test(cls):
|
||||
|
||||
test_prompt = random.choice(test_prompts).rstrip()
|
||||
return cls(
|
||||
prompt=test_prompt,
|
||||
width=1024,
|
||||
height=1024,
|
||||
steps=28,
|
||||
seed=random.randint(0, sys.maxsize),
|
||||
)
|
||||
|
||||
def generate_payload_json(
|
||||
self,
|
||||
) -> Dict[str, Any]:
|
||||
return json.loads(
|
||||
get_request_template()
|
||||
.replace("{{PROMPT}}", self.prompt)
|
||||
# these values should be of int type. Since "{{VAR}}" is wrapped with " in the template
|
||||
# to make the JSON valid, we must replace the double quotes. i.e. "{{WIDTH}}" -> 1024 and not "1024"
|
||||
.replace('"{{WIDTH}}"', str(self.width))
|
||||
.replace('"{{HEIGHT}}"', str(self.height))
|
||||
.replace('"{{STEPS}}"', str(self.steps))
|
||||
.replace('"{{SEED}}"', str(self.seed))
|
||||
)
|
||||
|
||||
def count_workload(self) -> float:
|
||||
return count_workload(width=self.width, height=self.height, steps=self.steps)
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "DefaultComfyWorkflowData":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomComfyWorkflowData(ApiPayload):
|
||||
custom_fields: Dict[str, int]
|
||||
workflow: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def for_test(cls):
|
||||
raise NotImplemented("Custom comfy workflow is not used for testing")
|
||||
|
||||
def count_workload(self) -> float:
|
||||
return count_workload(
|
||||
width=int(self.custom_fields.get("width", 1024)),
|
||||
height=int(self.custom_fields.get("height", 1024)),
|
||||
steps=int(self.custom_fields.get("steps", 28)),
|
||||
)
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
template_json = json.loads(get_request_template())
|
||||
template_json["input"]["workflow_json"] = self.workflow
|
||||
return template_json
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "CustomComfyWorkflowData":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
{
|
||||
"input": {
|
||||
"handler": "RawWorkflow",
|
||||
"aws_access_key_id": "your-s3-access-key",
|
||||
"aws_secret_access_key": "your-s3-secret-access-key",
|
||||
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
|
||||
"aws_bucket_name": "your-bucket",
|
||||
"webhook_url": "your-webhook-url",
|
||||
"webhook_extra_params": {},
|
||||
"workflow_json": {
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": "{{WIDTH}}",
|
||||
"height": "{{HEIGHT}}",
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "{{PROMPT}}",
|
||||
"clip": ["11", 0]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": ["13", 0],
|
||||
"vae": ["10", 0]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": ["8", 0]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
},
|
||||
"10": {
|
||||
"inputs": {
|
||||
"vae_name": "ae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "Load VAE"
|
||||
}
|
||||
},
|
||||
"11": {
|
||||
"inputs": {
|
||||
"clip_name1": "t5xxl_fp16.safetensors",
|
||||
"clip_name2": "clip_l.safetensors",
|
||||
"type": "flux"
|
||||
},
|
||||
"class_type": "DualCLIPLoader",
|
||||
"_meta": {
|
||||
"title": "DualCLIPLoader"
|
||||
}
|
||||
},
|
||||
"12": {
|
||||
"inputs": {
|
||||
"unet_name": "flux1-dev.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "Load Diffusion Model"
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"inputs": {
|
||||
"noise": ["25", 0],
|
||||
"guider": ["22", 0],
|
||||
"sampler": ["16", 0],
|
||||
"sigmas": ["17", 0],
|
||||
"latent_image": ["5", 0]
|
||||
},
|
||||
"class_type": "SamplerCustomAdvanced",
|
||||
"_meta": {
|
||||
"title": "SamplerCustomAdvanced"
|
||||
}
|
||||
},
|
||||
"16": {
|
||||
"inputs": {
|
||||
"sampler_name": "euler"
|
||||
},
|
||||
"class_type": "KSamplerSelect",
|
||||
"_meta": {
|
||||
"title": "KSamplerSelect"
|
||||
}
|
||||
},
|
||||
"17": {
|
||||
"inputs": {
|
||||
"scheduler": "simple",
|
||||
"steps": "{{STEPS}}",
|
||||
"denoise": 1,
|
||||
"model": ["12", 0]
|
||||
},
|
||||
"class_type": "BasicScheduler",
|
||||
"_meta": {
|
||||
"title": "BasicScheduler"
|
||||
}
|
||||
},
|
||||
"22": {
|
||||
"inputs": {
|
||||
"model": ["12", 0],
|
||||
"conditioning": ["6", 0]
|
||||
},
|
||||
"class_type": "BasicGuider",
|
||||
"_meta": {
|
||||
"title": "BasicGuider"
|
||||
}
|
||||
},
|
||||
"25": {
|
||||
"inputs": {
|
||||
"noise_seed": "{{SEED}}"
|
||||
},
|
||||
"class_type": "RandomNoise",
|
||||
"_meta": {
|
||||
"title": "RandomNoise"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
{
|
||||
"input": {
|
||||
"handler": "RawWorkflow",
|
||||
"aws_access_key_id": "your-s3-access-key",
|
||||
"aws_secret_access_key": "your-s3-secret-access-key",
|
||||
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
|
||||
"aws_bucket_name": "your-bucket",
|
||||
"webhook_url": "your-webhook-url",
|
||||
"webhook_extra_params": {},
|
||||
"workflow_json": {
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "{{PROMPT}}",
|
||||
"clip": ["252", 1]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"inputs": {
|
||||
"shift": 3,
|
||||
"model": ["252", 0]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "ModelSamplingSD3"
|
||||
}
|
||||
},
|
||||
"67": {
|
||||
"inputs": {
|
||||
"conditioning": ["71", 0]
|
||||
},
|
||||
"class_type": "ConditioningZeroOut",
|
||||
"_meta": {
|
||||
"title": "ConditioningZeroOut"
|
||||
}
|
||||
},
|
||||
"68": {
|
||||
"inputs": {
|
||||
"start": 0.1,
|
||||
"end": 1,
|
||||
"conditioning": ["67", 0]
|
||||
},
|
||||
"class_type": "ConditioningSetTimestepRange",
|
||||
"_meta": {
|
||||
"title": "ConditioningSetTimestepRange"
|
||||
}
|
||||
},
|
||||
"69": {
|
||||
"inputs": {
|
||||
"conditioning_1": ["68", 0],
|
||||
"conditioning_2": ["70", 0]
|
||||
},
|
||||
"class_type": "ConditioningCombine",
|
||||
"_meta": {
|
||||
"title": "Conditioning (Combine)"
|
||||
}
|
||||
},
|
||||
"70": {
|
||||
"inputs": {
|
||||
"start": 0,
|
||||
"end": 0.1,
|
||||
"conditioning": ["71", 0]
|
||||
},
|
||||
"class_type": "ConditioningSetTimestepRange",
|
||||
"_meta": {
|
||||
"title": "ConditioningSetTimestepRange"
|
||||
}
|
||||
},
|
||||
"71": {
|
||||
"inputs": {
|
||||
"text": "bad quality, poor quality, doll, disfigured, jpg, toy, bad anatomy, missing limbs, missing fingers, 3d, cgi",
|
||||
"clip": ["252", 1]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Negative Prompt)"
|
||||
}
|
||||
},
|
||||
"135": {
|
||||
"inputs": {
|
||||
"width": "{{WIDTH}}",
|
||||
"height": "{{HEIGHT}}",
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"_meta": {
|
||||
"title": "EmptySD3LatentImage"
|
||||
}
|
||||
},
|
||||
"231": {
|
||||
"inputs": {
|
||||
"samples": ["271", 0],
|
||||
"vae": ["252", 2]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"233": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": ["231", 0]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
},
|
||||
"252": {
|
||||
"inputs": {
|
||||
"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"271": {
|
||||
"inputs": {
|
||||
"seed": "{{SEED}}",
|
||||
"steps": "{{STEPS}}",
|
||||
"cfg": 4.5,
|
||||
"sampler_name": "dpmpp_2m",
|
||||
"scheduler": "sgm_uniform",
|
||||
"denoise": 1,
|
||||
"model": ["13", 0],
|
||||
"positive": ["6", 0],
|
||||
"negative": ["69", 0],
|
||||
"latent_image": ["135", 0]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
|
||||
stardew valley, fine details
|
||||
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
|
||||
realistic futuristic city-downtown with short buildings, sunset
|
||||
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
|
||||
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
|
||||
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
|
||||
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
|
||||
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
|
||||
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
|
||||
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
|
||||
Pope Francis wearing biker (leather jacket), a masterpiece
|
||||
Luke Skywalker ordering a burger and fries from the Death Star canteen.
|
||||
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
|
||||
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
|
||||
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
|
||||
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
|
||||
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
|
||||
london luxurious interior living-room, light walls
|
||||
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
|
||||
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
|
||||
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
|
||||
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
|
||||
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
|
||||
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
|
||||
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
|
||||
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
|
||||
the street of amedieval fantasy town, at dawn, dark, highly detailed
|
||||
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
|
||||
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
|
||||
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
|
||||
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
|
||||
@@ -0,0 +1,135 @@
|
||||
import os
|
||||
import logging
|
||||
import dataclasses
|
||||
import base64
|
||||
from typing import Union, Type
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
from anyio import open_file
|
||||
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.data_types import EndpointHandler
|
||||
from lib.server import start_server
|
||||
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
|
||||
|
||||
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:38188"
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
|
||||
"Value not in list: unet_name", # This error is emitted when the model file is not there at all
|
||||
]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
async def generate_client_response(
|
||||
request: web.Request, response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
_ = request
|
||||
match response.status:
|
||||
case 200:
|
||||
log.debug("SUCCESS")
|
||||
res = await response.json()
|
||||
if "output" not in res:
|
||||
return web.json_response(
|
||||
data=dict(error="there was an error in the workflow"),
|
||||
status=422,
|
||||
)
|
||||
image_paths = [path["local_path"] for path in res["output"]["images"]]
|
||||
if not image_paths:
|
||||
return web.json_response(
|
||||
data=dict(error="workflow did not produce any images"),
|
||||
status=422,
|
||||
)
|
||||
images = []
|
||||
for image_path in image_paths:
|
||||
async with await open_file(image_path, mode="rb") as f:
|
||||
contents = await f.read()
|
||||
images.append(
|
||||
f"data:image/png;base64,{base64.b64encode(contents).decode('utf-8')}"
|
||||
)
|
||||
return web.json_response(data=dict(images=images))
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/runsync"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
|
||||
return DefaultComfyWorkflowData
|
||||
|
||||
def make_benchmark_payload(self) -> DefaultComfyWorkflowData:
|
||||
return DefaultComfyWorkflowData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
return await generate_client_response(client_request, model_response)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/runsync"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
|
||||
return CustomComfyWorkflowData
|
||||
|
||||
def make_benchmark_payload(self) -> CustomComfyWorkflowData:
|
||||
return CustomComfyWorkflowData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
return await generate_client_response(client_request, model_response)
|
||||
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=False,
|
||||
benchmark_handler=DefaultComfyWorkflowHandler(
|
||||
benchmark_runs=3, benchmark_words=100
|
||||
),
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
(LogAction.Info, "Downloading:"),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def handle_ping(_):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/prompt", backend.create_handler(DefaultComfyWorkflowHandler())),
|
||||
web.post("/custom-workflow", backend.create_handler(CustomComfyWorkflowHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server(backend, routes)
|
||||
@@ -0,0 +1,15 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from .data_types import DefaultComfyWorkflowData, Model
|
||||
|
||||
WORKER_ENDPOINT = "/prompt"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_args.add_argument(
|
||||
"-m",
|
||||
dest="comfy_model",
|
||||
choices=list(map(lambda x: x.value, Model)),
|
||||
required=True,
|
||||
help="Image generation model name",
|
||||
)
|
||||
test_load_cmd(DefaultComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||
@@ -0,0 +1,300 @@
|
||||
# Vast PyWorker
|
||||
|
||||
## Hello_world example
|
||||
|
||||
There is a hello_world PyWorker implantation under `workers/hello_world`. This PyWorker is
|
||||
created for an LLM model server that runs on port 5001 has two API endpoints:
|
||||
|
||||
1. `/generate`: generates an full response to the prompt and sends a JSON response
|
||||
2. `/generate_stream`: streams a response one token at a time
|
||||
|
||||
Both of these endpoints take the same API JSON payload:
|
||||
|
||||
```
|
||||
{
|
||||
"prompt": String,
|
||||
"max_response_tokens": Number | null
|
||||
}
|
||||
```
|
||||
|
||||
We want the PyWorker to also expose two endpoints, for each of the above endpoints.
|
||||
|
||||
### Structure
|
||||
|
||||
All PyWorkers have four files:
|
||||
|
||||
```
|
||||
.
|
||||
└── workers
|
||||
└── hello_world
|
||||
├── __init__.py
|
||||
├── data_types.py # contains data types representing model API endpoints
|
||||
├── server.py # contains endpoint handlers
|
||||
├── client.py # a script to call an endpoint through the autoscaler
|
||||
└── test_load.py # script for load testing
|
||||
|
||||
```
|
||||
|
||||
All of the classes follow strict type hinting. It is recommended that you type hint all of your function.
|
||||
This will allow your IDE or VSCode with `pyright` plugin to find any type errors in your implementation.
|
||||
You can also install `pyright` with `sudo npm install -g pyright` and run `pyright` in the root of the project to find
|
||||
any type errors.
|
||||
|
||||
#### data_Types.py
|
||||
|
||||
data classes representing the model API are defined here. They must inherit from
|
||||
`lib.data_types.ApiPayload`. `ApiPayload` is an abstract class and you need to define several functions for it:
|
||||
|
||||
```python
|
||||
import dataclasses
|
||||
import random
|
||||
from typing import Dict, Any
|
||||
|
||||
from transformers import AutoTokenizer # used to count tokens in a prompt
|
||||
import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model
|
||||
|
||||
from lib.data_types import ApiPayload
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
# you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputData(ApiPayload):
|
||||
prompt: str
|
||||
max_response_tokens: int
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "ApiPayload":
|
||||
"""defines how create a payload for load testing"""
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
return cls(prompt=prompt, max_response_tokens=300)
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def count_workload(self) -> float:
|
||||
"""defines how to calculate workload for a payload"""
|
||||
return len(tokenizer.tokenize(self.prompt))
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||
"""
|
||||
defines how to transform JSON data to AuthData and payload type,
|
||||
in this case `InputData` defined above represents the data sent to the model API.
|
||||
AuthData is data generated by autoscaler in order to authenticate payloads.
|
||||
In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker
|
||||
for more complicated examples
|
||||
"""
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
#### server.py
|
||||
|
||||
For every model API endpoint you want to use, you must implement an `EndpointHandler`. This class handles incoming
|
||||
requests, processes them, sends them to the model API server, and finally returns an HTTP response.
|
||||
`EndpointHandler` has several abstract functions that must be implemented. Here, we implement two, one
|
||||
for `/generate`, and one for `/generate_stream`:
|
||||
|
||||
```python
|
||||
|
||||
"""
|
||||
AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route.
|
||||
When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData
|
||||
work.
|
||||
When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON:
|
||||
{
|
||||
auth_data: AuthData,
|
||||
payload : InputData # defined above
|
||||
}
|
||||
"""
|
||||
from aiohttp import web
|
||||
|
||||
from lib.data_types import EndpointHandler, JsonDataException
|
||||
from lib.server import start_server
|
||||
from .data_types import InputData
|
||||
|
||||
# This class is the implementer for the '/generate' endpoint of model API
|
||||
@dataclasses.dataclass
|
||||
class GenerateHandler(EndpointHandler[InputData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
# the API endpoint
|
||||
return "/generate"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
"""this function should just return ApiPayload subclass used by this handler"""
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
"""
|
||||
defines how to convert `InputData` defined above, to
|
||||
JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but
|
||||
can be more complicated, See comfyui for an example
|
||||
"""
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
"""
|
||||
defines how to generate an InputData for benchmarking. This needs to be defined in only
|
||||
one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test()
|
||||
method on InputData. However, in some cases you might need to fine tune your InputData used for
|
||||
benchmarking to closely resemble the average request users call the endpoint with in order to get best
|
||||
autoscaling performance
|
||||
"""
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""
|
||||
defines how to convert a model API response to a response to PyWorker client
|
||||
"""
|
||||
_ = client_request
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("SUCCESS")
|
||||
data = await model_response.json()
|
||||
return web.json_response(data=data)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
```
|
||||
|
||||
We also handle `GenerateStreamHandler` for streaming responses. It is identical to `GenerateHandler`, except for
|
||||
the endpoint name and how we create a web response, as it is a streaming response:
|
||||
|
||||
```python
|
||||
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("Streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = "text/event-stream"
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
```
|
||||
|
||||
You can now instantiate a Backend and use it to handle requests.
|
||||
|
||||
```python
|
||||
from lib.backend import Backend, LogAction
|
||||
|
||||
# the url and port of model API
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
|
||||
# This is the log line that is emitted once the server has started
|
||||
MODEL_SERVER_START_LOG_MSG = "server has started"
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
|
||||
]
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
# location of model log file
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
# for some model backends that can only handle one request at a time, be sure to set this to False to
|
||||
# let PyWorker handling queueing requests.
|
||||
allow_parallel_requests=True,
|
||||
# give the backend an EndpointHandler instance that is used for benchmarking
|
||||
# number of benchmark run and number of words for a random benchmark run are given
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
# defines how to handle specific log messages. See docstring of LogAction for details
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
# this is a simple ping handler for PyWorker
|
||||
async def handle_ping(_: web.Request):
|
||||
return web.Response(body="pong")
|
||||
|
||||
# this is a handler for forwarding a health check to model API
|
||||
async def handle_healthcheck(_: web.Request):
|
||||
healthcheck_res = await backend.session.get("/healthcheck")
|
||||
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
|
||||
|
||||
routes = [
|
||||
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
web.get("/healthcheck", handle_healthcheck),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
# start server, called from start_server.sh
|
||||
start_server(backend, routes)
|
||||
```
|
||||
|
||||
#### test_load.py
|
||||
|
||||
Here you can create a script that allows you test an endpoint group running instances with this PyWorker
|
||||
|
||||
```python
|
||||
from lib.test_harness import run
|
||||
from .data_types import InputData
|
||||
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
|
||||
if __name__ == "__main__":
|
||||
run(InputData.for_test(), WORKER_ENDPOINT)
|
||||
```
|
||||
|
||||
You can then run the following command from the root of this repo to load test endpoint group:
|
||||
|
||||
```sh
|
||||
# sends 1000 requests at the rate of 0.5 requests per second
|
||||
python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"
|
||||
```
|
||||
@@ -0,0 +1,48 @@
|
||||
import dataclasses
|
||||
import random
|
||||
import inspect
|
||||
from typing import Dict, Any
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
import nltk
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
# used to count to count tokens and workload for LLM
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputData(ApiPayload):
|
||||
prompt: str
|
||||
max_response_tokens: int
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "InputData":
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
return cls(prompt=prompt, max_response_tokens=300)
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def count_workload(self) -> int:
|
||||
return len(tokenizer.tokenize(self.prompt))
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
PyWorker works as a man-in-the-middle between the client and model API. It's function is:
|
||||
1. receive request from client, update metrics such as workload of a request, number of pending requests, etc.
|
||||
2a. transform the data and forward the transformed data to model API
|
||||
2b. send updated metrics to autoscaler
|
||||
3. transform response from model API(if needed) and forward the response to client
|
||||
|
||||
PyWorker forward requests to many model API endpoint. each endpoint must have an EndpointHandler. You can also
|
||||
write function to just forward requests that don't generate anything with the model to model API without an
|
||||
EndpointHandler. This is useful for endpoints such as healthchecks. See below for example
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import dataclasses
|
||||
from typing import Dict, Any, Union, Type
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.data_types import EndpointHandler
|
||||
from lib.server import start_server
|
||||
from .data_types import InputData
|
||||
|
||||
# the url and port of model API
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
|
||||
# This is the log line that is emitted once the server has started
|
||||
MODEL_SERVER_START_LOG_MSG = "infer server has started"
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = [
|
||||
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
|
||||
]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
# This class is the implementer for the '/generate' endpoint of model API
|
||||
@dataclasses.dataclass
|
||||
class GenerateHandler(EndpointHandler[InputData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
# the API endpoint
|
||||
return "/generate"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
"""
|
||||
defines how to convert `InputData` defined above, to
|
||||
json data to be sent to the model API
|
||||
"""
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
"""
|
||||
defines how to generate an InputData for benchmarking. This needs to be defined in only
|
||||
one EndpointHandler, the one passed to the backend as the benchmark handler
|
||||
"""
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
"""
|
||||
defines how to convert a model API response to a response to PyWorker client
|
||||
"""
|
||||
_ = client_request
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("SUCCESS")
|
||||
data = await model_response.json()
|
||||
return web.json_response(data=data)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
# This is the same as GenerateHandler, except that it calls a streaming endpoint of the model API and streams the
|
||||
# response, which itself is streaming, back to the client.
|
||||
# it is nearly identical to handler as above, but it calls a different model API endpoint and it streams the
|
||||
# streaming response from model API to client
|
||||
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(payload)
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("Streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = "text/event-stream"
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
# This is the backend instance of pyworker. Only one must be made which uses EndpointHandlers to process
|
||||
# incoming requests
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
# give the backend a handler instance that is used for benchmarking
|
||||
# number of benchmark run and number of words for a random benchmark run are given
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
# defines how to handle specific log messages. See docstring of LogAction for details
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# this is a simple ping handler for pyworker
|
||||
async def handle_ping(_: web.Request):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
# this is a handler for forwarding a health check to modelAPI
|
||||
async def handle_healthcheck(_: web.Request):
|
||||
healthcheck_res = await backend.session.get("/healthcheck")
|
||||
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
web.get("/healthcheck", handle_healthcheck),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
# start the PyWorker server
|
||||
start_server(backend, routes)
|
||||
@@ -0,0 +1,7 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from .data_types import InputData
|
||||
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||
@@ -0,0 +1,19 @@
|
||||
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
|
||||
|
||||
1. `generate`: Generates the LLM's response to a given prompt in a single request.
|
||||
2. `generate_stream`: Streams the LLM's response token by token.
|
||||
|
||||
Both endpoints use the following API payload format:
|
||||
|
||||
```json
|
||||
{
|
||||
"inputs": "PROMPT",
|
||||
"parameters": {
|
||||
"max_new_tokens": 250
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
|
||||
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
|
||||
approximately 2 seconds to complete.
|
||||
@@ -0,0 +1,91 @@
|
||||
import sys
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
)
|
||||
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
)
|
||||
res = response.json()
|
||||
print(res)
|
||||
|
||||
|
||||
def call_generate_stream(endpoint_group_name: str, api_key: str, server_url: str):
|
||||
WORKER_ENDPOINT = "/generate_stream"
|
||||
COST = 100
|
||||
route_payload = {
|
||||
"endpoint": endpoint_group_name,
|
||||
"api_key": api_key,
|
||||
"cost": COST,
|
||||
}
|
||||
response = requests.post(
|
||||
urljoin(server_url, "/route/"),
|
||||
json=route_payload,
|
||||
timeout=4,
|
||||
)
|
||||
message = response.json()
|
||||
url = message["url"]
|
||||
print(f"url: {url}")
|
||||
auth_data = dict(
|
||||
signature=message["signature"],
|
||||
cost=message["cost"],
|
||||
endpoint=message["endpoint"],
|
||||
reqnum=message["reqnum"],
|
||||
url=message["url"],
|
||||
)
|
||||
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
response = requests.post(url, json=req_data, stream=True)
|
||||
for line in response.iter_lines():
|
||||
payload = line.decode().lstrip("data:").rstrip()
|
||||
if payload:
|
||||
data = json.loads(payload)
|
||||
print(data["token"]["text"], end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lib.test_utils import test_args
|
||||
|
||||
args = test_args.parse_args()
|
||||
call_generate(
|
||||
api_key=args.api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
call_generate_stream(
|
||||
api_key=args.api_key,
|
||||
endpoint_group_name=args.endpoint_group_name,
|
||||
server_url=args.server_url,
|
||||
)
|
||||
@@ -0,0 +1,73 @@
|
||||
import dataclasses
|
||||
import random
|
||||
import inspect
|
||||
from typing import Dict, Any
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
import nltk
|
||||
|
||||
from lib.data_types import ApiPayload, JsonDataException
|
||||
|
||||
nltk.download("words")
|
||||
WORD_LIST = nltk.corpus.words.words()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputParameters:
|
||||
max_new_tokens: int = 256
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputParameters":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in json_msg.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputData(ApiPayload):
|
||||
inputs: str
|
||||
parameters: InputParameters
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "InputData":
|
||||
return cls(
|
||||
inputs=data["inputs"], parameters=InputParameters(**data["parameters"])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_test(cls) -> "InputData":
|
||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||
return cls(inputs=prompt, parameters=InputParameters())
|
||||
|
||||
def generate_payload_json(self) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def count_workload(self) -> int:
|
||||
return self.parameters.max_new_tokens
|
||||
|
||||
@classmethod
|
||||
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
|
||||
errors = {}
|
||||
for param in inspect.signature(cls).parameters:
|
||||
if param not in json_msg:
|
||||
errors[param] = "missing parameter"
|
||||
if errors:
|
||||
raise JsonDataException(errors)
|
||||
try:
|
||||
parameters = InputParameters.from_json_msg(json_msg["parameters"])
|
||||
return cls(inputs=json_msg["inputs"], parameters=parameters)
|
||||
except JsonDataException as e:
|
||||
errors["parameters"] = e.message
|
||||
raise JsonDataException(errors)
|
||||
@@ -0,0 +1,115 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Union, Type
|
||||
import dataclasses
|
||||
|
||||
from aiohttp import web, ClientResponse
|
||||
|
||||
from lib.backend import Backend, LogAction
|
||||
from lib.data_types import EndpointHandler
|
||||
from lib.server import start_server
|
||||
from .data_types import InputData
|
||||
|
||||
|
||||
MODEL_SERVER_URL = "http://0.0.0.0:5001"
|
||||
|
||||
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
|
||||
MODEL_SERVER_START_LOG_MSG = '"message":"Connected","target":"text_generation_router"'
|
||||
MODEL_SERVER_ERROR_LOG_MSGS = ["Error: WebserverFailed", "Error: DownloadError"]
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GenerateHandler(EndpointHandler[InputData]):
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
_ = client_request
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("SUCCESS")
|
||||
data = await model_response.json()
|
||||
return web.json_response(data=data)
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
class GenerateStreamHandler(EndpointHandler[InputData]):
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
return "/generate_stream"
|
||||
|
||||
@classmethod
|
||||
def payload_cls(cls) -> Type[InputData]:
|
||||
return InputData
|
||||
|
||||
def make_benchmark_payload(self) -> InputData:
|
||||
return InputData.for_test()
|
||||
|
||||
async def generate_client_response(
|
||||
self, client_request: web.Request, model_response: ClientResponse
|
||||
) -> Union[web.Response, web.StreamResponse]:
|
||||
match model_response.status:
|
||||
case 200:
|
||||
log.debug("Streaming response...")
|
||||
res = web.StreamResponse()
|
||||
res.content_type = "text/event-stream"
|
||||
await res.prepare(client_request)
|
||||
async for chunk in model_response.content:
|
||||
await res.write(chunk)
|
||||
await res.write_eof()
|
||||
log.debug("Done streaming response")
|
||||
return res
|
||||
case code:
|
||||
log.debug("SENDING RESPONSE: ERROR: unknown code")
|
||||
return web.Response(status=code)
|
||||
|
||||
|
||||
backend = Backend(
|
||||
model_server_url=MODEL_SERVER_URL,
|
||||
model_log_file=os.environ["MODEL_LOG"],
|
||||
allow_parallel_requests=True,
|
||||
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
|
||||
log_actions=[
|
||||
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
|
||||
(LogAction.Info, '"message":"Download'),
|
||||
*[
|
||||
(LogAction.ModelError, error_msg)
|
||||
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def handle_ping(_):
|
||||
return web.Response(body="pong")
|
||||
|
||||
|
||||
routes = [
|
||||
web.post("/generate", backend.create_handler(GenerateHandler())),
|
||||
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
|
||||
web.get("/ping", handle_ping),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server(backend, routes)
|
||||
@@ -0,0 +1,7 @@
|
||||
from lib.test_utils import test_load_cmd, test_args
|
||||
from .data_types import InputData
|
||||
|
||||
WORKER_ENDPOINT = "/generate"
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
|
||||
Reference in New Issue
Block a user