Compare commits
61 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a81d3febe7 | |||
| 913e3a8782 | |||
| 47ad0ebe0a | |||
| 34fd21e76a | |||
| 1d2caaf554 | |||
| 01eff874d8 | |||
| d51f04a176 | |||
| ef248ef695 | |||
| 6a562a1376 | |||
| 6c2f194b28 | |||
| 2aada7b210 | |||
| 8df562e243 | |||
| 4eef5e22af | |||
| 9d969e376e | |||
| ef3f34a515 | |||
| 147bf2597a | |||
| dc423e2999 | |||
| 463f3de8ea | |||
| ed0db198c3 | |||
| 3668d948be | |||
| 254ccdf181 | |||
| 89761b378a | |||
| 18974873e5 | |||
| 9bc9ba11c5 | |||
| 48fdc65e3d | |||
| 2cd97315cd | |||
| 83c31e25a9 | |||
| fbe1dca6fa | |||
| 4c3120dbc5 | |||
| d7d9b915f6 | |||
| 4660b337fb | |||
| 7506ecb6b5 | |||
| 50633c5003 | |||
| 2e8f18276f | |||
| eba9c480eb | |||
| aaca1c9645 | |||
| f319db6bd5 | |||
| 4d786b4d17 | |||
| bd3e0032a1 | |||
| e02f4bc943 | |||
| bcb04b9a32 | |||
| 9daf171487 | |||
| 29f836eb1a | |||
| 4380d98c01 | |||
| 2ce741a8b7 | |||
| 4ecc07032f | |||
| df61e6e946 | |||
| 70f8a8f534 | |||
| 7be8aa6397 | |||
| 138fc3ac47 | |||
| 222ac2a0dd | |||
| 40aed9b5f8 | |||
| d4d36bf86e | |||
| e839cfc6e8 | |||
| f04138e13b | |||
| de3aa87c8f | |||
| 6b5b1341a7 | |||
| 8be92c03de | |||
| adedb8ba90 | |||
| 2f543c01ad | |||
| 0bcd2219ea |
+2
-11
@@ -1,11 +1,2 @@
|
|||||||
aiohttp[speedups]==3.10.1
|
vastai-sdk>=0.3.0
|
||||||
anyio~=4.4
|
nltk==3.9.4
|
||||||
lib~=4.0
|
|
||||||
nltk~=3.9
|
|
||||||
psutil~=6.0
|
|
||||||
pycryptodome~=3.20
|
|
||||||
Requests~=2.32
|
|
||||||
transformers~=4.52
|
|
||||||
utils==1.0.*
|
|
||||||
hf_transfer>=0.1.9
|
|
||||||
git+https://github.com/vast-ai/vast-sdk.git@session
|
|
||||||
+162
-18
@@ -2,10 +2,17 @@
|
|||||||
|
|
||||||
set -e -o pipefail
|
set -e -o pipefail
|
||||||
|
|
||||||
|
# Check for force update flag
|
||||||
|
FORCE_UPDATE=false
|
||||||
|
if [ -f "/.force_update" ]; then
|
||||||
|
echo "Force update flag detected at /.force_update"
|
||||||
|
FORCE_UPDATE=true
|
||||||
|
fi
|
||||||
|
|
||||||
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
|
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
|
||||||
|
|
||||||
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
|
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
|
||||||
ENV_PATH="$WORKSPACE_DIR/worker-env"
|
ENV_PATH="${ENV_PATH:-$WORKSPACE_DIR/worker-env}"
|
||||||
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
|
||||||
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
|
||||||
|
|
||||||
@@ -46,6 +53,42 @@ JSON
|
|||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function install_vastai_sdk() {
|
||||||
|
local uv_flags=()
|
||||||
|
if [ "${USE_SYSTEM_PYTHON:-}" = "true" ]; then
|
||||||
|
uv_flags+=(--system --break-system-packages)
|
||||||
|
fi
|
||||||
|
if [ "$FORCE_UPDATE" = true ]; then
|
||||||
|
uv_flags+=(--force-reinstall)
|
||||||
|
echo "Force reinstalling vastai"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If SDK_BRANCH is set, install vastai from the vast-cli repo at that branch/tag/commit.
|
||||||
|
if [ -n "${SDK_BRANCH:-}" ]; then
|
||||||
|
if [ -n "${SDK_VERSION:-}" ]; then
|
||||||
|
echo "WARNING: Both SDK_BRANCH and SDK_VERSION are set; using SDK_BRANCH=${SDK_BRANCH}"
|
||||||
|
fi
|
||||||
|
echo "Installing vastai from https://github.com/vast-ai/vast-cli/ @ ${SDK_BRANCH}"
|
||||||
|
if ! uv pip install "${uv_flags[@]}" "vastai @ git+https://github.com/vast-ai/vast-cli.git@${SDK_BRANCH}"; then
|
||||||
|
report_error_and_exit "Failed to install vastai from vast-ai/vast-cli@${SDK_BRANCH}"
|
||||||
|
fi
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "${SDK_VERSION:-}" ]; then
|
||||||
|
echo "Installing vastai version ${SDK_VERSION}"
|
||||||
|
if ! uv pip install "${uv_flags[@]}" "vastai==${SDK_VERSION}"; then
|
||||||
|
report_error_and_exit "Failed to install vastai==${SDK_VERSION}"
|
||||||
|
fi
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Installing default vastai"
|
||||||
|
if ! uv pip install "${uv_flags[@]}" vastai; then
|
||||||
|
report_error_and_exit "Failed to install vastai"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
[ -n "$BACKEND" ] && [ -z "$HF_TOKEN" ] && report_error_and_exit "HF_TOKEN must be set when BACKEND is set!"
|
[ -n "$BACKEND" ] && [ -z "$HF_TOKEN" ] && report_error_and_exit "HF_TOKEN must be set when BACKEND is set!"
|
||||||
[ -z "$CONTAINER_ID" ] && report_error_and_exit "CONTAINER_ID must be set!"
|
[ -z "$CONTAINER_ID" ] && report_error_and_exit "CONTAINER_ID must be set!"
|
||||||
[ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && report_error_and_exit "For comfyui backends, COMFY_MODEL must be set!"
|
[ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && report_error_and_exit "For comfyui backends, COMFY_MODEL must be set!"
|
||||||
@@ -63,7 +106,8 @@ echo_var DEBUG_LOG
|
|||||||
echo_var PYWORKER_LOG
|
echo_var PYWORKER_LOG
|
||||||
echo_var MODEL_LOG
|
echo_var MODEL_LOG
|
||||||
|
|
||||||
if [ -e "$MODEL_LOG" ]; then
|
ROTATE_MODEL_LOG="${ROTATE_MODEL_LOG:-false}"
|
||||||
|
if [ "$ROTATE_MODEL_LOG" = "true" ] && [ -e "$MODEL_LOG" ]; then
|
||||||
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
|
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
|
||||||
if ! cat "$MODEL_LOG" >> "$MODEL_LOG.old"; then
|
if ! cat "$MODEL_LOG" >> "$MODEL_LOG.old"; then
|
||||||
report_error_and_exit "Failed to rotate model log"
|
report_error_and_exit "Failed to rotate model log"
|
||||||
@@ -84,8 +128,21 @@ if ! grep -q "VAST" /etc/environment; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -d "$ENV_PATH" ]
|
if [ "${USE_SYSTEM_PYTHON:-}" = "true" ]; then
|
||||||
then
|
echo "Using system Python: $(which python3)"
|
||||||
|
if ! which uv > /dev/null 2>&1; then
|
||||||
|
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
|
||||||
|
report_error_and_exit "Failed to install uv package manager"
|
||||||
|
fi
|
||||||
|
if [[ -f ~/.local/bin/env ]]; then
|
||||||
|
if ! source ~/.local/bin/env; then
|
||||||
|
report_error_and_exit "Failed to source uv environment"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
install_vastai_sdk
|
||||||
|
touch ~/.no_auto_tmux
|
||||||
|
elif [ ! -d "$ENV_PATH" ]; then
|
||||||
echo "setting up venv"
|
echo "setting up venv"
|
||||||
if ! which uv; then
|
if ! which uv; then
|
||||||
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
|
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
|
||||||
@@ -104,17 +161,34 @@ then
|
|||||||
if ! git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"; then
|
if ! git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"; then
|
||||||
report_error_and_exit "Failed to clone pyworker repository"
|
report_error_and_exit "Failed to clone pyworker repository"
|
||||||
fi
|
fi
|
||||||
|
elif [ "$FORCE_UPDATE" = true ]; then
|
||||||
|
echo "Force updating pyworker repository"
|
||||||
|
if ! (cd "$SERVER_DIR" && git fetch --all); then
|
||||||
|
report_error_and_exit "Failed to fetch pyworker repository updates"
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
if [[ -n ${PYWORKER_REF:-} ]]; then
|
if [[ -n ${PYWORKER_REF:-} ]]; then
|
||||||
if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); then
|
if [ "$FORCE_UPDATE" = true ]; then
|
||||||
report_error_and_exit "Failed to checkout pyworker reference: $PYWORKER_REF"
|
echo "Force updating to pyworker reference: $PYWORKER_REF"
|
||||||
|
if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF" && git pull); then
|
||||||
|
report_error_and_exit "Failed to force update pyworker reference: $PYWORKER_REF"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); then
|
||||||
|
report_error_and_exit "Failed to checkout pyworker reference: $PYWORKER_REF"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
elif [ "$FORCE_UPDATE" = true ]; then
|
||||||
|
echo "Force updating pyworker to latest"
|
||||||
|
if ! (cd "$SERVER_DIR" && git pull); then
|
||||||
|
report_error_and_exit "Failed to pull latest pyworker changes"
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if ! uv venv --python-preference only-managed "$ENV_PATH" -p 3.10; then
|
if ! uv venv --python-preference only-managed "$ENV_PATH" -p 3.10; then
|
||||||
report_error_and_exit "Failed to create virtual environment"
|
report_error_and_exit "Failed to create virtual environment"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if ! source "$ENV_PATH/bin/activate"; then
|
if ! source "$ENV_PATH/bin/activate"; then
|
||||||
report_error_and_exit "Failed to activate virtual environment"
|
report_error_and_exit "Failed to activate virtual environment"
|
||||||
fi
|
fi
|
||||||
@@ -123,6 +197,8 @@ then
|
|||||||
report_error_and_exit "Failed to install Python requirements"
|
report_error_and_exit "Failed to install Python requirements"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
install_vastai_sdk
|
||||||
|
|
||||||
if ! touch ~/.no_auto_tmux; then
|
if ! touch ~/.no_auto_tmux; then
|
||||||
report_error_and_exit "Failed to create ~/.no_auto_tmux"
|
report_error_and_exit "Failed to create ~/.no_auto_tmux"
|
||||||
fi
|
fi
|
||||||
@@ -132,11 +208,44 @@ else
|
|||||||
report_error_and_exit "Failed to source uv environment"
|
report_error_and_exit "Failed to source uv environment"
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
if ! source "$WORKSPACE_DIR/worker-env/bin/activate"; then
|
if ! source "$ENV_PATH/bin/activate"; then
|
||||||
report_error_and_exit "Failed to activate existing virtual environment"
|
report_error_and_exit "Failed to activate existing virtual environment"
|
||||||
fi
|
fi
|
||||||
echo "environment activated"
|
echo "environment activated"
|
||||||
echo "venv: $VIRTUAL_ENV"
|
echo "venv: $VIRTUAL_ENV"
|
||||||
|
|
||||||
|
# Handle force update for existing environment
|
||||||
|
if [ "$FORCE_UPDATE" = true ]; then
|
||||||
|
echo "Performing force update on existing environment"
|
||||||
|
|
||||||
|
if [[ -d $SERVER_DIR ]]; then
|
||||||
|
echo "Force updating pyworker repository"
|
||||||
|
if ! (cd "$SERVER_DIR" && git fetch --all); then
|
||||||
|
report_error_and_exit "Failed to fetch pyworker repository updates"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n ${PYWORKER_REF:-} ]]; then
|
||||||
|
echo "Force updating to pyworker reference: $PYWORKER_REF"
|
||||||
|
if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF" && git pull); then
|
||||||
|
report_error_and_exit "Failed to force update pyworker reference: $PYWORKER_REF"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "Force updating pyworker to latest"
|
||||||
|
if ! (cd "$SERVER_DIR" && git pull); then
|
||||||
|
report_error_and_exit "Failed to pull latest pyworker changes"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
install_vastai_sdk
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Remove force update flag after successful update
|
||||||
|
if [ "$FORCE_UPDATE" = true ]; then
|
||||||
|
echo "Removing force update flag"
|
||||||
|
rm -f "/.force_update"
|
||||||
|
echo "Force update completed successfully"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$USE_SSL" = true ]; then
|
if [ "$USE_SSL" = true ]; then
|
||||||
@@ -174,16 +283,51 @@ EOF
|
|||||||
report_error_and_exit "Failed to generate SSL certificate request"
|
report_error_and_exit "Failed to generate SSL certificate request"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if ! curl --header 'Content-Type: application/octet-stream' \
|
max_retries=5
|
||||||
--data-binary @/etc/instance.csr \
|
retry_delay=2
|
||||||
-X \
|
for attempt in $(seq 1 "$max_retries"); do
|
||||||
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; then
|
http_code=$(curl -sS -o /etc/instance.crt -w '%{http_code}' \
|
||||||
report_error_and_exit "Failed to sign SSL certificate"
|
--header 'Content-Type: application/octet-stream' \
|
||||||
fi
|
--data-binary @/etc/instance.csr \
|
||||||
|
-X POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID")
|
||||||
|
if [ "$http_code" -ge 200 ] && [ "$http_code" -lt 300 ]; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
echo "SSL cert signing attempt $attempt/$max_retries failed (HTTP $http_code)"
|
||||||
|
if [ "$attempt" -eq "$max_retries" ]; then
|
||||||
|
report_error_and_exit "Failed to sign SSL certificate after $max_retries attempts (HTTP $http_code)"
|
||||||
|
fi
|
||||||
|
sleep "$retry_delay"
|
||||||
|
retry_delay=$((retry_delay * 2))
|
||||||
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
|
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
|
||||||
|
|
||||||
|
# ─── SDK Deployment Mode ───────────────────────────────────────────────
|
||||||
|
if [ "$IS_DEPLOYMENT" = "true" ]; then
|
||||||
|
echo "=== SDK Deployment Mode ==="
|
||||||
|
echo "DEPLOYMENT_ID: $DEPLOYMENT_ID"
|
||||||
|
|
||||||
|
DEPLOY_DIR="/workspace/deployment"
|
||||||
|
mkdir -p "$DEPLOY_DIR"
|
||||||
|
|
||||||
|
VAST_API_BASE="${VAST_API_BASE:-https://console.vast.ai}"
|
||||||
|
|
||||||
|
# Download deployment code, retrying until the blob is available on S3.
|
||||||
|
# The s3_key exists in the DB as soon as the deployment is created, but the
|
||||||
|
# actual upload may still be in flight from the client side.
|
||||||
|
|
||||||
|
# Install SDK (uses the install_vastai_sdk function which supports SDK_BRANCH/SDK_VERSION)
|
||||||
|
install_vastai_sdk
|
||||||
|
# Run deployment in serve mode
|
||||||
|
export VAST_DEPLOYMENT_MODE=serve
|
||||||
|
echo "Starting deployment: python3 $DEPLOY_DIR/deployment.py"
|
||||||
|
serve-vast-deployment
|
||||||
|
exit $?
|
||||||
|
fi
|
||||||
|
# ─── End SDK Deployment Mode ───────────────────────────────────────────
|
||||||
|
|
||||||
if ! cd "$SERVER_DIR"; then
|
if ! cd "$SERVER_DIR"; then
|
||||||
report_error_and_exit "Failed to cd into SERVER_DIR: $SERVER_DIR"
|
report_error_and_exit "Failed to cd into SERVER_DIR: $SERVER_DIR"
|
||||||
fi
|
fi
|
||||||
@@ -195,19 +339,19 @@ set +e
|
|||||||
PY_STATUS=1
|
PY_STATUS=1
|
||||||
|
|
||||||
if [ -f "$SERVER_DIR/worker.py" ]; then
|
if [ -f "$SERVER_DIR/worker.py" ]; then
|
||||||
echo "trying worker.py"
|
echo "Running worker.py"
|
||||||
python3 -m "worker" |& tee -a "$PYWORKER_LOG"
|
python3 -m "worker" |& tee -a "$PYWORKER_LOG"
|
||||||
PY_STATUS=${PIPESTATUS[0]}
|
PY_STATUS=${PIPESTATUS[0]}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "${PY_STATUS}" -ne 0 ] && [ -f "$SERVER_DIR/workers/$BACKEND/worker.py" ]; then
|
if [ "${PY_STATUS}" -ne 0 ] && [ -f "$SERVER_DIR/workers/$BACKEND/worker.py" ]; then
|
||||||
echo "trying workers.${BACKEND}.worker"
|
echo "Running workers.${BACKEND}.worker"
|
||||||
python3 -m "workers.${BACKEND}.worker" |& tee -a "$PYWORKER_LOG"
|
python3 -m "workers.${BACKEND}.worker" |& tee -a "$PYWORKER_LOG"
|
||||||
PY_STATUS=${PIPESTATUS[0]}
|
PY_STATUS=${PIPESTATUS[0]}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "${PY_STATUS}" -ne 0 ] && [ -f "$SERVER_DIR/workers/$BACKEND/server.py" ]; then
|
if [ "${PY_STATUS}" -ne 0 ] && [ -f "$SERVER_DIR/workers/$BACKEND/server.py" ]; then
|
||||||
echo "trying workers.${BACKEND}.server"
|
echo "Running workers.${BACKEND}.server"
|
||||||
python3 -m "workers.${BACKEND}.server" |& tee -a "$PYWORKER_LOG"
|
python3 -m "workers.${BACKEND}.server" |& tee -a "$PYWORKER_LOG"
|
||||||
PY_STATUS=${PIPESTATUS[0]}
|
PY_STATUS=${PIPESTATUS[0]}
|
||||||
fi
|
fi
|
||||||
@@ -221,4 +365,4 @@ if [ "${PY_STATUS}" -ne 0 ]; then
|
|||||||
report_error_and_exit "PyWorker exited with status ${PY_STATUS}"
|
report_error_and_exit "PyWorker exited with status ${PY_STATUS}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "launching PyWorker server done"
|
echo "PyWorker bootstrap complete"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
This is the PyWorker implementation for running **ACE Step v1 3.5B** text-to-music workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI audio-generation workflows through a proxy-based architecture and returning generated audio assets.
|
This is the PyWorker implementation for running **ACE Step v1 3.5B** text-to-music workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI audio-generation workflows through a proxy-based architecture and returning generated audio assets.
|
||||||
|
|
||||||
Each request has a static cost of `100`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
|
Each request has a static cost of `1000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
# ComfyUI PyWorker
|
# ComfyUI PyWorker
|
||||||
|
|
||||||
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture.
|
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||||
|
|
||||||
The cost for each request has a static value of `100`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
|
The cost for each request has a static value of `100`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
|
||||||
|
|
||||||
|
## Instance Setup
|
||||||
|
|
||||||
|
1. Pick a template
|
||||||
|
|
||||||
|
- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless))
|
||||||
|
|
||||||
|
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
@@ -10,6 +18,137 @@ This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) a
|
|||||||
|
|
||||||
A docker image is provided but you may use any if the above requirements are met.
|
A docker image is provided but you may use any if the above requirements are met.
|
||||||
|
|
||||||
|
## Client
|
||||||
|
|
||||||
|
The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/vast-ai/pyworker
|
||||||
|
cd pyworker
|
||||||
|
pip install uv
|
||||||
|
uv venv -p 3.12
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Set your API key:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export VAST_API_KEY=<your_api_key>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Default prompt
|
||||||
|
python -m workers.comfyui-json.client
|
||||||
|
|
||||||
|
# Custom prompt
|
||||||
|
python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow"
|
||||||
|
|
||||||
|
# With options
|
||||||
|
python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30
|
||||||
|
|
||||||
|
# Using a custom workflow file
|
||||||
|
python -m workers.comfyui-json.client --workflow my_workflow.json
|
||||||
|
|
||||||
|
# With S3 upload
|
||||||
|
python -m workers.comfyui-json.client --s3
|
||||||
|
```
|
||||||
|
|
||||||
|
### CLI Flags
|
||||||
|
|
||||||
|
| Flag | Default | Description |
|
||||||
|
|------|---------|-------------|
|
||||||
|
| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name |
|
||||||
|
| `--prompt` | (default) | Text prompt for image generation |
|
||||||
|
| `--workflow` | (none) | Path to custom workflow JSON file |
|
||||||
|
| `--width` | 512 | Image width in pixels |
|
||||||
|
| `--height` | 512 | Image height in pixels |
|
||||||
|
| `--steps` | 20 | Number of denoising steps |
|
||||||
|
| `--seed` | (random) | Random seed for reproducibility |
|
||||||
|
| `--s3` | (disabled) | Upload generated images to S3 |
|
||||||
|
|
||||||
|
### Output
|
||||||
|
|
||||||
|
Images are saved to `./generated_images/comfy_{seed}.png`.
|
||||||
|
|
||||||
|
### S3 Upload (Optional)
|
||||||
|
|
||||||
|
You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag.
|
||||||
|
|
||||||
|
**1. Set environment variables:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com"
|
||||||
|
export S3_BUCKET_NAME="my-bucket"
|
||||||
|
export S3_ACCESS_KEY_ID="your-access-key-id"
|
||||||
|
export S3_SECRET_ACCESS_KEY="your-secret-access-key"
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Run with S3 upload enabled:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3
|
||||||
|
```
|
||||||
|
|
||||||
|
Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`.
|
||||||
|
|
||||||
|
**Note:** Requires `boto3` (`pip install boto3`).
|
||||||
|
|
||||||
|
## Benchmarking
|
||||||
|
|
||||||
|
### Custom Benchmark Workflows
|
||||||
|
|
||||||
|
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
|
||||||
|
|
||||||
|
**Ways to provide the benchmark file:**
|
||||||
|
- Fork this repository and add your `benchmark.json` file
|
||||||
|
- Write the file during worker provisioning (onstart script or setup phase)
|
||||||
|
|
||||||
|
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
|
||||||
|
|
||||||
|
### Default Benchmark (Fallback)
|
||||||
|
|
||||||
|
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
|
||||||
|
|
||||||
|
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
|
||||||
|
|
||||||
|
| Environment Variable | Default Value | Description |
|
||||||
|
| -------------------- | ------------- | ----------- |
|
||||||
|
| BENCHMARK_TEST_WIDTH | 512 | Image width (pixels) |
|
||||||
|
| BENCHMARK_TEST_HEIGHT | 512 | Image height (pixels) |
|
||||||
|
| BENCHMARK_TEST_STEPS | 20 | Number of denoising steps |
|
||||||
|
|
||||||
|
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
|
||||||
|
|
||||||
|
#### Calibrating Fallback Benchmark Duration
|
||||||
|
|
||||||
|
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
|
||||||
|
|
||||||
|
**Example:** If your typical workflow should complete in 90 seconds on acceptable hardware:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Measure it/sec on your reference machine
|
||||||
|
# RTX 4090 typically achieves ~43 it/sec with SD1.5
|
||||||
|
|
||||||
|
# 2. Calculate required steps
|
||||||
|
# 90 seconds × 43 it/sec = 3870 steps
|
||||||
|
|
||||||
|
# 3. Configure benchmark
|
||||||
|
export BENCHMARK_TEST_STEPS=3870
|
||||||
|
|
||||||
|
# 4. Machines completing significantly slower than 90s indicate hardware issues
|
||||||
|
```
|
||||||
|
|
||||||
|
**Performance expectations:**
|
||||||
|
- Benchmark duration should remain consistent across identical GPU models
|
||||||
|
- Significant variation (>20%) may indicate thermal, power, or configuration issues
|
||||||
|
|
||||||
## Endpoint
|
## Endpoint
|
||||||
|
|
||||||
The worker provides a single endpoint:
|
The worker provides a single endpoint:
|
||||||
@@ -170,4 +309,4 @@ See the client example for implementation details on how to integrate with the C
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
|
See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler.
|
||||||
|
|||||||
+301
-23
@@ -1,34 +1,312 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import random
|
import random
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import logging
|
||||||
|
import argparse
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
from vastai import Serverless
|
from vastai import Serverless
|
||||||
|
|
||||||
async def main():
|
# ---------------------- Config ----------------------
|
||||||
async with Serverless() as client:
|
DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed"
|
||||||
endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name
|
ENDPOINT_NAME = "my-comfyui-endpoint"
|
||||||
|
DEFAULT_WIDTH = 512
|
||||||
|
DEFAULT_HEIGHT = 512
|
||||||
|
DEFAULT_STEPS = 20
|
||||||
|
COST = 100 # Fixed cost for ComfyUI requests
|
||||||
|
|
||||||
payload = {
|
# Optional S3 Configuration (from environment variables)
|
||||||
"input": {
|
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
|
||||||
"request_id": str(uuid.uuid4()),
|
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
||||||
"modifier": "Text2Image",
|
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
|
||||||
"modifications": {
|
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
|
||||||
"prompt": "a beautiful landscape with mountains and lakes",
|
|
||||||
"width": 1024,
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
|
||||||
"height": 1024,
|
log = logging.getLogger(__name__)
|
||||||
"steps": 20,
|
|
||||||
"seed": random.randint(0, 2**32 - 1)
|
|
||||||
},
|
def get_s3_client():
|
||||||
"workflow_json": {} # Empty since using modifier approach
|
"""Create and return an S3 client configured for the S3-compatible endpoint"""
|
||||||
}
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.config import Config
|
||||||
|
except ImportError:
|
||||||
|
log.error("boto3 is required for S3 uploads. Install with: pip install boto3")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
|
||||||
|
log.error("S3 environment variables not fully configured. Required:")
|
||||||
|
log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return boto3.client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=S3_ENDPOINT_URL,
|
||||||
|
aws_access_key_id=S3_ACCESS_KEY_ID,
|
||||||
|
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||||
|
config=Config(signature_version="s3v4"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- API Functions ----------------------
|
||||||
|
async def call_generate(
|
||||||
|
client: Serverless,
|
||||||
|
*,
|
||||||
|
endpoint_name: str,
|
||||||
|
prompt: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
steps: int,
|
||||||
|
seed: int,
|
||||||
|
) -> dict:
|
||||||
|
"""Generate image using Text2Image modifier"""
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
payload = {
|
||||||
|
"input": {
|
||||||
|
"request_id": str(uuid.uuid4()),
|
||||||
|
"modifier": "Text2Image",
|
||||||
|
"modifications": {
|
||||||
|
"prompt": prompt,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"steps": steps,
|
||||||
|
"seed": seed,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
}
|
||||||
response = await endpoint.request("/generate/sync", payload, cost=100)
|
return await endpoint.request("/generate/sync", payload, cost=COST)
|
||||||
|
|
||||||
|
|
||||||
|
async def call_generate_workflow(
|
||||||
|
client: Serverless,
|
||||||
|
*,
|
||||||
|
endpoint_name: str,
|
||||||
|
workflow_json: dict,
|
||||||
|
) -> dict:
|
||||||
|
"""Generate using custom workflow JSON"""
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
payload = {
|
||||||
|
"input": {
|
||||||
|
"request_id": str(uuid.uuid4()),
|
||||||
|
"workflow_json": workflow_json,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return await endpoint.request("/generate/sync", payload, cost=COST)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- Demo Class ----------------------
|
||||||
|
class APIDemo:
|
||||||
|
def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False):
|
||||||
|
self.client = client
|
||||||
|
self.endpoint_name = endpoint_name
|
||||||
|
self.upload_s3 = upload_s3
|
||||||
|
self.s3_client = get_s3_client() if upload_s3 else None
|
||||||
|
|
||||||
|
if upload_s3 and not self.s3_client:
|
||||||
|
log.warning("S3 upload requested but client creation failed. Images will only be saved locally.")
|
||||||
|
|
||||||
|
def extract_filename(self, response: dict) -> str | None:
|
||||||
|
"""Extract the generated image filename from ComfyUI response"""
|
||||||
|
if "comfyui_response" in response:
|
||||||
|
for data in response["comfyui_response"].values():
|
||||||
|
if isinstance(data, dict) and "outputs" in data:
|
||||||
|
for node_output in data["outputs"].values():
|
||||||
|
if "images" in node_output and node_output["images"]:
|
||||||
|
return node_output["images"][0].get("filename")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
|
||||||
|
"""Fetch and save image locally from the worker, optionally upload to S3"""
|
||||||
|
os.makedirs("generated_images", exist_ok=True)
|
||||||
|
return await self._fetch_image(worker_url, filename, local_name)
|
||||||
|
|
||||||
|
def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None:
|
||||||
|
"""Upload a local file to S3 and return the S3 URL"""
|
||||||
|
if not self.s3_client:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.s3_client.upload_file(
|
||||||
|
local_path,
|
||||||
|
S3_BUCKET_NAME,
|
||||||
|
s3_key,
|
||||||
|
ExtraArgs={"ContentType": "image/png"}
|
||||||
|
)
|
||||||
|
s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}"
|
||||||
|
print(f" ☁️ Uploaded to S3: {s3_key}")
|
||||||
|
return s3_url
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Failed to upload to S3: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
|
||||||
|
"""Fetch image from worker's /view endpoint and save locally"""
|
||||||
|
if not worker_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = f"{worker_url}/view"
|
||||||
|
params = {"filename": filename, "type": "output"}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url, params=params, ssl=False) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
path = f"generated_images/{local_name}"
|
||||||
|
image_data = await resp.read()
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(image_data)
|
||||||
|
print(f" 💾 Saved: {path}")
|
||||||
|
|
||||||
|
# Upload to S3 if enabled
|
||||||
|
if self.upload_s3 and self.s3_client:
|
||||||
|
s3_key = f"comfyui/{local_name}"
|
||||||
|
self._upload_to_s3(path, s3_key)
|
||||||
|
|
||||||
|
return path
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def demo_prompt(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
steps: int,
|
||||||
|
seed: int | None,
|
||||||
|
):
|
||||||
|
"""Demo: Generate image from text prompt"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("COMFYUI TEXT-TO-IMAGE DEMO")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if seed is None:
|
||||||
|
seed = random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
|
print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}")
|
||||||
|
print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}")
|
||||||
|
print("\n🎨 Generating image...")
|
||||||
|
|
||||||
|
response = await call_generate(
|
||||||
|
self.client,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
prompt=prompt,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
steps=steps,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ Generation complete!")
|
||||||
|
|
||||||
|
# Get worker URL for fetching images
|
||||||
|
worker_url = response.get("url", "")
|
||||||
|
print(f"Worker URL: {worker_url}")
|
||||||
|
|
||||||
|
# Fetch and save image
|
||||||
|
if "response" in response:
|
||||||
|
filename = self.extract_filename(response["response"])
|
||||||
|
if filename:
|
||||||
|
path = await self.save_image(worker_url, filename, f"comfy_{seed}.png")
|
||||||
|
if not path:
|
||||||
|
print(f"❌ Failed to fetch image")
|
||||||
|
else:
|
||||||
|
print("❌ No image in response")
|
||||||
|
else:
|
||||||
|
print("❌ Unexpected response format")
|
||||||
|
|
||||||
|
async def demo_workflow(self, workflow_file: str):
|
||||||
|
"""Demo: Generate using custom workflow file"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("COMFYUI CUSTOM WORKFLOW DEMO")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if not os.path.exists(workflow_file):
|
||||||
|
log.error(f"Workflow file not found: {workflow_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(workflow_file, "r") as f:
|
||||||
|
workflow_json = json.load(f)
|
||||||
|
|
||||||
|
print(f"Workflow: {workflow_file}")
|
||||||
|
print("\n🎨 Generating...")
|
||||||
|
|
||||||
|
response = await call_generate_workflow(
|
||||||
|
self.client,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
workflow_json=workflow_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ Generation complete!")
|
||||||
|
|
||||||
|
worker_url = response.get("url", "")
|
||||||
|
|
||||||
|
if "response" in response:
|
||||||
|
filename = self.extract_filename(response["response"])
|
||||||
|
if filename:
|
||||||
|
path = await self.save_image(worker_url, filename, "workflow.png")
|
||||||
|
if not path:
|
||||||
|
print(f"❌ Failed to fetch image")
|
||||||
|
else:
|
||||||
|
print("❌ No image in response")
|
||||||
|
else:
|
||||||
|
print("❌ Unexpected response format")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- CLI ----------------------
|
||||||
|
def build_arg_parser() -> argparse.ArgumentParser:
|
||||||
|
p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)")
|
||||||
|
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||||
|
p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT",
|
||||||
|
help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')")
|
||||||
|
p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead")
|
||||||
|
p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})")
|
||||||
|
p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})")
|
||||||
|
p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})")
|
||||||
|
p.add_argument("--seed", type=int, default=None, help="Seed (default: random)")
|
||||||
|
p.add_argument("--s3", action="store_true",
|
||||||
|
help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)")
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
async def main_async():
|
||||||
|
args = build_arg_parser().parse_args()
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Using endpoint: {args.endpoint}")
|
||||||
|
if args.s3:
|
||||||
|
print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with Serverless() as client:
|
||||||
|
demo = APIDemo(client, args.endpoint, upload_s3=args.s3)
|
||||||
|
|
||||||
|
if args.workflow:
|
||||||
|
await demo.demo_workflow(workflow_file=args.workflow)
|
||||||
|
else:
|
||||||
|
await demo.demo_prompt(
|
||||||
|
prompt=args.prompt,
|
||||||
|
width=args.width,
|
||||||
|
height=args.height,
|
||||||
|
steps=args.steps,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
except AttributeError as e:
|
||||||
|
if "API key" in str(e):
|
||||||
|
log.error("API key missing. Set VAST_API_KEY environment variable.")
|
||||||
|
else:
|
||||||
|
log.error(f"Error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# Get the file from the path on the local machine using SCP or SFTP
|
|
||||||
# or configure S3 to upload to cloud storage.
|
|
||||||
print(response["response"]["output"][0]["local_path"])
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main_async())
|
||||||
|
|||||||
@@ -0,0 +1,88 @@
|
|||||||
|
# Null PyWorker
|
||||||
|
|
||||||
|
Holds Vast Serverless reservations open without forwarding any work to a
|
||||||
|
model. Use it when your real workload (a queue consumer in any language)
|
||||||
|
runs as a separate process on the instance and you just want to drive
|
||||||
|
Vast autoscaling: **one POST reserves a worker, one POST releases it.**
|
||||||
|
|
||||||
|
## Use case
|
||||||
|
|
||||||
|
You have a job queue on your own infrastructure (Redis, SQS, NATS, etc.)
|
||||||
|
and a consumer (node, golang, python, a binary — anything) that pulls
|
||||||
|
from it. You want one Vast worker per unit of in-flight work, scaling
|
||||||
|
elastically from zero. The null PyWorker is the autoscaling driver; your
|
||||||
|
consumer does the work.
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
Reservations use the framework's session API. The SDK's
|
||||||
|
`endpoint.session(...)` POSTs `/session/create` to reserve a worker;
|
||||||
|
`session.close()` POSTs `/session/end` to release it. `max_sessions=1`
|
||||||
|
means each worker holds exactly one reservation — the next reservation
|
||||||
|
either lands on a free worker or triggers a scale-up.
|
||||||
|
|
||||||
|
The PyWorker itself does nothing functional:
|
||||||
|
|
||||||
|
- One trivial `/ping` route to satisfy the framework's benchmark
|
||||||
|
requirement (its `max_perf` is pinned to 100).
|
||||||
|
- An internal `/release` endpoint on `127.0.0.1:18999` for the local
|
||||||
|
consumer to end the session without needing `session_auth`.
|
||||||
|
|
||||||
|
## Endpoint parameters
|
||||||
|
|
||||||
|
Tested working configuration:
|
||||||
|
|
||||||
|
| Parameter | Value | Why |
|
||||||
|
|---|---|---|
|
||||||
|
| `target_util` | `1.0` | One session = one worker. Default `0.9` rounds up to an extra worker. |
|
||||||
|
| `min_load` | `0` | Scale-to-zero floor. |
|
||||||
|
| `max_queue_time` | `1` | Stop routing to an occupied worker after ~1s of implied queue. |
|
||||||
|
| `target_queue_time` | `0.5` | Trigger scale-up promptly once anything queues. |
|
||||||
|
| `inactivity_timeout` | `10` (seconds) | Permit scale-to-zero after 10s idle. |
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
| Route | Where | Use |
|
||||||
|
|---|---|---|
|
||||||
|
| `POST /session/create` | endpoint, signed | Reserve a worker (`endpoint.session(...)`) |
|
||||||
|
| `POST /session/end` | endpoint, signed | Release (`session.close()`) |
|
||||||
|
| `POST /release` | `127.0.0.1:18999`, no auth | Local consumer release, no `session_auth` needed |
|
||||||
|
|
||||||
|
## Healthcheck
|
||||||
|
|
||||||
|
Default: stub on `127.0.0.1:18999/health` returning `200`. Set
|
||||||
|
`BACKEND_HEALTH_URL=http://127.0.0.1:9090/health` (absolute URL) to point
|
||||||
|
the framework at your queue consumer's health endpoint instead — if the
|
||||||
|
consumer dies, the autoscaler sees the worker as broken.
|
||||||
|
|
||||||
|
## Deploying
|
||||||
|
|
||||||
|
1. Point `PYWORKER_REPO` at this repo (or your fork).
|
||||||
|
2. Set `BACKEND=null` in the template.
|
||||||
|
3. Run your queue consumer alongside the PyWorker. When it's done with
|
||||||
|
a unit of work:
|
||||||
|
```bash
|
||||||
|
curl -X POST http://127.0.0.1:18999/release
|
||||||
|
```
|
||||||
|
|
||||||
|
## Client demo
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Single reservation, hold 180s
|
||||||
|
python -m workers.null.client --endpoint <NAME> --instance alpha
|
||||||
|
|
||||||
|
# Three concurrent reservations, started 30s apart, each held 360s
|
||||||
|
python -m workers.null.client --endpoint <NAME> --instance alpha --count 3 --hold 360
|
||||||
|
```
|
||||||
|
|
||||||
|
Flags: `--count` (number of concurrent sessions, default 1), `--hold`
|
||||||
|
(seconds each session is held, default 180), `--interval` (seconds
|
||||||
|
between starts when `--count > 1`, default 30), `--cost` (cost reported
|
||||||
|
at session-create, default 100 = `max_perf`), `--instance` (`prod` |
|
||||||
|
`alpha` | `candidate` | `local`).
|
||||||
|
|
||||||
|
## Environment variables
|
||||||
|
|
||||||
|
- `BACKEND_HEALTH_URL` — absolute URL the framework healthchecks. Stub
|
||||||
|
is used when unset.
|
||||||
|
- `NULL_CONTROL_PORT` — internal control server port. Defaults to `18999`.
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from vastai import Serverless
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
async def reserve(client: Serverless, endpoint_name: str, hold: float, cost: int, label: str):
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
async with await endpoint.session(cost=cost, lifetime=hold + 60) as s:
|
||||||
|
sid = s.session_id
|
||||||
|
log.info("[%s] %s open, holding %.0fs", label, sid, hold)
|
||||||
|
await asyncio.sleep(hold)
|
||||||
|
log.info("[%s] %s closed", label, sid)
|
||||||
|
|
||||||
|
|
||||||
|
async def main_async():
|
||||||
|
p = argparse.ArgumentParser(description="Vast Null PyWorker demo client")
|
||||||
|
p.add_argument("--endpoint", default=os.environ.get("VAST_ENDPOINT", "null-prod"))
|
||||||
|
p.add_argument("--instance", choices=("prod", "alpha", "candidate", "local"),
|
||||||
|
default=os.environ.get("VAST_INSTANCE", "prod"))
|
||||||
|
p.add_argument("--count", type=int, default=1,
|
||||||
|
help="concurrent sessions to open (default: 1)")
|
||||||
|
p.add_argument("--interval", type=float, default=30.0,
|
||||||
|
help="seconds between session starts when count>1 (default: 30)")
|
||||||
|
p.add_argument("--hold", type=float, default=180.0,
|
||||||
|
help="seconds to hold each session (default: 180)")
|
||||||
|
p.add_argument("--cost", type=int, default=100,
|
||||||
|
help="cost reported at session-create (default: 100)")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
print(f"endpoint={args.endpoint} instance={args.instance} "
|
||||||
|
f"count={args.count} hold={args.hold}s cost={args.cost}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with Serverless(instance=args.instance) as client:
|
||||||
|
tasks = []
|
||||||
|
for i in range(args.count):
|
||||||
|
label = f"res-{i+1}" if args.count > 1 else "reservation"
|
||||||
|
tasks.append(asyncio.create_task(
|
||||||
|
reserve(client, args.endpoint, args.hold, args.cost, label),
|
||||||
|
name=label,
|
||||||
|
))
|
||||||
|
if i + 1 < args.count:
|
||||||
|
await asyncio.sleep(args.interval)
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
log.info("Interrupted")
|
||||||
|
except Exception as e:
|
||||||
|
log.error("Error: %s", e, exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main_async())
|
||||||
@@ -0,0 +1,143 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from vastai import (
|
||||||
|
Worker,
|
||||||
|
WorkerConfig,
|
||||||
|
HandlerConfig,
|
||||||
|
BenchmarkConfig,
|
||||||
|
LogActionConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
TARGET_PERF = 100.0
|
||||||
|
BENCHMARK_SENTINEL = "__null_worker_benchmark__"
|
||||||
|
|
||||||
|
INTERNAL_HOST = "127.0.0.1"
|
||||||
|
INTERNAL_PORT = int(os.environ.get("NULL_CONTROL_PORT", 18999))
|
||||||
|
STUB_HEALTH_PATH = "/health"
|
||||||
|
|
||||||
|
BACKEND_HEALTH_URL = os.environ.get("BACKEND_HEALTH_URL", "").strip()
|
||||||
|
if BACKEND_HEALTH_URL:
|
||||||
|
_p = urlsplit(BACKEND_HEALTH_URL)
|
||||||
|
if not _p.scheme or not _p.hostname:
|
||||||
|
raise ValueError(f"BACKEND_HEALTH_URL must be absolute, got: {BACKEND_HEALTH_URL!r}")
|
||||||
|
HEALTH_BASE_URL = f"{_p.scheme}://{_p.hostname}"
|
||||||
|
HEALTH_PORT = _p.port or (443 if _p.scheme == "https" else 80)
|
||||||
|
HEALTH_PATH = _p.path or "/"
|
||||||
|
USE_STUB_HEALTH = False
|
||||||
|
else:
|
||||||
|
HEALTH_BASE_URL = f"http://{INTERNAL_HOST}"
|
||||||
|
HEALTH_PORT = INTERNAL_PORT
|
||||||
|
HEALTH_PATH = STUB_HEALTH_PATH
|
||||||
|
USE_STUB_HEALTH = True
|
||||||
|
|
||||||
|
|
||||||
|
_backend_ref: dict = {"backend": None}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_internal_app() -> web.Application:
|
||||||
|
app = web.Application()
|
||||||
|
|
||||||
|
async def release_handler(_request: web.Request) -> web.Response:
|
||||||
|
# Closes the singleton session. Uses name-mangled __close_session
|
||||||
|
# to bypass the session_auth check — safe because this server is
|
||||||
|
# bound to 127.0.0.1, and it spares the consumer from threading
|
||||||
|
# session_auth through its queue.
|
||||||
|
backend = _backend_ref.get("backend")
|
||||||
|
if backend is None:
|
||||||
|
return web.json_response({"released": False, "reason": "backend not ready"}, status=503)
|
||||||
|
sids = list(backend.sessions.keys())
|
||||||
|
if not sids:
|
||||||
|
return web.json_response({"released": False, "reason": "no active session"}, status=200)
|
||||||
|
closed = []
|
||||||
|
for sid in sids:
|
||||||
|
try:
|
||||||
|
if await backend._Backend__close_session(sid):
|
||||||
|
closed.append(sid)
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Error closing session {sid}: {e}")
|
||||||
|
return web.json_response({"released": bool(closed), "session_ids": closed}, status=200)
|
||||||
|
|
||||||
|
app.router.add_post("/release", release_handler)
|
||||||
|
|
||||||
|
if USE_STUB_HEALTH:
|
||||||
|
async def stub_health(_request: web.Request) -> web.Response:
|
||||||
|
return web.Response(status=200, text="ok")
|
||||||
|
app.router.add_get(STUB_HEALTH_PATH, stub_health)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def null_lifecycle():
|
||||||
|
# Pin max_throughput to TARGET_PERF exactly — the framework's
|
||||||
|
# __run_benchmark short-circuits to float(file_contents) if this exists.
|
||||||
|
try:
|
||||||
|
with open(".has_benchmark", "w") as fh:
|
||||||
|
fh.write(str(int(TARGET_PERF)))
|
||||||
|
except OSError as e:
|
||||||
|
log.warning(f"Could not pin benchmark cache: {e}")
|
||||||
|
|
||||||
|
runner = web.AppRunner(_build_internal_app())
|
||||||
|
await runner.setup()
|
||||||
|
await web.TCPSite(runner, INTERNAL_HOST, INTERNAL_PORT).start()
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"Null pyworker control server: http://%s:%d (POST /release%s)",
|
||||||
|
INTERNAL_HOST,
|
||||||
|
INTERNAL_PORT,
|
||||||
|
f", GET {STUB_HEALTH_PATH}" if USE_STUB_HEALTH else "",
|
||||||
|
)
|
||||||
|
if not USE_STUB_HEALTH:
|
||||||
|
log.info("Framework healthcheck → %s", BACKEND_HEALTH_URL)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def ping(**params: object) -> dict:
|
||||||
|
# Exists only to satisfy the framework's "at least one handler with a
|
||||||
|
# BenchmarkConfig" requirement. Sleep 1s on the benchmark path as a
|
||||||
|
# fallback in case the .has_benchmark cache pin failed; otherwise the
|
||||||
|
# benchmark cache short-circuits and this never runs.
|
||||||
|
if params.get(BENCHMARK_SENTINEL):
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
return {"ok": True, "benchmark": True}
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
worker_config = WorkerConfig(
|
||||||
|
model_server_url=HEALTH_BASE_URL,
|
||||||
|
model_server_port=HEALTH_PORT,
|
||||||
|
model_healthcheck_url=HEALTH_PATH,
|
||||||
|
lifecycle=null_lifecycle(),
|
||||||
|
max_sessions=1,
|
||||||
|
handlers=[
|
||||||
|
HandlerConfig(
|
||||||
|
route="/ping",
|
||||||
|
allow_parallel_requests=True,
|
||||||
|
remote_function=ping,
|
||||||
|
workload_calculator=lambda _payload: TARGET_PERF,
|
||||||
|
benchmark_config=BenchmarkConfig(
|
||||||
|
generator=lambda: {BENCHMARK_SENTINEL: True},
|
||||||
|
runs=1,
|
||||||
|
concurrency=1,
|
||||||
|
do_warmup=False,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
log_action_config=LogActionConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
_worker = Worker(worker_config)
|
||||||
|
_backend_ref["backend"] = _worker.backend
|
||||||
|
_worker.run()
|
||||||
+33
-26
@@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
|
|||||||
|
|
||||||
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
|
||||||
|
|
||||||
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
|
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
|
||||||
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
|
||||||
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
|
|
||||||
|
|
||||||
|
|
||||||
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
|
||||||
|
|
||||||
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||||
|
|
||||||
## Client Setup (Demo)
|
## Client Setup (Demo)
|
||||||
|
|
||||||
@@ -34,38 +33,20 @@ uv pip install -r requirements.txt
|
|||||||
|
|
||||||
Several examples have been provided in the client to help you get started with your own implementation.
|
Several examples have been provided in the client to help you get started with your own implementation.
|
||||||
|
|
||||||
### Completions
|
First, set your API key as an environment variable:
|
||||||
|
|
||||||
Call to `/v1/completions` with json response
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
|
export VAST_API_KEY=<your_api_key>
|
||||||
```
|
```
|
||||||
|
|
||||||
### Chat Completion (json)
|
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
|
||||||
|
|
||||||
Call to `/v1/chat/completions` with json response
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
|
|
||||||
```
|
|
||||||
|
|
||||||
### Chat Completion (streaming)
|
### Chat Completion (streaming)
|
||||||
|
|
||||||
Call to `/v1/chat/completions` with streaming response
|
Call to `/v1/chat/completions` with streaming response
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
|
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||||
```
|
|
||||||
|
|
||||||
### Tool Use (json)
|
|
||||||
|
|
||||||
Call to `/v1/chat/completions` with tool and json response.
|
|
||||||
|
|
||||||
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Interactive Chat (streaming)
|
### Interactive Chat (streaming)
|
||||||
@@ -75,6 +56,32 @@ Interactive session with calls to `/v1/chat/completions`.
|
|||||||
Type `clear` to clear the chat history or `quit` to exit.
|
Type `clear` to clear the chat history or `quit` to exit.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
|
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Chat Completion (json)
|
||||||
|
|
||||||
|
Call to `/v1/chat/completions` with json response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Use (json)
|
||||||
|
|
||||||
|
Call to `/v1/chat/completions` with tool and json response.
|
||||||
|
|
||||||
|
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Completions
|
||||||
|
|
||||||
|
Call to `/v1/completions` with json response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
+32
-16
@@ -18,7 +18,7 @@ logging.basicConfig(
|
|||||||
log = logging.getLogger(__file__)
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
# ---------------------- Prompts ----------------------
|
# ---------------------- Prompts ----------------------
|
||||||
COMPLETIONS_PROMPT = "the capital of USA is"
|
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
|
||||||
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||||
TOOLS_PROMPT = (
|
TOOLS_PROMPT = (
|
||||||
"Can you list the files in the current working directory and tell me what you see? "
|
"Can you list the files in the current working directory and tell me what you see? "
|
||||||
@@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[
|
|||||||
|
|
||||||
|
|
||||||
# ---- OpenAI-compatible calls (non-streaming) ----
|
# ---- OpenAI-compatible calls (non-streaming) ----
|
||||||
async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
|
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
|
||||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
@@ -111,9 +111,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa
|
|||||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"])
|
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"])
|
||||||
return resp["response"]
|
return resp["response"]
|
||||||
|
|
||||||
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
|
||||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
@@ -128,9 +128,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis
|
|||||||
return resp["response"]
|
return resp["response"]
|
||||||
|
|
||||||
# ---- Streaming variants ----
|
# ---- Streaming variants ----
|
||||||
async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs):
|
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
|
||||||
|
|
||||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
@@ -144,9 +144,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k
|
|||||||
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"], stream=True)
|
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"], stream=True)
|
||||||
return resp["response"] # async generator
|
return resp["response"] # async generator
|
||||||
|
|
||||||
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs):
|
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
|
||||||
|
|
||||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
@@ -166,9 +166,10 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L
|
|||||||
class APIDemo:
|
class APIDemo:
|
||||||
"""Demo and testing functionality for the API client"""
|
"""Demo and testing functionality for the API client"""
|
||||||
|
|
||||||
def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None):
|
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.endpoint_name = endpoint_name
|
||||||
self.tool_manager = tool_manager or ToolManager()
|
self.tool_manager = tool_manager or ToolManager()
|
||||||
|
|
||||||
# ----- Streaming handler -----
|
# ----- Streaming handler -----
|
||||||
@@ -177,10 +178,15 @@ class APIDemo:
|
|||||||
reasoning_content = ""
|
reasoning_content = ""
|
||||||
printed_reasoning = False
|
printed_reasoning = False
|
||||||
printed_answer = False
|
printed_answer = False
|
||||||
|
finish_reason = None
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
choice = (chunk.get("choices") or [{}])[0]
|
choice = (chunk.get("choices") or [{}])[0]
|
||||||
delta = choice.get("delta", {})
|
delta = choice.get("delta", {})
|
||||||
|
|
||||||
|
# Track finish reason
|
||||||
|
if choice.get("finish_reason"):
|
||||||
|
finish_reason = choice.get("finish_reason")
|
||||||
|
|
||||||
# reasoning tokens
|
# reasoning tokens
|
||||||
rc = delta.get("reasoning_content")
|
rc = delta.get("reasoning_content")
|
||||||
@@ -211,6 +217,8 @@ class APIDemo:
|
|||||||
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
print(f"Reasoning tokens: {len(reasoning_content.split())}")
|
||||||
if printed_answer:
|
if printed_answer:
|
||||||
print(f"Response tokens: {len(full_response.split())}")
|
print(f"Response tokens: {len(full_response.split())}")
|
||||||
|
if finish_reason:
|
||||||
|
print(f"Finish reason: {finish_reason}")
|
||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
@@ -223,6 +231,7 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
prompt=COMPLETIONS_PROMPT,
|
prompt=COMPLETIONS_PROMPT,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
temperature=DEFAULT_TEMPERATURE,
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
)
|
)
|
||||||
@@ -241,6 +250,7 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
temperature=DEFAULT_TEMPERATURE
|
temperature=DEFAULT_TEMPERATURE
|
||||||
)
|
)
|
||||||
@@ -253,6 +263,7 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
temperature=DEFAULT_TEMPERATURE
|
temperature=DEFAULT_TEMPERATURE
|
||||||
)
|
)
|
||||||
@@ -279,6 +290,7 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
tools=minimal_tool,
|
tools=minimal_tool,
|
||||||
tool_choice="none",
|
tool_choice="none",
|
||||||
max_tokens=10
|
max_tokens=10
|
||||||
@@ -304,6 +316,7 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
tools=self.tool_manager.get_ls_tool_definition(),
|
tools=self.tool_manager.get_ls_tool_definition(),
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
@@ -381,6 +394,7 @@ class APIDemo:
|
|||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
temperature=DEFAULT_TEMPERATURE,
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
)
|
)
|
||||||
@@ -419,7 +433,6 @@ class APIDemo:
|
|||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("INTERACTIVE STREAMING CHAT")
|
print("INTERACTIVE STREAMING CHAT")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"Using model: {self.model}")
|
|
||||||
print("Type 'quit' to exit, 'clear' to clear history")
|
print("Type 'quit' to exit, 'clear' to clear history")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
@@ -445,7 +458,8 @@ class APIDemo:
|
|||||||
stream = await stream_chat_completions(
|
stream = await stream_chat_completions(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
temperature=0.7
|
temperature=0.7
|
||||||
)
|
)
|
||||||
@@ -465,8 +479,8 @@ class APIDemo:
|
|||||||
# ---------------------- CLI ----------------------
|
# ---------------------- CLI ----------------------
|
||||||
def build_arg_parser() -> argparse.ArgumentParser:
|
def build_arg_parser() -> argparse.ArgumentParser:
|
||||||
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
|
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
|
||||||
p.add_argument("--model", required=True, help="Model to use for requests (required)")
|
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
|
||||||
p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)")
|
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||||
|
|
||||||
modes = p.add_mutually_exclusive_group(required=False)
|
modes = p.add_mutually_exclusive_group(required=False)
|
||||||
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
|
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
|
||||||
@@ -494,12 +508,14 @@ async def main_async():
|
|||||||
print("Please specify exactly one test mode")
|
print("Please specify exactly one test mode")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
print(f"Using model: {args.model}")
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
print(f"Using model: {args.model}")
|
||||||
|
print(f"Using endpoint: {args.endpoint}")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with Serverless() as client:
|
async with Serverless() as client:
|
||||||
demo = APIDemo(client, args.model, ToolManager())
|
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
|
||||||
|
|
||||||
if args.completion:
|
if args.completion:
|
||||||
await demo.demo_completions()
|
await demo.demo_completions()
|
||||||
|
|||||||
@@ -28,6 +28,12 @@ MODEL_INFO_LOG_MSGS = [
|
|||||||
nltk.download("words")
|
nltk.download("words")
|
||||||
WORD_LIST = nltk.corpus.words.words()
|
WORD_LIST = nltk.corpus.words.words()
|
||||||
|
|
||||||
|
def request_parser(request):
|
||||||
|
data = request
|
||||||
|
if request.get("input") is not None:
|
||||||
|
data = request.get("input")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def completions_benchmark_generator() -> dict:
|
def completions_benchmark_generator() -> dict:
|
||||||
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||||
@@ -54,18 +60,20 @@ worker_config = WorkerConfig(
|
|||||||
route="/v1/completions",
|
route="/v1/completions",
|
||||||
workload_calculator= lambda data: data.get("max_tokens", 0),
|
workload_calculator= lambda data: data.get("max_tokens", 0),
|
||||||
allow_parallel_requests=True,
|
allow_parallel_requests=True,
|
||||||
max_queue_time=60.0,
|
request_parser=request_parser,
|
||||||
|
max_queue_time=600.0,
|
||||||
benchmark_config=BenchmarkConfig(
|
benchmark_config=BenchmarkConfig(
|
||||||
generator=completions_benchmark_generator,
|
generator=completions_benchmark_generator,
|
||||||
concurrency=100,
|
concurrency=10,
|
||||||
runs=2
|
runs=3
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
HandlerConfig(
|
HandlerConfig(
|
||||||
route="/v1/chat/completions",
|
route="/v1/chat/completions",
|
||||||
workload_calculator= lambda data: data.get("max_tokens", 0),
|
workload_calculator= lambda data: data.get("max_tokens", 0),
|
||||||
allow_parallel_requests=True,
|
allow_parallel_requests=True,
|
||||||
max_queue_time=60.0,
|
request_parser=request_parser,
|
||||||
|
max_queue_time=600.0,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
log_action_config=LogActionConfig(
|
log_action_config=LogActionConfig(
|
||||||
|
|||||||
+93
-9
@@ -1,19 +1,103 @@
|
|||||||
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
|
# HuggingFace TGI PyWorker
|
||||||
|
|
||||||
1. `generate`: Generates the LLM's response to a given prompt in a single request.
|
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
|
||||||
2. `generate_stream`: Streams the LLM's response token by token.
|
|
||||||
|
|
||||||
Both endpoints use the following API payload format:
|
## Instance Setup
|
||||||
|
|
||||||
|
1. Pick a template
|
||||||
|
|
||||||
|
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
|
||||||
|
|
||||||
|
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
|
||||||
|
|
||||||
|
The template can be configured via the template interface. You may want to change the model or startup arguments.
|
||||||
|
|
||||||
|
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
|
||||||
|
|
||||||
|
## Client Setup (Demo)
|
||||||
|
|
||||||
|
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/vast-ai/pyworker
|
||||||
|
cd pyworker
|
||||||
|
pip install uv
|
||||||
|
uv venv -p 3.12
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using the Test Client
|
||||||
|
|
||||||
|
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
|
||||||
|
|
||||||
|
First, set your API key as an environment variable:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export VAST_API_KEY=<your_api_key>
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
|
||||||
|
|
||||||
|
### Generate (Streaming)
|
||||||
|
|
||||||
|
Call to `/generate_stream` with streaming response:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Generate (Non-Streaming)
|
||||||
|
|
||||||
|
Call to `/generate` with json response:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Interactive Session (Streaming)
|
||||||
|
|
||||||
|
Interactive session with streaming responses. Type `quit` to exit.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
TGI provides two primary endpoints:
|
||||||
|
|
||||||
|
### Generate (Non-Streaming)
|
||||||
|
|
||||||
|
`/generate` - Returns the complete response in a single request.
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"inputs": "PROMPT",
|
"inputs": "Your prompt here",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_new_tokens": 250
|
"max_new_tokens": 1024,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"return_full_text": false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
|
### Generate Stream (Streaming)
|
||||||
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
|
|
||||||
approximately 2 seconds to complete.
|
`/generate_stream` - Streams the response token by token.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"inputs": "Your prompt here",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 1024,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"do_sample": true,
|
||||||
|
"return_full_text": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Notes
|
||||||
|
|
||||||
|
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
|
||||||
|
|||||||
+195
-34
@@ -1,61 +1,222 @@
|
|||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
|
||||||
from vastai import Serverless
|
from vastai import Serverless
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name
|
# ---------------------- Logging ----------------------
|
||||||
MAX_TOKENS = 1024
|
logging.basicConfig(
|
||||||
PROMPT = "Think step by step: Tell me about the Python programming language."
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s[%(levelname)-5s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__file__)
|
||||||
|
|
||||||
async def call_generate(client: Serverless) -> None:
|
# ---------------------- Defaults ----------------------
|
||||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
|
||||||
|
|
||||||
|
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
|
||||||
|
MAX_TOKENS = 1024
|
||||||
|
DEFAULT_TEMPERATURE = 0.7
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- API Calls ----------------------
|
||||||
|
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
|
||||||
|
"""Non-streaming generation via /generate endpoint"""
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"inputs": PROMPT,
|
"inputs": prompt,
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_new_tokens": MAX_TOKENS,
|
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
"temperature": 0.7,
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
"return_full_text": False
|
"return_full_text": False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
log.debug("POST /generate %s", json.dumps(payload)[:500])
|
||||||
resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)
|
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
|
||||||
|
return resp["response"]
|
||||||
print(resp["response"]["generated_text"])
|
|
||||||
|
|
||||||
|
|
||||||
async def call_generate_stream(client: Serverless) -> None:
|
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
|
||||||
endpoint = await client.get_endpoint(name=ENDPOINT_NAME)
|
"""Streaming generation via /generate_stream endpoint"""
|
||||||
|
endpoint = await client.get_endpoint(name=endpoint_name)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"inputs": PROMPT,
|
"inputs": prompt,
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_new_tokens": MAX_TOKENS,
|
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
|
||||||
"temperature": 0.7,
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"return_full_text": False,
|
"return_full_text": False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
|
||||||
resp = await endpoint.request(
|
resp = await endpoint.request(
|
||||||
"/generate_stream",
|
"/generate_stream",
|
||||||
payload,
|
payload,
|
||||||
cost=MAX_TOKENS,
|
cost=payload["parameters"]["max_new_tokens"],
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
stream = resp["response"]
|
return resp["response"] # async generator
|
||||||
|
|
||||||
printed_answer = False
|
|
||||||
async for event in stream:
|
|
||||||
tok = (event.get("token") or {}).get("text")
|
|
||||||
if tok:
|
|
||||||
if not printed_answer:
|
|
||||||
printed_answer = True
|
|
||||||
print("Answer:\n", end="", flush=True)
|
|
||||||
print(tok, end="", flush=True)
|
|
||||||
|
|
||||||
async def main():
|
# ---------------------- Demo Runner ----------------------
|
||||||
async with Serverless() as client:
|
class APIDemo:
|
||||||
await call_generate(client)
|
"""Demo and testing functionality for the TGI API client"""
|
||||||
await call_generate_stream(client)
|
|
||||||
|
def __init__(self, client: Serverless, endpoint_name: str):
|
||||||
|
self.client = client
|
||||||
|
self.endpoint_name = endpoint_name
|
||||||
|
|
||||||
|
async def handle_streaming_response(self, stream) -> str:
|
||||||
|
"""Process streaming response and print tokens"""
|
||||||
|
full_response = ""
|
||||||
|
printed_answer = False
|
||||||
|
|
||||||
|
async for event in stream:
|
||||||
|
tok = (event.get("token") or {}).get("text")
|
||||||
|
if tok:
|
||||||
|
if not printed_answer:
|
||||||
|
printed_answer = True
|
||||||
|
print("\n💬 Response: ", end="", flush=True)
|
||||||
|
print(tok, end="", flush=True)
|
||||||
|
full_response += tok
|
||||||
|
|
||||||
|
print() # newline
|
||||||
|
if printed_answer:
|
||||||
|
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
|
||||||
|
|
||||||
|
return full_response
|
||||||
|
|
||||||
|
async def demo_generate(self) -> None:
|
||||||
|
"""Demo non-streaming generation"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("GENERATE DEMO (NON-STREAMING)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
response = await call_generate(
|
||||||
|
client=self.client,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
prompt=DEFAULT_PROMPT,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n💬 Response: {response.get('generated_text', '')}")
|
||||||
|
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
|
||||||
|
|
||||||
|
async def demo_generate_stream(self) -> None:
|
||||||
|
"""Demo streaming generation"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("GENERATE DEMO (STREAMING)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
stream = await call_generate_stream(
|
||||||
|
client=self.client,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
prompt=DEFAULT_PROMPT,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.handle_streaming_response(stream)
|
||||||
|
except Exception as e:
|
||||||
|
log.error("\nError during streaming: %s", e, exc_info=True)
|
||||||
|
|
||||||
|
async def interactive_chat(self) -> None:
|
||||||
|
"""Interactive session with streaming generation"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("INTERACTIVE STREAMING SESSION")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Using endpoint: {self.endpoint_name}")
|
||||||
|
print("Type 'quit' to exit")
|
||||||
|
print()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
user_input = input("You: ").strip()
|
||||||
|
|
||||||
|
if user_input.lower() == "quit":
|
||||||
|
print("👋 Goodbye!")
|
||||||
|
break
|
||||||
|
elif not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("Assistant: ", end="", flush=True)
|
||||||
|
stream = await call_generate_stream(
|
||||||
|
client=self.client,
|
||||||
|
endpoint_name=self.endpoint_name,
|
||||||
|
prompt=user_input,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
temperature=DEFAULT_TEMPERATURE,
|
||||||
|
)
|
||||||
|
|
||||||
|
full_response = ""
|
||||||
|
async for event in stream:
|
||||||
|
tok = (event.get("token") or {}).get("text")
|
||||||
|
if tok:
|
||||||
|
print(tok, end="", flush=True)
|
||||||
|
full_response += tok
|
||||||
|
print() # newline
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n👋 Session interrupted. Goodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
log.error("\nError: %s", e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- CLI ----------------------
|
||||||
|
def build_arg_parser() -> argparse.ArgumentParser:
|
||||||
|
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
|
||||||
|
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
|
||||||
|
|
||||||
|
modes = p.add_mutually_exclusive_group(required=False)
|
||||||
|
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
|
||||||
|
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
|
||||||
|
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
async def main_async():
|
||||||
|
args = build_arg_parser().parse_args()
|
||||||
|
|
||||||
|
selected = sum([args.generate, args.generate_stream, args.interactive])
|
||||||
|
if selected == 0:
|
||||||
|
print("Please specify exactly one test mode:")
|
||||||
|
print(" --generate : Test generate endpoint (non-streaming)")
|
||||||
|
print(" --generate-stream : Test generate endpoint with streaming")
|
||||||
|
print(" --interactive : Start interactive streaming session")
|
||||||
|
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
|
||||||
|
sys.exit(1)
|
||||||
|
elif selected > 1:
|
||||||
|
print("Please specify exactly one test mode")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Using endpoint: {args.endpoint}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with Serverless() as client:
|
||||||
|
demo = APIDemo(client, args.endpoint)
|
||||||
|
|
||||||
|
if args.generate:
|
||||||
|
await demo.demo_generate()
|
||||||
|
elif args.generate_stream:
|
||||||
|
await demo.demo_generate_stream()
|
||||||
|
elif args.interactive:
|
||||||
|
await demo.interactive_chat()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error("Error during test: %s", e, exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main_async())
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ def benchmark_generator() -> dict:
|
|||||||
benchmark_data = {
|
benchmark_data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_new_tokens": 128,
|
"max_new_tokens": 500,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"return_full_text": False
|
"return_full_text": False
|
||||||
}
|
}
|
||||||
@@ -52,17 +52,18 @@ worker_config = WorkerConfig(
|
|||||||
HandlerConfig(
|
HandlerConfig(
|
||||||
route="/generate",
|
route="/generate",
|
||||||
allow_parallel_requests=True,
|
allow_parallel_requests=True,
|
||||||
max_queue_time=60.0,
|
max_queue_time=600.0,
|
||||||
benchmark_config=BenchmarkConfig(
|
benchmark_config=BenchmarkConfig(
|
||||||
generator=benchmark_generator,
|
generator=benchmark_generator,
|
||||||
concurrency=50
|
concurrency=10,
|
||||||
|
runs=3
|
||||||
),
|
),
|
||||||
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
|
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
|
||||||
),
|
),
|
||||||
HandlerConfig(
|
HandlerConfig(
|
||||||
route="/generate_stream",
|
route="/generate_stream",
|
||||||
allow_parallel_requests=True,
|
allow_parallel_requests=True,
|
||||||
max_queue_time=60.0,
|
max_queue_time=600.0,
|
||||||
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
|
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
This is the PyWorker implementation for running **Wan 2.2 T2V A14B** text-to-video workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI video-generation workflows through a proxy-based architecture and returning generated video assets.
|
This is the PyWorker implementation for running **Wan 2.2 T2V A14B** text-to-video workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI video-generation workflows through a proxy-based architecture and returning generated video assets.
|
||||||
|
|
||||||
Each request has a static cost of `100`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
|
Each request has a static cost of `10000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user