Compare commits

..

150 Commits

Author SHA1 Message Date
Rob Ballantyne a81d3febe7 Collapse null pyworker client to a single mode parameterized by --count
Now that the session model means no HTTP connection is held during the
reservation, the dichotomy between "single reserve" and "trapezoid demo"
collapses — both are "open N sessions, each held for H seconds, started
I seconds apart, close." Replace --reserve/--demo/--duration/--plateau
with --count/--hold/--interval. --session-cost becomes --cost.

Client is now 64 lines (down from 120).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 12:18:33 +01:00
Rob Ballantyne 913e3a8782 Simplify null pyworker code and docs
Pass over all three files to drop verbose expository commentary that
duplicated either the code or the README. Net: -284 lines.

README now reads top-to-bottom in roughly the order someone would need
the info: use case → how it works → endpoint params → API → healthcheck
→ deploy → demo. Endpoint params table uses the values actually tested
on alpha (min_load=0, target_util=1, max_queue_time=1,
target_queue_time=0.5, inactivity_timeout=10). Dropped the
"known autoscaler quirk" section now that alpha addresses it; kept the
--session-cost flag as a debugging knob.

worker.py and client.py keep the same behavior but trim long block
comments and multi-line docstrings the code didn't need.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 11:50:03 +01:00
Rob Ballantyne 47ad0ebe0a Add --instance flag to null pyworker client
Lets the demo target run-alpha.vast.ai (or candidate/local) without
editing code. Defaults to prod; respects VAST_INSTANCE env var.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 11:40:51 +01:00
Rob Ballantyne 34fd21e76a Revert default session cost to 100; document the over-provision as a workaround
cost = max_perf = 100 is the intended steady-state semantics: one
session = one worker, scaling elastically from zero. Reverting the
default so the design reads correctly even where current autoscaler
bugs make it misbehave (2→3 scale-up not firing reliably,
scale-to-zero issues — fixes pending on the Vast side).

README now describes the intended model first (clean unit occupancy,
scale-to-zero via inactivity_timeout + min_load=0), then flags the
known autoscaler quirk and presents --session-cost 200 as a temporary
band-aid until the Vast fixes land.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 11:34:52 +01:00
Rob Ballantyne 1d2caaf554 Default null pyworker session cost to 2x max_perf
Reporting cost == max_perf puts an occupied worker at exactly 100%
utilization, which the autoscaler reads as "at target, no action."
The 3rd session_create then 429s on both active workers and stalls in
the global queue instead of triggering a cold-worker activation
(observed: 1→2 active scales fine, 2→3 does not).

Bumping cost to 2 * max_perf makes each session look like more than
one worker's work, so the autoscaler always keeps an extra active
worker hot. Slight over-provisioning, but the 3rd reservation lands
directly on a free worker rather than queueing.

Expose --session-cost on the client so the value can be swept without
edits. README documents the trade-off.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 11:31:26 +01:00
Rob Ballantyne 01eff874d8 Correct queue-time guidance for null pyworker endpoints
Earlier note claimed max_queue_time / target_queue_time were no-ops
because the worker's internal wait_time property filters sessions out.
That filter only affects per-worker rejection on a given handler — the
autoscaler doesn't see the property and computes its own queue-time
estimate from cur_load / max_perf, which *does* include sessions.

With defaults around 30s, an occupied null worker (cur_load=100,
max_perf=100, implied queue=1s) still looks "available" to the
autoscaler, so a third reservation gets queued on an existing worker
via repeated 429-retries instead of triggering scale-up.

Fix: set max_queue_time = 0 and target_queue_time = 0 on the endpoint.
Any in-flight load marks the worker "full" for routing, and any
observed queue time triggers immediate scale-up.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 11:14:20 +01:00
Rob Ballantyne d51f04a176 Await endpoint.session() in null pyworker client
endpoint.session() forwards to start_endpoint_session, which is async
def — so the call returns a coroutine, not a Session, despite the
SDK's return-type annotation. Use 'async with await endpoint.session(...)'
so the coroutine resolves to a Session before entering the context.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 11:07:32 +01:00
Rob Ballantyne ef248ef695 Document endpoint scaling parameters for null pyworker
Add a scaling-parameters section to the README covering target_util=1.0
(the critical one — the default 0.9 silently rounds up to one extra
worker), min_load math, and why max_queue_time / target_queue_time
don't matter here (sessions are filtered from wait_time so both signals
stay at zero).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 11:06:04 +01:00
Rob Ballantyne 6a562a1376 Rewrite null pyworker on the framework session model
Drop the held-/reserve approach in favour of the framework's session
primitive (max_sessions=1 + /session/create). Sessions are excluded from
the autoscaler's queue-wait math and don't suffer the cur_perf=0
degradation that a long-held request did, so this naturally produces the
"one request comes in and you get a worker; release and it scales back
down" model we were hand-rolling.

Server side:
  - max_sessions=1; framework auto-registers /session/* routes
  - Drop custom /reserve handler, _active_reservation event, max_queue_
    time=0.0, MAX_RESERVATION_SECONDS, _perf_heartbeat
  - Trivial /ping handler exists only to satisfy the framework's
    "at least one handler with BenchmarkConfig" requirement (and to give
    clients an extension/keepalive route)
  - /release on the internal control port is kept as a convenience for
    queue consumers that don't carry session_auth — calls the framework's
    __close_session via name-mangling, which bypasses the session_auth
    check but is fine for a localhost-only endpoint
  - Workload/perf back to 100 (conventional)

Client side:
  - Uses endpoint.session(cost, lifetime) instead of POST /reserve
  - async with the SDK Session; close on exit posts /session/end with
    proper auth → 200 success in metrics
  - Demo and single modes both ride the same reserve() helper

Sessions landed in vastai-sdk 0.4.2 (commit ec9ef59, 2026-01-20).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 10:51:24 +01:00
Rob Ballantyne 6c2f194b28 Add perf heartbeat to keep null pyworker reporting peak throughput
While a /reserve is held, no requests complete so workload_served stays
at 0 each metrics tick. The autoscaler sees cur_perf=0 against
max_perf=150, concludes the worker can't deliver claimed throughput,
downgrades it, and gets cautious about scaling up — so additional
/reserve requests pile up behind the held one instead of triggering a
new worker.

Add a 1Hz heartbeat coroutine that, while anything is in flight, sets
workload_served back to TARGET_PERF (150) and flags update_pending. The
metrics tick reads 150 and resets to 0; the heartbeat re-pins it before
the next tick. Net effect: the autoscaler sees a saturated worker
delivering at peak rate, which is the signal it needs to scale a new
worker up rather than queue.

The heartbeat needs the backend instance, which is only created inside
Worker(...) — stash a reference in a module-level dict between Worker()
and .run() so the lifecycle coroutine can reach it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 10:35:18 +01:00
Rob Ballantyne 2aada7b210 Add --plateau to null pyworker demo (default 5min)
Previously the first release fired only 30s after the third reservation
started, so the autoscaler often hadn't even finished provisioning the
third worker yet. Default plateau to 300s so all three workers are
visibly running before scale-down begins; configurable via --plateau.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 18:26:31 +01:00
Rob Ballantyne 8df562e243 Standardize null pyworker load/perf on 150
Bump workload_calculator, benchmark cache value, and client cost from 100
to 150.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 18:17:57 +01:00
Rob Ballantyne 4eef5e22af Pin null pyworker max_throughput to exactly 100
asyncio.sleep(1.0) takes slightly more than 1s due to event loop
scheduling, so workload/time landed at ~99.x instead of 100. Pre-populate
the framework's .has_benchmark cache file with "100" before the benchmark
runs — __run_benchmark short-circuits to the cached value and skips the
time-based calculation entirely.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 18:13:16 +01:00
Rob Ballantyne 9d969e376e Standardize null pyworker load/perf on 100
Using 1 confused the serverless capacity math. Set workload_calculator,
benchmark target throughput, and client cost all to 100 — the conventional
default the rest of the system expects.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 18:09:16 +01:00
Rob Ballantyne ef3f34a515 Restructure null pyworker --demo as a clean trapezoid
Three reservations 30s apart, each with a 90s duration. They end one at
a time, also 30s apart, then the client exits. Each reservation ends
via its duration cap (200 success) rather than the previous "cancel one,
leave two open" pattern that left two 499s pending.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 18:00:46 +01:00
Rob Ballantyne 147bf2597a Set null pyworker client cost to 1
Match the server-side workload_calculator (1.0) so the autoscaler routing
hint is consistent with what the worker reports. A null reservation is a
unitless slot — no reason for client cost to be 100.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 17:47:19 +01:00
Rob Ballantyne dc423e2999 Pin null pyworker benchmark to ~1.0 throughput
The startup benchmark previously returned instantly, producing
max_throughput around 339895. A null worker has no real throughput
concept (each reservation is a unitless slot), so sleep 1s during the
benchmark with workload=1 to record max_throughput ~= 1.0.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 17:22:45 +01:00
Rob Ballantyne 463f3de8ea Add staggered --demo mode to null pyworker client
Three concurrent /reserve calls 30s apart, then cancel the first to show
the early-release path. The remaining two run until their duration cap.
Useful for watching scale-up/scale-down behaviour in the autoscaler
dashboard.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 17:08:44 +01:00
Rob Ballantyne ed0db198c3 Reject queued /reserve immediately on busy null workers
A held reservation runs for up to MAX_RESERVATION_SECONDS (default 1h), so
queueing a second /reserve behind it makes no sense — the wait would dwarf
any sane timeout. Set max_queue_time=0.0 so the framework rejects 429 as
soon as another reservation is in flight, and serverless routes the request
to a free worker or scales a new one up.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 17:05:02 +01:00
Rob Ballantyne 3668d948be Simplify null pyworker README intro to serverless terminology
Drop the "autoscaler provisions a worker if none is free" phrasing in
favor of the simpler "request comes in and you get a worker; release and
it scales back down."

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 17:02:41 +01:00
Rob Ballantyne 254ccdf181 Add /release control endpoint to null pyworker
The held /reserve now waits on an asyncio.Event and resolves when the local
queue consumer POSTs /release on the internal control port (127.0.0.1:18999
by default). This produces a 200 success in metrics instead of the 499
cancellation you got from disconnecting the client. The duration cap stays
as a safety net for stuck consumers.

The internal aiohttp server is now unconditional and hosts /release always;
the stub /health route is added only when BACKEND_HEALTH_URL is unset.
NULL_STUB_HEALTH_PORT is renamed to NULL_CONTROL_PORT to reflect the
broader role.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 16:59:46 +01:00
Rob Ballantyne 89761b378a Wire null pyworker healthcheck to a stub (and optional user URL)
Adds an in-process aiohttp stub on 127.0.0.1:18999/health so the framework's
periodic healthcheck has something live to talk to. Operators can override
with BACKEND_HEALTH_URL to point at their queue consumer's /health
endpoint, so the autoscaler marks the worker errored if the consumer dies.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 16:53:26 +01:00
Rob Ballantyne 18974873e5 Add null pyworker for queue-driven autoscaling
A PyWorker that does not forward to any model server. POST /reserve holds
the worker busy until the client disconnects (or the duration cap elapses),
so users with their own job queue can drive Vast autoscaling without
exposing inbound model traffic on the instance.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 16:48:52 +01:00
Lucas Armand 9bc9ba11c5 Increase TGI benchmark tokens to 500 2026-04-30 14:04:39 -07:00
LucasArmandVast 48fdc65e3d Update to vastai package (#84) 2026-04-14 10:41:31 -07:00
LucasArmandVast 2cd97315cd Add nltk requirement for openai worker (#83)
* Add nltk requirement for openai worker

* pin version
2026-04-13 11:30:06 -07:00
Lucas Armand 83c31e25a9 Add force update detection 2026-03-31 13:46:22 -07:00
Lucas Armand fbe1dca6fa more env_path fixes 2026-03-30 16:28:51 -07:00
Lucas Armand 4c3120dbc5 allow override env_path 2026-03-30 16:25:01 -07:00
Lucas Armand d7d9b915f6 allow break system packages 2026-03-30 16:09:17 -07:00
Lucas Armand 4660b337fb Check for USE_SYSTEM_PYTHON 2026-03-30 14:46:38 -07:00
edgaratvast 7506ecb6b5 directly invoke one stop shop setup executable exported by vastai pip package for deployments (#82) 2026-03-26 10:59:49 -07:00
LucasArmandVast 50633c5003 Update deployments script with retries. (#81) 2026-03-23 14:58:32 -07:00
LucasArmandVast 2e8f18276f Add beta deployments script (#80) 2026-03-23 14:14:06 -07:00
Scott Darden eba9c480eb Merge pull request #79 from vast-ai/update-requirements
Updated requirements to only require vastai-sdk
2026-01-14 12:07:33 -08:00
Lucas Armand aaca1c9645 Updated requirements to only require vastai-sdk 2026-01-14 10:47:07 -08:00
LucasArmandVast f319db6bd5 flag for model log rotate (#78) 2026-01-12 17:03:18 -08:00
LucasArmandVast 4d786b4d17 SDK Versioning Improvements (#77)
* Add SDK_BRANCH
2026-01-02 10:23:07 -08:00
LucasArmandVast bd3e0032a1 Add SDK version checking (#76) 2025-12-17 21:01:52 -08:00
Lucas Armand e02f4bc943 Lowered concurrency of vLLM and TGI benchmarks 2025-12-17 11:55:33 -08:00
Lucas Armand bcb04b9a32 add missing comma 2025-12-17 11:40:40 -08:00
Lucas Armand 9daf171487 Increase queue limits for vLLM and TGI 2025-12-17 11:38:55 -08:00
LucasArmandVast 29f836eb1a Backwards compatible vLLM payload (#75)
* Support old vLLM payloads
2025-12-15 19:58:02 -08:00
LucasArmandVast 4380d98c01 Use PyWorker SDK (#67)
* Change PyWorker to Worker SDK
* Moved /lib to vast-sdk (https://github.com/vast-ai/vast-sdk)
2025-12-15 19:33:03 -08:00
Abiola Akinnubi 2ce741a8b7 Merge pull request #74 from vast-ai/AUTO-912
Mark pyworkers as "Error" if startup script fails. to avoid silent fail that waits for autoscaler.
2025-12-11 17:05:13 -08:00
Abiola Akinnubi 4ecc07032f Mark pyworkers as "Error" if startup script fails. to avoid silent fail that waits for autoscaler. 2025-12-11 12:51:56 -08:00
edgaratvast df61e6e946 correct version pin for aiohttp (#73)
Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-12-10 19:34:52 -08:00
LucasArmandVast 70f8a8f534 Merge pull request #72 from vast-ai/hotfix-pin-pycares
Hotfix: pin pycares
2025-12-10 20:41:44 -05:00
Lucas Armand 7be8aa6397 pin pycares 2025-12-10 17:38:03 -08:00
Colter-Downing 138fc3ac47 Merge pull request #71 from vast-ai/AUTO-comfyui-updates
Auto comfyui updates
2025-12-04 10:55:12 -08:00
Colter Downing 222ac2a0dd default endpoint name 2025-12-04 10:54:55 -08:00
Colter Downing 40aed9b5f8 adding s3 as an option 2025-12-04 10:52:57 -08:00
Colter Downing d4d36bf86e done with comfy updates 2025-12-03 20:45:55 -08:00
Colter Downing e839cfc6e8 include view in API wrapper 2025-12-03 20:22:45 -08:00
Colter Downing f04138e13b update to be able to get images 2025-12-03 20:16:25 -08:00
Colter-Downing de3aa87c8f Merge pull request #70 from vast-ai/AUTO-tgi-client-edits
update tgi client
2025-12-03 18:40:01 -08:00
Colter Downing 6b5b1341a7 update tgi client 2025-12-03 18:38:42 -08:00
Colter-Downing 8be92c03de Merge pull request #69 from vast-ai/AUTO-874--fix-openai-worker-client
defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first
2025-12-03 16:59:56 -08:00
Colter Downing adedb8ba90 defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first if present 2025-12-03 16:57:28 -08:00
LucasArmandVast 2f543c01ad Merge pull request #68 from vast-ai/fix-vllm-concurrency
Increase model wait time for vLLM
2025-12-03 16:13:51 -05:00
Lucas Armand 0bcd2219ea Increase model wait time for vLLM 2025-12-03 12:38:52 -08:00
LucasArmandVast 0339b471c5 Merge pull request #66 from vast-ai/synthesis
PyWorker Error Handling
2025-11-25 16:02:26 -08:00
Lucas Armand e143162438 bumpy pyworker version 2025-11-25 16:01:23 -08:00
Lucas Armand 7986e51e9e early errors 2025-11-24 15:24:06 -08:00
Lucas Armand 9c6ab78503 Move model log line 2025-11-24 15:22:23 -08:00
Lucas Armand 45e0c7d9ca Move model log rotate to top 2025-11-24 15:02:33 -08:00
LucasArmandVast 7a792fd176 Merge pull request #64 from vast-ai/add-llama-log
add llama log
2025-11-21 10:24:27 -08:00
Lucas Armand e0449cb3c7 add llama log 2025-11-21 10:22:16 -08:00
Lucas Armand a4339bd3f1 hotfix: add f 2025-11-12 16:10:55 -08:00
Lucas Armand 2b26e5e20c hotfix: remove g 2025-11-12 16:01:57 -08:00
LucasArmandVast d3727d4fd7 Merge pull request #58 from vast-ai/update-client-scripts
Update client scripts
2025-11-12 10:22:42 -08:00
Lucas Armand a47c9d1ed0 remove test bugs 2025-11-11 18:13:46 -08:00
Lucas Armand 0b14562a63 dont exit on pyworker fail 2025-11-11 17:57:08 -08:00
Lucas Armand de9b50abb9 use set +e 2025-11-11 17:53:36 -08:00
Lucas Armand c510801723 fix 2025-11-11 17:49:34 -08:00
Lucas Armand a12523b1d2 Added bad code to tgi server to test 2025-11-11 17:41:12 -08:00
Lucas Armand eedf81c0a3 Updated readme and .gitignore 2025-11-11 17:18:40 -08:00
Lucas Armand 3adec1826d minor changes 2025-11-11 17:11:38 -08:00
Lucas Armand b55bfa9611 Updated clients, include vastai-sdk, handle non-UTF-8 2025-11-11 17:09:28 -08:00
LucasArmandVast 7db54f3bd7 Merge pull request #55 from vast-ai/use-mtoken
Use mtoken
2025-11-10 11:54:04 -08:00
LucasArmandVast d63a060202 Merge pull request #56 from vast-ai/obfuscate-mtoken
Obfuscate mtoken in logs
2025-11-10 11:53:17 -08:00
Lucas Armand c6521cb6d4 add ... 2025-11-07 10:10:35 -08:00
Lucas Armand b7fe4ebb91 Obfuscate mtoken in logs 2025-11-07 10:02:39 -08:00
Lucas Armand 8ae7b74605 bump version to 0.2.0 2025-11-05 13:32:21 -08:00
Lucas Armand 106067d716 bump version to 0.1.1 2025-11-04 17:15:59 -08:00
Lucas Armand f5134d4bf5 Fix spelling mistake 2025-11-04 16:59:39 -08:00
Lucas Armand 47e5460532 added mtoken 2025-11-04 15:55:14 -08:00
Colter-Downing ec2ac0a21a Merge pull request #52 from vast-ai/remove-sleeps-and-delays
Remove sleeps and delays
2025-10-30 11:53:39 -07:00
Abiola Akinnubi 2cde573c56 Merge pull request #48 from vast-ai/comfy-request-idx
Added request_idx to comfy auth_data
2025-10-30 11:27:35 -07:00
Abiola Akinnubi b2e4a5db0c Merge pull request #49 from vast-ai/unsecure_report_addr
Added caller for REPORT_ADDR to backend.py to use the report add
2025-10-30 10:39:46 -07:00
Abiola Akinnubi 7437028cb2 Added caller for REPORT_ADDR to backend.py 2025-10-29 18:02:17 -07:00
edgaratvast 02c8307af7 remove redis pubsub from pyworker (#53)
Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-29 17:07:56 -07:00
Colter Downing 7c0f316eeb leave the env vars alone! 2025-10-29 11:36:46 -07:00
Colter Downing b4025a744f remove env var writing 2025-10-29 09:58:09 -07:00
Colter Downing d190308329 removed 5 sec sleep and warmup request on load 2025-10-29 09:57:46 -07:00
LucasArmandVast 9f5a432513 Merge pull request #51 from vast-ai/delete-reqs-hotfix
Redis subscriber queue patch
2025-10-28 16:07:28 -07:00
Lucas Armand e09f1fa953 patch for redis queue 2025-10-28 16:03:50 -07:00
edgaratvast ba6f1c2e4b Fix signature (#50)
* change order of fields in auth_data to match autoscaler for signature verification

* also ignore __request_id

* Revert "change order of fields in auth_data to match autoscaler for signature verification" so that it's alphabetical again

This reverts commit b8223879c9.

* enforce alphabetical json dumping of message for signature verification

---------

Co-authored-by: Edgar Lin <edgarlin2000@gmail.com>
2025-10-28 16:01:32 -07:00
Abiola Akinnubi 944f83fc03 Removed extra spaces from operator assignment 2025-10-28 21:03:52 +00:00
edgaratvast 298590fb88 Merge pull request #45 from vast-ai/new-pyworker
New PyWorker
2025-10-28 14:02:53 -07:00
Lucas Armand 814c3acd4c remove unused code 2025-10-28 13:43:57 -07:00
Lucas Armand 22bca74087 Prevent load time race 2025-10-27 18:25:21 -07:00
Lucas Armand 9c795e2a01 removed bad code 2025-10-27 17:03:13 -07:00
Lucas Armand 830b532781 Trying unified delete 2025-10-27 16:57:52 -07:00
LucasArmandVast d6a6e34c6b Merge branch 'main' into new-pyworker 2025-10-27 12:43:49 -07:00
Colter-Downing ac1e109c48 Merge pull request #47 from vast-ai/new-pyworker-vllm-prefix-cache
vLLM Prefix caching, benchmark bug fix, test load script
2025-10-27 12:30:34 -07:00
Colter Downing d6eb498ee4 catch the case where all benchmarks fail (sets error) 2025-10-27 12:01:55 -07:00
Abiola Akinnubi f56bbc0ebe Added request_idx to comfy auth_data 2025-10-27 03:17:06 +00:00
Colter Downing bcecd6df40 Suppress matplot debug logs 2025-10-25 16:18:02 -07:00
Lucas Armand 4d9bf2048c Fix 2025-10-24 15:44:38 -07:00
Lucas Armand 7788bc4a62 Added some debug logs 2025-10-24 15:41:00 -07:00
Lucas Armand 37ad3f8d46 asyncio in metrics 2025-10-23 10:18:31 -07:00
Rob Ballantyne 70d51bafe1 Merge pull request #36 from robballantyne/feat/comfyui-json-benchmark-workflow-from-file 2025-10-23 17:05:48 +01:00
Rob Ballantyne 63909736bb Merge pull request #4 from robballantyne/feat/comfyui-json-benchmark-workflow-from-file-no-silent-fail
Feat/comfyui json benchmark workflow from file no silent fail
2025-10-23 17:02:12 +01:00
Rob Ballantyne f4f7080df1 Re-add comment 2025-10-23 17:00:28 +01:00
Rob Ballantyne d51a338e8f log when benchmark file not used 2025-10-23 16:41:02 +01:00
Rob Ballantyne 92a04bd7af No silent fail if benchmark file is missing 2025-10-23 13:41:03 +01:00
Lucas Armand 0f13506938 Send success param 2025-10-22 10:18:59 -07:00
Lucas Armand 01e752d31f use more asyncio sleep 2025-10-21 18:52:13 -07:00
Lucas Armand 5edfa968ca async sleep 2025-10-21 18:49:48 -07:00
Lucas Armand 5b5ef7227a nvm moved it here 2025-10-21 18:20:11 -07:00
Lucas Armand 16990ff8ff move start request 2025-10-21 18:18:44 -07:00
Lucas Armand 9748176366 fixed semaphore acquire bool 2025-10-21 18:12:23 -07:00
Lucas Armand b39193ae70 check for sem acquire 2025-10-21 18:02:14 -07:00
Lucas Armand 9a6ca5d412 added versioning 2025-10-21 15:42:43 -07:00
Lucas Armand e9ba1b03e4 Use delete_requests and track request_idxs 2025-10-21 11:59:35 -07:00
LucasArmandVast c98d661513 Merge pull request #39 from vast-ai/remove-time-divide
PyWorker fixes for cur_load and acks bug
2025-10-13 10:06:22 -07:00
Lucas Armand f6fd1c6ac1 merge 2025-10-09 18:15:55 -07:00
Lucas Armand 055e346c8c Send metrics on request start 2025-10-09 10:13:50 -07:00
Lucas Armand 1cedb28acf Removed division by elapsed time, since autoscaler cur_load in units of workload 2025-10-08 16:54:18 -07:00
Rob Ballantyne ec25dda3ad Merge branch 'vast-ai:main' into feat/comfyui-json-benchmark-workflow-from-file 2025-10-08 14:49:32 +01:00
Colter-Downing 0397af719d Merge pull request #37 from robballantyne/bugfix/healthcheck-endpoint
Fix healthcheck endpoint URL

Tested and merged by Colter
2025-10-06 15:11:27 -07:00
Rob Ballantyne 4fdc314fd9 Fix healthcheck endpoint URL 2025-10-06 22:16:09 +01:00
Rob Ballantyne 3786cf978d Add awareness of errors thrown by the provisioning script 2025-10-05 23:14:59 +01:00
Rob Ballantyne a86d4bcf9c Import json 2025-10-05 23:05:33 +01:00
Rob Ballantyne e9b6a14a5e Import Path 2025-10-05 22:59:19 +01:00
Rob Ballantyne cadac033e1 Enables use of custom workflow for benchmarking
Retains existing method is misc/benchmark.json is nopt present
2025-10-05 22:53:22 +01:00
Colter-Downing 639d82f5b4 Merge pull request #35 from vast-ai/AUTO-664--Healthcheck-error
Fix healthcheck with separate session
2025-10-02 12:51:19 -07:00
Colter Downing 25db78e39d Fix healthcheck with separate session 2025-10-01 18:04:31 -07:00
Scott-Laytart 4e2f2311d0 Merge pull request #33 from vast-ai/comfy-blind-fix-override
undo the fix for comfy yesterday.
2025-09-03 11:50:07 -07:00
abiola-vastai 38782d89bc undo the fix for comfy yesterday. 2025-09-03 17:12:35 +00:00
Scott-Laytart 0185216ccb Merge pull request #32 from vast-ai/blindhotfix_comfy_ui_default_port
Blind hotfix to see if comfy UI default is needed. if it does work we…
2025-09-02 18:26:25 -07:00
abiola-vastai b20d9e714c Blind hotfix to see if comfy UI default is needed. if it does work we would revert back. 2025-09-03 01:20:09 +00:00
Rob Ballantyne b1eb65d75d Merge pull request #31 from vast-ai/bugfix/startup-script-20250901
Update uv venv creation command
2025-09-01 18:19:17 +01:00
Rob Ballantyne 1d09d7fe96 Update uv venv creation command 2025-09-01 16:55:20 +01:00
Colter-Downing 1b37054dec Merge pull request #28 from vast-ai/bugfix/backend-timeout-infinite
Bugfix/backend timeout infinite
2025-08-28 11:22:33 -07:00
Colter-Downing 1a1e4174b8 Merge pull request #29 from vast-ai/bugfix/comfyui-json-cost-fix
Set cost to 100
2025-08-28 11:22:21 -07:00
Rob Ballantyne 1e4fa87437 Prevent timeout and allow long running connections 2025-08-28 15:48:57 +01:00
Rob Ballantyne 4c5fa03c7b adds import for ClientTimeout 2025-08-27 20:54:27 +01:00
Rob Ballantyne a8fe74f771 Remove default 300s timeout 2025-08-27 18:34:45 +01:00
61 changed files with 3234 additions and 4441 deletions
+1
View File
@@ -3,3 +3,4 @@
__pycache__
bin/
lib64
.venv
+133 -70
View File
@@ -1,89 +1,152 @@
# Vast PyWorker
# Vast PyWorker Examples
Vast PyWorker is a Python web server designed to run alongside a LLM or image generation models running on vast,
enabling autoscaler integration.
It serves as the primary entry point for API requests, forwarding them to the model's API hosted on the
same instance. Additionally, it monitors performance metrics and estimates current workload based on factors
such as the number of tokens processed for LLMs or image resolution and steps for image generation models,
reporting these metrics to the autoscaler.
This repository contains **example PyWorkers** used by Vast.ais default Serverless templates (e.g., vLLM, TGI, ComfyUI, Wan, ACE). A PyWorker is a lightweight Python HTTP proxy that runs alongside your model server and:
- Exposes one or more HTTP routes (e.g., `/v1/completions`, `/generate/sync`)
- Optionally validates/transforms request payloads
- Computes per-request **workload** for autoscaling
- Forwards requests to the local model server
- Optionally supports FIFO queueing when the backend cannot process concurrent requests
- Detects readiness/failure from model logs and runs a benchmark to estimate throughput
> Important: The **core PyWorker framework** (Worker, WorkerConfig, HandlerConfig, BenchmarkConfig, LogActionConfig) is provided by the **`vastai` / `vastai-sdk`** Python package (https://github.com/vast-ai/vast-sdk). This repo focuses on *worker implementations and examples*, not the framework internals.
## Repository Purpose
Use this repository as:
- A reference for how Vast templates wire up `worker.py`
- A starting point for implementing your own custom Serverless PyWorker
- A collection of working examples for common model backends
If you are looking for the framework code itself, refer to the Vast.ai SDK.
## Project Structure
* `lib/`: Contains the core PyWorker framework code (server logic, data types, metrics).
* `workers/`: Contains specific implementations (PyWorkers) for different model servers. Each subdirectory represents a worker for a particular model type.
Typical layout:
## Getting Started
- `workers/`
- Example worker implementations (each worker is usually a self-contained folder)
- Each example typically includes:
- `worker.py` (the entrypoint used by Serverless)
- Optional sample workflows / payloads (for ComfyUI-based workers)
- Optional local test harness scripts
1. **Install Dependencies:**
```bash
pip install -r requirements.txt
```
You may also need `pyright` for type checking:
```bash
sudo npm install -g pyright
# or use your preferred method to install pyright
```
## How Serverless launches worker.py
2. **Configure Environment:** Set any necessary environment variables (e.g., `MODEL_LOG` path, API keys if needed by your worker).
On each worker instance, the templates startup script typically:
3. **Run the Server:** Use the provided script. You'll need to specify which worker to run.
```bash
# Example for hello_world worker (assuming MODEL_LOG is set)
./start_server.sh workers.hello_world.server
```
Replace `workers.hello_world.server` with the path to the `server.py` module of the worker you want to run.
1. Clones your repository from `PYWORKER_REPO`
2. Installs dependencies from `requirements.txt`
3. Starts the **model server** (vLLM, TGI, ComfyUI, etc.)
4. Runs:
```bash
python worker.py
```
## How to Use
Your `worker.py` builds a `WorkerConfig`, constructs a `Worker`, and starts the PyWorker HTTP server.
### Using Existing Workers
## worker.py
If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few:
A PyWorker is usually a single `worker.py` that uses SDK configuration objects:
* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00)
* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447)
```python
from vastai import (
Worker,
WorkerConfig,
HandlerConfig,
BenchmarkConfig,
LogActionConfig,
)
Currently available workers:
* `hello_world`: A simple example worker for a basic LLM server.
* `comfyui`: A worker for the ComfyUI image generation backend.
* `tgi`: A worker for the Text Generation Inference backend.
worker_config = WorkerConfig(
model_server_url="http://127.0.0.1",
model_server_port=18000,
model_log_file="/var/log/model/server.log",
handlers=[
HandlerConfig(
route="/v1/completions",
allow_parallel_requests=True,
max_queue_time=60.0,
workload_calculator=lambda payload: float(payload.get("max_tokens", 0)),
benchmark_config=BenchmarkConfig(
generator=lambda: {"prompt": "hello", "max_tokens": 128},
runs=8,
concurrency=10,
),
)
],
log_action_config=LogActionConfig(
on_load=["Application startup complete."],
on_error=["Traceback (most recent call last):", "RuntimeError:"],
on_info=['"message":"Download'],
),
)
### Implementing a New Worker
To integrate PyWorker with a model server not already supported, you need to create a new worker implementation under the `workers/` directory. Follow these general steps:
1. **Create Worker Directory:** Add a new directory under `workers/` (e.g., `workers/my_model/`).
2. **Define Data Types (`data_types.py`):**
* Create a class inheriting from `lib.data_types.ApiPayload`.
* Implement methods like `for_test`, `generate_payload_json`, `count_workload`, and `from_json_msg` to handle request data, testing, and workload calculation specific to your model's API.
3. **Implement Endpoint Handlers (`server.py`):**
* For each model API endpoint you want PyWorker to proxy, create a class inheriting from `lib.data_types.EndpointHandler`.
* Implement methods like `endpoint`, `payload_cls`, `generate_payload_json`, `make_benchmark_payload` (for one handler), and `generate_client_response`.
* Instantiate `lib.backend.Backend` with your model server details, log file path, benchmark handler, and log actions.
* Define `aiohttp` routes, mapping paths to your handlers using `backend.create_handler()`.
* Use `lib.server.start_server` to run the application.
4. **Add `__init__.py`:** Create an empty `__init__.py` file in your worker directory.
5. **(Optional) Add Load Testing (`test_load.py`):** Create a script using `lib.test_harness.run` to test your worker against a Vast.ai endpoint group.
6. **(Optional) Add Client Example (`client.py`):** Provide a script demonstrating how to call your worker's endpoints.
**For a detailed walkthrough, refer to the `hello_world` example:** [workers/hello_world/README.md](workers/hello_world/README.md)
**Type Hinting:** It is strongly recommended to use strict type hinting throughout your implementation. Use `pyright` to check for type errors.
## Testing Your Worker
If you implement a `test_load.py` script for your worker, you can use it to load test a Vast.ai endpoint group running your instance image.
```bash
# Example for hello_world worker
python3 -m workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"
Worker(worker_config).run()
```
Replace `workers.hello_world.test_load` with the path to your worker's test script and provide your Vast.ai API Key (`-k`) and the target Endpoint Group Name (`-e`). Adjust the number of requests (`-n`) and requests per second (`-rps`) as needed.
## Included Examples
This repository contains example PyWorkers corresponding to common Vast templates, including:
- **vLLM**: OpenAI-compatible completions/chat endpoints with parallel request support
- **TGI (Text Generation Inference)**: OpenAI-compatible endpoints and log-based readiness
- **ComfyUI (Image / JSON workflows)**: `/generate/sync` for ComfyUI workflow execution
- **ComfyUI Wan 2.2 (T2V)**: ComfyUI workflow execution producing video outputs
- **ComfyUI ACE Step (Text-to-Music)**: ComfyUI workflow execution producing audio outputs
Exact worker paths and naming may vary by template; use the `workers/` directory as the source of truth.
## Getting Started (Local)
1. Install Python dependencies for the examples you plan to run:
```bash
pip install -r requirements.txt
```
2. Start your model server locally (vLLM, TGI, ComfyUI, etc.) and ensure:
- You know the model server URL/port
- You have a log file path you can tail for readiness/error detection
3. Run the worker:
```bash
python worker.py
```
or, if running an example from a subfolder:
```bash
python workers/<example>/worker.py
```
> Note: Many examples assume they are running inside Vast templates (ports, log paths, model locations). You may need to adjust `model_server_port` and `model_log_file` for local usage.
## Deploying on Vast Serverless
To use a custom PyWorker with Serverless:
1. Create a public Git repository containing:
- `worker.py`
- `requirements.txt`
2. In your Serverless template / endpoint configuration, set:
- `PYWORKER_REPO` to your Git repository URL
- (Optional) `PYWORKER_REF` to a git ref (branch, tag, or commit)
3. The template startup script will clone/install and run your `worker.py`.
## Guidance for Custom Workers
When implementing your own worker:
- Define one `HandlerConfig` per route you want to expose.
- Choose a workload function that correlates with compute cost:
- LLMs: prompt tokens + max output tokens (or `max_tokens` as a simpler proxy)
- Non-LLMs: constant cost per request (e.g., `100.0`) is often sufficient
- Set `allow_parallel_requests=False` for backends that cannot handle concurrency (e.g., many ComfyUI deployments).
- Configure exactly **one** `BenchmarkConfig` across all handlers to enable capacity estimation.
- Use `LogActionConfig` to reliably detect “model loaded” and “fatal error” log lines.
## Community & Support
Join the conversation and get help:
* **Vast.ai Discord:** [https://discord.gg/Pa9M29FFye](https://discord.gg/Pa9M29FFye)
* **Vast.ai Subreddit:** [https://reddit.com/r/vastai/](https://reddit.com/r/vastai/)
- Vast.ai Discord: https://discord.gg/Pa9M29FFye
- Vast.ai Subreddit: https://reddit.com/r/vastai/
-381
View File
@@ -1,381 +0,0 @@
import os
import json
import time
import base64
import subprocess
import dataclasses
import logging
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property
from distutils.util import strtobool
from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
import requests
from Crypto.Signature import pkcs1_15
from Crypto.Hash import SHA256
from Crypto.PublicKey import RSA
from lib.metrics import Metrics
from lib.data_types import (
AuthData,
EndpointHandler,
LogAction,
ApiPayload_T,
JsonDataException,
)
MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
# defines the minimum wait time between sending updates to autoscaler
LOG_POLL_INTERVAL = 0.1
BENCHMARK_INDICATOR_FILE = ".has_benchmark"
MAX_PUBKEY_FETCH_ATTEMPTS = 3
@dataclasses.dataclass
class Backend:
"""
This class is responsible for:
1. Tailing logs and updating load time metrics
2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and
sending the request. It also updates metrics as it makes those requests.
3. Running a benchmark from an EndpointHandler
"""
model_server_url: str
model_log_file: str
allow_parallel_requests: bool
benchmark_handler: (
EndpointHandler # this endpoint handler will be used for benchmarking
)
log_actions: List[Tuple[LogAction, str]]
reqnum = -1
msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
def __post_init__(self):
self.metrics = Metrics()
self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False
@property
def pubkey(self) -> Optional[RSA.RsaKey]:
if self._pubkey is None:
self._pubkey = self._fetch_pubkey()
return self._pubkey
@cached_property
def session(self):
log.debug(f"starting session with {self.model_server_url}")
return ClientSession(self.model_server_url)
def create_handler(
self,
handler: EndpointHandler[ApiPayload_T],
) -> Callable[[web.Request], Awaitable[Union[web.Response, web.StreamResponse]]]:
async def handler_fn(
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
return await self.__handle_request(handler=handler, request=request)
return handler_fn
#######################################Private#######################################
def _fetch_pubkey(self):
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:")
log.debug(result)
key = None
for _ in range(5):
try:
key = RSA.import_key(result)
break
except ValueError as e:
log.debug(f"Error downloading key: {e}")
time.sleep(15)
if key is None:
self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey")
return key
async def __handle_request(
self,
handler: EndpointHandler[ApiPayload_T],
request: web.Request,
) -> Union[web.Response, web.StreamResponse]:
"""use this function to forward requests to the model endpoint"""
try:
data = await request.json()
auth_data, payload = handler.get_data_from_request(data)
except JsonDataException as e:
return web.json_response(data=e.message, status=422)
except json.JSONDecodeError:
return web.json_response(dict(error="invalid JSON"), status=422)
workload = payload.count_workload()
async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
self.metrics._request_canceled(workload=workload)
return web.Response(status=500)
async def make_request() -> Union[web.Response, web.StreamResponse]:
log.debug(f"got request, {auth_data.reqnum}")
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{auth_data.reqnum}")
await self.sem.acquire()
log.debug(
f"Sem acquired for reqnum:{auth_data.reqnum}, starting request..."
)
else:
log.debug(f"Starting request for reqnum:{auth_data.reqnum}")
try:
response = await self.__call_api(handler=handler, payload=payload)
status_code = response.status
log.debug(
" ".join(
[
f"request with reqnum:{auth_data.reqnum}",
f"returned status code: {status_code},",
]
)
)
res = await handler.generate_client_response(request, response)
self.metrics._request_success(workload=workload)
return res
except requests.exceptions.RequestException as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(workload=workload)
return web.Response(status=500)
finally:
self.metrics._request_end(
workload=workload,
reqnum=auth_data.reqnum,
)
self.sem.release()
###########
if self.__check_signature(auth_data) is False:
return web.Response(status=401)
try:
done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
[task.cancel() for task in pending]
return done.pop().result()
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
async def __healthcheck(self):
health_check_url = self.benchmark_handler.healthcheck_endpoint
if health_check_url is None:
log.debug("No healthcheck endpoint defined, skipping healthcheck")
return
while True:
await sleep(10)
if self.__start_healthcheck is False:
continue
try:
log.debug(f"Performing healthcheck on {health_check_url}")
async with self.session.get(health_check_url) as response:
if response.status == 200:
log.debug("Healthcheck successful")
elif response.status == 503:
log.debug(f"Healthcheck failed with status: {response.status}")
self.backend_errored(
f"Healthcheck failed with status: {response.status}"
)
else:
# endpoint not ready yet so bail
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
except Exception as e:
log.debug(f"Healthcheck failed with exception: {e}")
self.backend_errored(str(e))
async def _start_tracking(self) -> None:
await gather(
self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck()
)
def backend_errored(self, msg: str) -> None:
self.metrics._model_errored(msg)
async def __call_api(
self, handler: EndpointHandler[ApiPayload_T], payload: ApiPayload_T
) -> ClientResponse:
api_payload = payload.generate_payload_json()
log.debug(f"posting to endpoint: '{handler.endpoint}', payload: {api_payload}")
return await self.session.post(url=handler.endpoint, json=api_payload)
def __check_signature(self, auth_data: AuthData) -> bool:
if self.unsecured is True:
return True
def verify_signature(message, signature):
if self.pubkey is None:
log.debug(f"No Public Key!")
return False
h = SHA256.new(message.encode())
try:
pkcs1_15.new(self.pubkey).verify(h, base64.b64decode(signature))
return True
except (ValueError, TypeError):
return False
message = {
key: value
for (key, value) in (dataclasses.asdict(auth_data).items())
if key != "signature"
}
if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
log.debug(
f"reqnum failure, got {auth_data.reqnum}, current_reqnum: {self.reqnum}"
)
return False
elif message in self.msg_history:
log.debug(f"message: {message} already in message history")
return False
elif verify_signature(json.dumps(message, indent=4), auth_data.signature):
self.reqnum = max(auth_data.reqnum, self.reqnum)
self.msg_history.append(message)
self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
return True
else:
log.debug(
f"signature verification failed, sig:{auth_data.signature}, message: {message}"
)
return False
async def __read_logs(self) -> Awaitable[NoReturn]:
async def run_benchmark() -> float:
log.debug("starting benchmark")
try:
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark")
# trigger model load
payload = self.benchmark_handler.make_benchmark_payload()
_ = await self.__call_api(
handler=self.benchmark_handler, payload=payload
)
return float(f.readline())
except FileNotFoundError:
pass
log.debug("Initial run to trigger model loading...")
payload = self.benchmark_handler.make_benchmark_payload()
await self.__call_api(handler=self.benchmark_handler, payload=payload)
max_throughput = 0
sum_throughput = 0
concurrent_requests = 10 if self.allow_parallel_requests else 1
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
start = time.time()
tasks = []
total_workload = 0
for _ in range(concurrent_requests):
payload = self.benchmark_handler.make_benchmark_payload()
total_workload += payload.count_workload()
tasks.append(
self.__call_api(handler=self.benchmark_handler, payload=payload)
)
responses = await gather(*tasks)
time_elapsed = time.time() - start
throughput = total_workload / time_elapsed
sum_throughput += throughput
max_throughput = max(max_throughput, throughput)
# Log results for debugging
log.debug(
"\n".join(
[
"#" * 60,
f"Run: {run}, concurrent_requests: {concurrent_requests}",
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
f"Throughput: {throughput} workload/s",
f"Successful responses: {len([r for r in responses if r.status == 200])}",
"#" * 60,
]
)
)
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
log.debug(
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
)
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
f.write(str(max_throughput))
return max_throughput
async def handle_log_line(log_line: str) -> None:
"""
Implement this function to handle each log line for your model.
This function should mutate self.system_metrics and self.model_metrics
"""
for action, msg in self.log_actions:
match action:
case LogAction.ModelLoaded if msg in log_line:
log.debug(
f"Got log line indicating model is loaded: {log_line}"
)
# some backends need a few seconds after logging successful startup before
# they can begin accepting requests
await sleep(5)
try:
max_throughput = await run_benchmark()
self.__start_healthcheck = True
self.metrics._model_loaded(
max_throughput=max_throughput,
)
except ClientConnectorError as e:
log.debug(
f"failed to connect to comfyui api during benchmark"
)
self.backend_errored(str(e))
case LogAction.ModelError if msg in log_line:
log.debug(f"Got log line indicating error: {log_line}")
self.backend_errored(msg)
break
case LogAction.Info if msg in log_line:
log.debug(f"Info from model logs: {log_line}")
async def tail_log():
log.debug(f"tailing file: {self.model_log_file}")
async with await open_file(self.model_log_file) as f:
while True:
line = await f.readline()
if line:
await handle_log_line(line.rstrip())
else:
time.sleep(LOG_POLL_INTERVAL)
###########
while True:
if os.path.isfile(self.model_log_file) is True:
return await tail_log()
else:
await sleep(1)
-283
View File
@@ -1,283 +0,0 @@
import time
import logging
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
from aiohttp import web, ClientResponse
import inspect
import psutil
"""
type variable representing an incoming payload to pyworker that will used to calculate load and will then
be forwarded to the model
"""
log = logging.getLogger(__file__)
class JsonDataException(Exception):
def __init__(self, json_msg: Dict[str, Any]):
self.message = json_msg
ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload")
@dataclass
class ApiPayload(ABC):
@classmethod
@abstractmethod
def for_test(cls: Type[ApiPayload_T]) -> ApiPayload_T:
"""defines how create a payload for load testing"""
pass
@abstractmethod
def generate_payload_json(self) -> Dict[str, Any]:
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
pass
@abstractmethod
def count_workload(self) -> float:
"""defines how to calculate workload for a payload"""
pass
@classmethod
@abstractmethod
def from_json_msg(
cls: Type[ApiPayload_T], json_msg: Dict[str, Any]
) -> ApiPayload_T:
"""
defines how to create an API payload from a JSON message,
it should throw an JsonDataException if there are issues with some fields
or they are missing in the format of
{
"field": "error msg"
}
"""
pass
@dataclass
class AuthData:
"""data used to authenticate requester"""
signature: str
cost: str
endpoint: str
reqnum: int
url: str
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]):
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
@dataclass
class EndpointHandler(ABC, Generic[ApiPayload_T]):
"""
Each model endpoint will have a handler responsible for counting workload from the incoming ApiPayload
and converting it to json to be forwarded to model API
"""
benchmark_runs: int = 8
benchmark_words: int = 100
@property
@abstractmethod
def endpoint(self) -> str:
"""the endpoint on the model API"""
pass
@property
@abstractmethod
def healthcheck_endpoint(self) -> Optional[str]:
"""the endpoint on the model API that is used for healthchecks"""
pass
@classmethod
@abstractmethod
def payload_cls(cls) -> Type[ApiPayload_T]:
"""ApiPayload class"""
pass
@abstractmethod
def make_benchmark_payload(self) -> ApiPayload_T:
"""defines how to create an ApiPayload for benchmarking."""
pass
@abstractmethod
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
"""
defines how to convert a model API response to a response to PyWorker client
"""
pass
@classmethod
def get_data_from_request(
cls, req_data: Dict[str, Any]
) -> Tuple[AuthData, ApiPayload_T]:
errors = {}
auth_data: Optional[AuthData] = None
payload: Optional[ApiPayload_T] = None
try:
if "auth_data" in req_data:
auth_data = AuthData.from_json_msg(req_data["auth_data"])
else:
errors["auth_data"] = "field missing"
except JsonDataException as e:
errors["auth_data"] = e.message
try:
if "payload" in req_data:
payload_cls = cls.payload_cls()
payload = payload_cls.from_json_msg(req_data["payload"])
else:
errors["payload"] = "field missing"
except JsonDataException as e:
errors["payload"] = e.message
if errors:
raise JsonDataException(errors)
if auth_data and payload:
return (auth_data, payload)
else:
raise Exception("error deserializing request data")
@dataclass
class SystemMetrics:
"""General system metrics"""
model_loading_start: float
model_loading_time: Union[float, None]
last_disk_usage: float
additional_disk_usage: float
model_is_loaded: bool
@staticmethod
def get_disk_usage_GB():
return psutil.disk_usage("/").used / (2**30) # want units of GB
@classmethod
def empty(cls):
return cls(
model_loading_start=time.time(),
model_loading_time=None,
last_disk_usage=SystemMetrics.get_disk_usage_GB(),
additional_disk_usage=0.0,
model_is_loaded=False,
)
def update_disk_usage(self):
disk_usage = SystemMetrics.get_disk_usage_GB()
self.additional_disk_usage = disk_usage - self.last_disk_usage
self.last_disk_usage = disk_usage
def reset(self):
# autoscaler excepts model_loading_time to be populated only once, when the instance has
# finished benchmarking and is ready to receive requests. This applies to restarted instances
# as well: they should send model_loading_time once when they are done loading
self.model_loading_time = None
@dataclass
class ModelMetrics:
"""Model specific metrics"""
# these are reset after being sent to autoscaler
workload_served: float
workload_received: float
workload_cancelled: float
workload_errored: float
# these are not
workload_pending: float
error_msg: Optional[str]
max_throughput: float
requests_recieved: Set[int] = field(default_factory=set)
requests_working: Set[int] = field(default_factory=set)
last_update: float = field(default_factory=time.time)
@classmethod
def empty(cls):
return cls(
workload_pending=0.0,
workload_served=0.0,
workload_cancelled=0.0,
workload_errored=0.0,
workload_received=0.0,
error_msg=None,
max_throughput=0.0,
)
@property
def cur_perf(self) -> float:
return max(self.workload_served / (time.time() - self.last_update), 0.0)
@property
def workload_processing(self) -> float:
return max(self.workload_received - self.workload_cancelled, 0.0)
def set_errored(self, error_msg):
self.reset()
self.error_msg = error_msg
def reset(self):
self.workload_served = 0
self.workload_received = 0
self.workload_cancelled = 0
self.workload_errored = 0
self.last_update = time.time()
@dataclass
class AutoScalaerData:
"""Data that is reported to autoscaler"""
id: int
loadtime: float
cur_load: float
error_msg: str
max_perf: float
cur_perf: float
cur_capacity: float
max_capacity: float
num_requests_working: int
num_requests_recieved: int
additional_disk_usage: float
url: str
class LogAction(Enum):
"""
These actions tell the backend what a log value means, for example:
actions [
# this marks the model server as loaded
(LogAction.ModelLoaded, "Starting server"),
# these mark the model server as errored
(LogAction.ModelError, "Exception loading model"),
(LogAction.ModelError, "Server failed to bind to port"),
# this tells the backend to print any logs containing the string into its own logs
# which are visible in the vast console instance logs
(LogAction.Info, "Starting model download"),
]
"""
ModelLoaded = 1
ModelError = 2
Info = 3
-155
View File
@@ -1,155 +0,0 @@
import os
import time
import logging
import json
from asyncio import sleep
from dataclasses import dataclass, asdict, field
from functools import cache
import requests
from lib.data_types import AutoScalaerData, SystemMetrics, ModelMetrics
from typing import Awaitable, NoReturn, List
METRICS_UPDATE_INTERVAL = 1
log = logging.getLogger(__file__)
@cache
def get_url() -> str:
use_ssl = os.environ.get("USE_SSL", "false") == "true"
worker_port = os.environ[f"VAST_TCP_PORT_{os.environ['WORKER_PORT']}"]
public_ip = os.environ["PUBLIC_IPADDR"]
return f"http{'s' if use_ssl else ''}://{public_ip}:{worker_port}"
@dataclass
class Metrics:
last_metric_update: float = 0.0
update_pending: bool = False
id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"]))
report_addr: List[str] = field(
default_factory=lambda: os.environ["REPORT_ADDR"].split(",")
)
url: str = field(default_factory=get_url)
system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty)
model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty)
def _request_start(self, workload: float, reqnum: int) -> None:
"""
this function is called prior to forwarding a request to a model API.
"""
log.debug("request start")
self.model_metrics.workload_pending += workload
self.model_metrics.workload_received += workload
self.model_metrics.requests_recieved.add(reqnum)
self.model_metrics.requests_working.add(reqnum)
def _request_end(self, workload: float, reqnum: int) -> None:
"""
this function is called after handling of a request ends, regardless of the outcome
"""
self.model_metrics.workload_pending -= workload
self.model_metrics.requests_working.discard(reqnum)
def _request_success(self, workload: float) -> None:
"""
this function is called after a response from model API is received and forwarded.
"""
self.model_metrics.workload_served += workload
self.update_pending = True
def _request_errored(self, workload: float) -> None:
"""
this function is called if model API returns an error
"""
self.model_metrics.workload_errored += workload
def _request_canceled(self, workload: float) -> None:
"""
this function is called if client drops connection before model API has responded
"""
self.model_metrics.workload_cancelled += workload
async def _send_metrics_loop(self) -> Awaitable[NoReturn]:
while True:
await sleep(METRICS_UPDATE_INTERVAL)
elapsed = time.time() - self.last_metric_update
if self.system_metrics.model_is_loaded is False and elapsed >= 10:
log.debug(f"sending loading model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed)
elif self.update_pending or elapsed > 10:
log.debug(f"sending loaded model metrics after {int(elapsed)}s wait")
self.__send_metrics_and_reset(elapsed)
def _model_loaded(self, max_throughput: float) -> None:
self.system_metrics.model_loading_time = (
time.time() - self.system_metrics.model_loading_start
)
self.system_metrics.model_is_loaded = True
self.model_metrics.max_throughput = max_throughput
def _model_errored(self, error_msg: str) -> None:
self.model_metrics.set_errored(error_msg)
self.system_metrics.model_is_loaded = True
#######################################Private#######################################
def __send_metrics_and_reset(self, elapsed):
def compute_autoscaler_data() -> AutoScalaerData:
return AutoScalaerData(
id=self.id,
loadtime=(self.system_metrics.model_loading_time or 0.0),
cur_load=(self.model_metrics.workload_processing / elapsed),
max_perf=self.model_metrics.max_throughput,
cur_perf=self.model_metrics.cur_perf,
error_msg=self.model_metrics.error_msg or "",
num_requests_working=len(self.model_metrics.requests_working),
num_requests_recieved=len(self.model_metrics.requests_recieved),
additional_disk_usage=self.system_metrics.additional_disk_usage,
cur_capacity=0,
max_capacity=0,
url=self.url,
)
def send_data(report_addr: str) -> bool:
data = compute_autoscaler_data()
full_path = report_addr.rstrip("/") + "/worker_status/"
log.debug(
"\n".join(
[
"#" * 60,
f"sending data to autoscaler",
f"{json.dumps((asdict(data)), indent=2)}",
"#" * 60,
]
)
)
for attempt in range(1, 4):
try:
res = requests.post(full_path, json=asdict(data), timeout=1)
res.raise_for_status()
return True
except requests.Timeout:
log.debug(f"autoscaler status update timed out")
except Exception as e:
log.debug(f"autoscaler status update failed with error: {e}")
time.sleep(2)
log.debug(f"retrying autoscaler status update, attempt: {attempt}")
log.debug(f"failed to send update through {report_addr}")
return False
###########
self.system_metrics.update_disk_usage()
for report_addr in self.report_addr:
success = send_data(report_addr)
if success is True:
break
self.update_pending = False
self.model_metrics.reset()
self.system_metrics.reset()
self.last_metric_update = time.time()
-40
View File
@@ -1,40 +0,0 @@
import os
import logging
from typing import List
import ssl
from asyncio import run, gather
from lib.backend import Backend
from aiohttp import web
log = logging.getLogger(__file__)
def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
log.debug("getting certificate...")
use_ssl = os.environ.get("USE_SSL", "false") == "true"
if use_ssl is True:
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(
certfile="/etc/instance.crt",
keyfile="/etc/instance.key",
)
else:
ssl_context = None
async def main():
log.debug("starting server...")
app = web.Application()
app.add_routes(routes)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(
runner,
ssl_context=ssl_context,
port=int(os.environ["WORKER_PORT"]),
**kwargs
)
await gather(site.start(), backend._start_tracking())
run(main())
-310
View File
@@ -1,310 +0,0 @@
import logging
import os
import time
import argparse
from typing import Callable, List, Dict, Tuple, Dict, Any, Type
from time import sleep
import threading
from enum import Enum
from collections import Counter
from dataclasses import dataclass, field, asdict
from urllib.parse import urljoin
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
import requests
from lib.data_types import AuthData, ApiPayload
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
class ClientStatus(Enum):
FetchEndpoint = 1
Generating = 2
Done = 3
Error = 4
total_success = 0
last_res = []
stop_event = threading.Event()
start_time = time.time()
test_args = argparse.ArgumentParser(description="Test inference endpoint")
test_args.add_argument(
"-k", dest="api_key", type=str, required=True, help="Your vast account API key"
)
test_args.add_argument(
"-e",
dest="endpoint_group_name",
type=str,
required=True,
help="Endpoint group name",
)
test_args.add_argument(
"-l",
dest="server_url",
action="store_const",
const="http://localhost:8081",
default="https://run.vast.ai",
help="Call local autoscaler instead of prod, for dev use only",
)
test_args.add_argument(
"-i",
dest="instance",
type=str,
default="prod",
help="Autoscaler shard to run the command against, default: prod",
)
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
def print_truncate_res(res: str):
if len(res) > 150:
print(f"{res[:50]}....{res[-100:]}")
else:
print(res)
@dataclass
class ClientState:
endpoint_group_name: str
api_key: str
server_url: str
worker_endpoint: str
instance: str
payload: ApiPayload
url: str = ""
status: ClientStatus = ClientStatus.FetchEndpoint
as_error: List[str] = field(default_factory=list)
infer_error: List[str] = field(default_factory=list)
conn_errors: Counter = field(default_factory=Counter)
def make_call(self):
self.status = ClientStatus.FetchEndpoint
if not self.api_key:
self.as_error.append(
f"Endpoint {self.endpoint_group_name} not found for API key",
)
self.status = ClientStatus.Error
return
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.api_key,
"cost": self.payload.count_workload(),
}
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
headers=headers,
timeout=4,
)
if response.status_code != 200:
self.as_error.append(
f"code: {response.status_code}, body: {response.text}",
)
self.status = ClientStatus.Error
return
message = response.json()
worker_address = message["url"]
req_data = dict(
payload=asdict(self.payload),
auth_data=asdict(AuthData.from_json_msg(message)),
)
self.url = worker_address
url = urljoin(worker_address, self.worker_endpoint)
self.status = ClientStatus.Generating
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
if response.status_code != 200:
self.infer_error.append(
f"code: {response.status_code}, body: {response.text}, url: {url}",
)
self.status = ClientStatus.Error
return
res = str(response.json())
global total_success
global last_res
total_success += 1
last_res.append(res)
self.status = ClientStatus.Done
def simulate_user(self) -> None:
try:
self.make_call()
except Exception as e:
print(e)
self.status = ClientStatus.Error
_ = e
self.conn_errors[self.url] += 1
def print_state(clients: List[ClientState], num_clients: int) -> None:
print("starting up...")
sleep(2)
center_size = 14
global start_time
while len(clients) < num_clients or (
any(
map(
lambda client: client.status
in [ClientStatus.FetchEndpoint, ClientStatus.Generating],
clients,
)
)
):
sleep(0.5)
os.system("clear")
print(
" | ".join(
[member.name.center(center_size) for member in ClientStatus]
+ [
item.center(center_size)
for item in [
"urls",
"as_error",
"infer_error",
"conn_error",
"total_success",
]
]
)
)
unique_urls = len(set([c.url for c in clients if c.url != ""]))
as_errors = sum(
map(
lambda client: len(client.as_error),
[client for client in clients],
)
)
infer_errors = sum(
map(
lambda client: len(client.infer_error),
[client for client in clients],
)
)
conn_errors = sum([client.conn_errors for client in clients], start=Counter())
conn_errors_str = ",".join(map(str, conn_errors.values())) or "0"
elapsed = time.time() - start_time
print(
" | ".join(
map(
lambda item: str(item).center(center_size),
[
len(list(filter(lambda x: x.status == member, clients)))
for member in ClientStatus
]
+ [
unique_urls,
as_errors,
infer_errors,
conn_errors_str,
f"{total_success}({((total_success/elapsed) * 60):.2f}/minute)",
],
)
)
)
if conn_errors:
print("conn_errors:")
for url, count in conn_errors.items():
print(url.ljust(28), ": ", str(count))
elapsed = time.time() - start_time
print(f"\n elapsed: {int(elapsed // 60)}:{int(elapsed % 60)}")
if last_res:
for i, res in enumerate(last_res[-10:]):
print_truncate_res(f"res #{1+i+max(len(last_res )-10,0)}: {res}")
if stop_event.is_set():
print("\n### waiting for existing connections to close ###")
def run_test(
num_requests: int,
requests_per_second: int,
endpoint_group_name: str,
api_key: str,
server_url: str,
worker_endpoint: str,
payload_cls: Type[ApiPayload],
instance: str,
):
threads = []
clients = []
print_thread = threading.Thread(target=print_state, args=(clients, num_requests))
print_thread.daemon = True # makes threads get killed on program exit
print_thread.start()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
)
if not endpoint_api_key:
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
return
try:
for _ in range(num_requests):
client = ClientState(
endpoint_group_name=endpoint_group_name,
api_key=endpoint_api_key,
server_url=server_url,
worker_endpoint=worker_endpoint,
payload=payload_cls.for_test(),
instance=instance,
)
clients.append(client)
thread = threading.Thread(target=client.simulate_user, args=())
threads.append(thread)
thread.start()
sleep(1 / requests_per_second)
for thread in threads:
thread.join()
print("done spawning workers")
except KeyboardInterrupt:
stop_event.set()
def test_load_cmd(
payload_cls: Type[ApiPayload], endpoint: str, arg_parser: argparse.ArgumentParser
):
arg_parser.add_argument(
"-n",
dest="num_requests",
type=int,
required=True,
help="total number of requests",
)
arg_parser.add_argument(
"-rps",
dest="requests_per_second",
type=float,
required=True,
help="requests per second",
)
args = arg_parser.parse_args()
if hasattr(args, "comfy_model"):
os.environ["COMFY_MODEL"] = args.comfy_model
server_url = dict(
prod="https://run.vast.ai",
alpha="https://run-alpha.vast.ai",
candidate="https://run-candidate.vast.ai",
local="http://localhost:8080",
)[args.instance]
run_test(
num_requests=args.num_requests,
requests_per_second=args.requests_per_second,
api_key=args.api_key,
server_url=server_url,
endpoint_group_name=args.endpoint_group_name,
worker_endpoint=endpoint,
payload_cls=payload_cls,
instance=args.instance,
)
+22
View File
@@ -0,0 +1,22 @@
# Where did the PyWorker code go?
We have moved the PyWorker source code into the `vastai-sdk` Python SDK.
You can install it with
```
pip install vastai-sdk
```
All of the source code can be found here:
https://github.com/vast-ai/vast-sdk
And can be imported from vastai.serverless.server.lib
Serverless instances automatically run the start_server.sh script, which installs the vastai-sdk.
This is how the PyWorker source code makes it onto your serverless instances.
You provide a worker.py file in your PYWORKER_REPO, and the start_server.sh will
create and run a PyWorker according to your configuration defined in the file.
While you can still create and run PyWorkers for serverless using your old PyWorker code,
we **strongly** encourage you to use the new worker.py configuration method, since
we can guarantee backwards compatibility for all your worker definitions. No more forking pyworker :)
If you encounter and issues with using PyWorker, please create a GitHub issue and we will be happy to assist.
+2 -10
View File
@@ -1,10 +1,2 @@
aiohttp[speedups]==3.10.1
anyio~=4.4
lib~=4.0
nltk~=3.9
psutil~=6.0
pycryptodome~=3.20
Requests~=2.32
transformers~=4.52
utils==1.0.*
hf_transfer>=0.1.9
vastai-sdk>=0.3.0
nltk==3.9.4
+275 -39
View File
@@ -2,31 +2,96 @@
set -e -o pipefail
# Check for force update flag
FORCE_UPDATE=false
if [ -f "/.force_update" ]; then
echo "Force update flag detected at /.force_update"
FORCE_UPDATE=true
fi
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
ENV_PATH="$WORKSPACE_DIR/worker-env"
ENV_PATH="${ENV_PATH:-$WORKSPACE_DIR/worker-env}"
DEBUG_LOG="$WORKSPACE_DIR/debug.log"
PYWORKER_LOG="$WORKSPACE_DIR/pyworker.log"
REPORT_ADDR="${REPORT_ADDR:-https://cloud.vast.ai/api/v0,https://run.vast.ai}"
REPORT_ADDR="${REPORT_ADDR:-https://run.vast.ai}"
USE_SSL="${USE_SSL:-true}"
WORKER_PORT="${WORKER_PORT:-3000}"
mkdir -p "$WORKSPACE_DIR"
cd "$WORKSPACE_DIR"
# make all output go to $DEBUG_LOG and stdout without having to add `... | tee -a $DEBUG_LOG` to every command
exec &> >(tee -a "$DEBUG_LOG")
function echo_var(){
echo "$1: ${!1}"
}
[ -z "$BACKEND" ] && echo "BACKEND must be set!" && exit 1
[ -z "$MODEL_LOG" ] && echo "MODEL_LOG must be set!" && exit 1
[ -z "$HF_TOKEN" ] && echo "HF_TOKEN must be set!" && exit 1
[ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && echo "For comfyui backends, COMFY_MODEL must be set!" && exit 1
function report_error_and_exit(){
local error_msg="$1"
echo "ERROR: $error_msg"
MTOKEN="${MASTER_TOKEN:-}"
VERSION="${PYWORKER_VERSION:-0}"
IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}"
for addr in "${REPORT_ADDRS[@]}"; do
curl -sS -X POST -H 'Content-Type: application/json' \
-d "$(cat <<JSON
{
"id": ${CONTAINER_ID:-0},
"mtoken": "${MTOKEN}",
"version": "${VERSION}",
"error_msg": "${error_msg}",
"url": "${URL:-}"
}
JSON
)" "${addr%/}/worker_status/" || true
done
exit 1
}
function install_vastai_sdk() {
local uv_flags=()
if [ "${USE_SYSTEM_PYTHON:-}" = "true" ]; then
uv_flags+=(--system --break-system-packages)
fi
if [ "$FORCE_UPDATE" = true ]; then
uv_flags+=(--force-reinstall)
echo "Force reinstalling vastai"
fi
# If SDK_BRANCH is set, install vastai from the vast-cli repo at that branch/tag/commit.
if [ -n "${SDK_BRANCH:-}" ]; then
if [ -n "${SDK_VERSION:-}" ]; then
echo "WARNING: Both SDK_BRANCH and SDK_VERSION are set; using SDK_BRANCH=${SDK_BRANCH}"
fi
echo "Installing vastai from https://github.com/vast-ai/vast-cli/ @ ${SDK_BRANCH}"
if ! uv pip install "${uv_flags[@]}" "vastai @ git+https://github.com/vast-ai/vast-cli.git@${SDK_BRANCH}"; then
report_error_and_exit "Failed to install vastai from vast-ai/vast-cli@${SDK_BRANCH}"
fi
return 0
fi
if [ -n "${SDK_VERSION:-}" ]; then
echo "Installing vastai version ${SDK_VERSION}"
if ! uv pip install "${uv_flags[@]}" "vastai==${SDK_VERSION}"; then
report_error_and_exit "Failed to install vastai==${SDK_VERSION}"
fi
return 0
fi
echo "Installing default vastai"
if ! uv pip install "${uv_flags[@]}" vastai; then
report_error_and_exit "Failed to install vastai"
fi
}
[ -n "$BACKEND" ] && [ -z "$HF_TOKEN" ] && report_error_and_exit "HF_TOKEN must be set when BACKEND is set!"
[ -z "$CONTAINER_ID" ] && report_error_and_exit "CONTAINER_ID must be set!"
[ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && report_error_and_exit "For comfyui backends, COMFY_MODEL must be set!"
echo "start_server.sh"
date
@@ -41,47 +106,151 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG
echo_var MODEL_LOG
ROTATE_MODEL_LOG="${ROTATE_MODEL_LOG:-false}"
if [ "$ROTATE_MODEL_LOG" = "true" ] && [ -e "$MODEL_LOG" ]; then
echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old"
if ! cat "$MODEL_LOG" >> "$MODEL_LOG.old"; then
report_error_and_exit "Failed to rotate model log"
fi
if ! : > "$MODEL_LOG"; then
report_error_and_exit "Failed to truncate model log"
fi
fi
# Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
if ! env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
name=${line%%=*}
value=${line#*=}
printf '%s="%s"\n' "$name" "$value"
done > /etc/environment
done > /etc/environment; then
echo "WARNING: Failed to populate /etc/environment, continuing anyway"
fi
fi
if [ ! -d "$ENV_PATH" ]
then
if [ "${USE_SYSTEM_PYTHON:-}" = "true" ]; then
echo "Using system Python: $(which python3)"
if ! which uv > /dev/null 2>&1; then
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
report_error_and_exit "Failed to install uv package manager"
fi
if [[ -f ~/.local/bin/env ]]; then
if ! source ~/.local/bin/env; then
report_error_and_exit "Failed to source uv environment"
fi
fi
fi
install_vastai_sdk
touch ~/.no_auto_tmux
elif [ ! -d "$ENV_PATH" ]; then
echo "setting up venv"
if ! which uv; then
curl -LsSf https://astral.sh/uv/install.sh | sh
source ~/.local/bin/env
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
report_error_and_exit "Failed to install uv package manager"
fi
if [[ -f ~/.local/bin/env ]]; then
if ! source ~/.local/bin/env; then
report_error_and_exit "Failed to source uv environment"
fi
else
echo "WARNING: ~/.local/bin/env not found after uv installation"
fi
fi
# Fork testing
git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
if [[ ! -d $SERVER_DIR ]]; then
if ! git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"; then
report_error_and_exit "Failed to clone pyworker repository"
fi
elif [ "$FORCE_UPDATE" = true ]; then
echo "Force updating pyworker repository"
if ! (cd "$SERVER_DIR" && git fetch --all); then
report_error_and_exit "Failed to fetch pyworker repository updates"
fi
fi
if [[ -n ${PYWORKER_REF:-} ]]; then
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
if [ "$FORCE_UPDATE" = true ]; then
echo "Force updating to pyworker reference: $PYWORKER_REF"
if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF" && git pull); then
report_error_and_exit "Failed to force update pyworker reference: $PYWORKER_REF"
fi
else
if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); then
report_error_and_exit "Failed to checkout pyworker reference: $PYWORKER_REF"
fi
fi
elif [ "$FORCE_UPDATE" = true ]; then
echo "Force updating pyworker to latest"
if ! (cd "$SERVER_DIR" && git pull); then
report_error_and_exit "Failed to pull latest pyworker changes"
fi
fi
uv venv --managed-python "$ENV_PATH" -p 3.10
source "$ENV_PATH/bin/activate"
if ! uv venv --python-preference only-managed "$ENV_PATH" -p 3.10; then
report_error_and_exit "Failed to create virtual environment"
fi
uv pip install -r "${SERVER_DIR}/requirements.txt"
if ! source "$ENV_PATH/bin/activate"; then
report_error_and_exit "Failed to activate virtual environment"
fi
touch ~/.no_auto_tmux
if ! uv pip install -r "${SERVER_DIR}/requirements.txt"; then
report_error_and_exit "Failed to install Python requirements"
fi
install_vastai_sdk
if ! touch ~/.no_auto_tmux; then
report_error_and_exit "Failed to create ~/.no_auto_tmux"
fi
else
[[ -f ~/.local/bin/env ]] && source ~/.local/bin/env
source "$WORKSPACE_DIR/worker-env/bin/activate"
if [[ -f ~/.local/bin/env ]]; then
if ! source ~/.local/bin/env; then
report_error_and_exit "Failed to source uv environment"
fi
fi
if ! source "$ENV_PATH/bin/activate"; then
report_error_and_exit "Failed to activate existing virtual environment"
fi
echo "environment activated"
echo "venv: $VIRTUAL_ENV"
# Handle force update for existing environment
if [ "$FORCE_UPDATE" = true ]; then
echo "Performing force update on existing environment"
if [[ -d $SERVER_DIR ]]; then
echo "Force updating pyworker repository"
if ! (cd "$SERVER_DIR" && git fetch --all); then
report_error_and_exit "Failed to fetch pyworker repository updates"
fi
if [[ -n ${PYWORKER_REF:-} ]]; then
echo "Force updating to pyworker reference: $PYWORKER_REF"
if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF" && git pull); then
report_error_and_exit "Failed to force update pyworker reference: $PYWORKER_REF"
fi
else
echo "Force updating pyworker to latest"
if ! (cd "$SERVER_DIR" && git pull); then
report_error_and_exit "Failed to pull latest pyworker changes"
fi
fi
fi
install_vastai_sdk
fi
fi
[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
# Remove force update flag after successful update
if [ "$FORCE_UPDATE" = true ]; then
echo "Removing force update flag"
rm -f "/.force_update"
echo "Force update completed successfully"
fi
if [ "$USE_SSL" = true ]; then
cat << EOF > /etc/openssl-san.cnf
if ! cat << EOF > /etc/openssl-san.cnf
[req]
default_bits = 2048
distinguished_name = req_distinguished_name
@@ -101,32 +270,99 @@ if [ "$USE_SSL" = true ]; then
[alt_names]
IP.1 = 0.0.0.0
EOF
then
report_error_and_exit "Failed to write OpenSSL config"
fi
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
if ! openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
-nodes \
-sha256 \
-keyout /etc/instance.key \
-out /etc/instance.csr \
-config /etc/openssl-san.cnf
-config /etc/openssl-san.cnf; then
report_error_and_exit "Failed to generate SSL certificate request"
fi
curl --header 'Content-Type: application/octet-stream' \
--data-binary @//etc/instance.csr \
-X \
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
max_retries=5
retry_delay=2
for attempt in $(seq 1 "$max_retries"); do
http_code=$(curl -sS -o /etc/instance.crt -w '%{http_code}' \
--header 'Content-Type: application/octet-stream' \
--data-binary @/etc/instance.csr \
-X POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID")
if [ "$http_code" -ge 200 ] && [ "$http_code" -lt 300 ]; then
break
fi
echo "SSL cert signing attempt $attempt/$max_retries failed (HTTP $http_code)"
if [ "$attempt" -eq "$max_retries" ]; then
report_error_and_exit "Failed to sign SSL certificate after $max_retries attempts (HTTP $http_code)"
fi
sleep "$retry_delay"
retry_delay=$((retry_delay * 2))
done
fi
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
cd "$SERVER_DIR"
# ─── SDK Deployment Mode ───────────────────────────────────────────────
if [ "$IS_DEPLOYMENT" = "true" ]; then
echo "=== SDK Deployment Mode ==="
echo "DEPLOYMENT_ID: $DEPLOYMENT_ID"
DEPLOY_DIR="/workspace/deployment"
mkdir -p "$DEPLOY_DIR"
VAST_API_BASE="${VAST_API_BASE:-https://console.vast.ai}"
# Download deployment code, retrying until the blob is available on S3.
# The s3_key exists in the DB as soon as the deployment is created, but the
# actual upload may still be in flight from the client side.
# Install SDK (uses the install_vastai_sdk function which supports SDK_BRANCH/SDK_VERSION)
install_vastai_sdk
# Run deployment in serve mode
export VAST_DEPLOYMENT_MODE=serve
echo "Starting deployment: python3 $DEPLOY_DIR/deployment.py"
serve-vast-deployment
exit $?
fi
# ─── End SDK Deployment Mode ───────────────────────────────────────────
if ! cd "$SERVER_DIR"; then
report_error_and_exit "Failed to cd into SERVER_DIR: $SERVER_DIR"
fi
echo "launching PyWorker server"
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
set +e
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
echo "launching PyWorker server done"
PY_STATUS=1
if [ -f "$SERVER_DIR/worker.py" ]; then
echo "Running worker.py"
python3 -m "worker" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
fi
if [ "${PY_STATUS}" -ne 0 ] && [ -f "$SERVER_DIR/workers/$BACKEND/worker.py" ]; then
echo "Running workers.${BACKEND}.worker"
python3 -m "workers.${BACKEND}.worker" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
fi
if [ "${PY_STATUS}" -ne 0 ] && [ -f "$SERVER_DIR/workers/$BACKEND/server.py" ]; then
echo "Running workers.${BACKEND}.server"
python3 -m "workers.${BACKEND}.server" |& tee -a "$PYWORKER_LOG"
PY_STATUS=${PIPESTATUS[0]}
fi
set -e
if [ "${PY_STATUS}" -ne 0 ]; then
if [ ! -f "$SERVER_DIR/worker.py" ] && [ ! -f "$SERVER_DIR/workers/$BACKEND/worker.py" ] && [ ! -f "$SERVER_DIR/workers/$BACKEND/server.py" ]; then
report_error_and_exit "Failed to find PyWorker"
fi
report_error_and_exit "PyWorker exited with status ${PY_STATUS}"
fi
echo "PyWorker bootstrap complete"
-98
View File
@@ -1,98 +0,0 @@
import logging
from typing import Any, Dict, Optional
import requests
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
class Endpoint:
"""
Utility class for handling endpoint operations.
"""
@staticmethod
def get_autoscaler_server_url(instance: str) -> str:
endpoints = {
"alpha": "run-alpha",
"candidate": "run-candidate",
"prod": "run",
}
return f"https://{endpoints[instance]}.vast.ai/"
@staticmethod
def get_server_url(instance: str) -> str:
endpoints = {
"alpha": "alpha",
"candidate": "candidate",
"prod": "console",
}
return f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
@staticmethod
def get_endpoint_api_key(
endpoint_name: str, account_api_key: str, instance: str
) -> Optional[str]:
"""
Fetch endpoint API key from VastAI console following the healthcheck pattern.
Args:
endpoint_name: Name of the endpoint
account_api_key: Account API key for authentication
Returns:
Endpoint API key if successful, None otherwise
"""
headers = {"Authorization": f"Bearer {account_api_key}"}
try:
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
response = requests.get(
f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}",
headers=headers,
)
if response.status_code != 200:
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
log.debug(error_msg)
return None
try:
data = response.json()
except requests.exceptions.JSONDecodeError as e:
log.debug(f"Failed to parse JSON response: {e}")
return None
result = data.get("results", [])
endpoint: Optional[Dict[str, Any]] = next(
(item for item in result if item["endpoint_name"] == endpoint_name),
None,
)
if not endpoint:
error_msg = f"Endpoint '{endpoint_name}' not found."
log.debug(error_msg)
return None
endpoint_api_key = endpoint.get("api_key")
if not endpoint_api_key:
error_msg = f"API key for endpoint '{endpoint_name}' not found."
log.debug(error_msg)
return None
log.debug(f"Successfully retrieved API key for endpoint: {endpoint_name}")
return endpoint_api_key
except requests.exceptions.RequestException as e:
error_msg = f"Request error while fetching endpoint API key: {e}"
log.debug(error_msg)
return None
except Exception as e:
error_msg = f"Unexpected error while fetching endpoint API key: {e}"
log.debug(error_msg)
return None
-15
View File
@@ -1,15 +0,0 @@
import tempfile
from functools import cache
import requests
@cache
def get_cert_file_path():
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
response = requests.get(cert_url)
response.raise_for_status()
# Use a temporary file that is not deleted on close
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
f.write(response.content)
return f.name
+168
View File
@@ -0,0 +1,168 @@
# ComfyUI ACE Step PyWorker
This is the PyWorker implementation for running **ACE Step v1 3.5B** text-to-music workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI audio-generation workflows through a proxy-based architecture and returning generated audio assets.
Each request has a static cost of `1000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
## Requirements
This worker requires the following components:
- ComfyUI (https://github.com/comfyanonymous/ComfyUI)
- ComfyUI API Wrapper (https://github.com/ai-dock/comfyui-api-wrapper)
- ACE Step v1 3.5B model and required custom nodes
A Docker image is provided with the ACE Step model pre-installed, but any image may be used if the above requirements are met.
## Endpoint
The worker exposes a single synchronous endpoint:
- `/generate/sync`: Processes a complete ComfyUI workflow JSON and generates audio output
## Request Format
The ACE Step worker **only supports custom workflow mode**. Modifier-based workflows are not supported.
```json
{
"input": {
"request_id": "uuid-string",
"workflow_json": {
// Complete ComfyUI ACE Step workflow JSON
},
"s3": { },
"webhook": { }
}
}
```
## Request Fields
### Required Fields
- `input`: Container for all request parameters
- `input.workflow_json`: Complete ComfyUI workflow graph for ACE Step audio generation
### Optional Fields
- `input.request_id`: Client-defined request identifier
- `input.s3`: S3-compatible storage configuration
- `input.webhook`: Webhook configuration for completion notifications
The special string `"__RANDOM_INT__"` may be used in the workflow JSON and will be replaced with a random integer before submission to ComfyUI.
## S3 Configuration
Generated audio assets can be automatically uploaded to S3-compatible storage. Configuration can be supplied per request or via environment variables. Request-level values take precedence.
### Via Request JSON
```json
"s3": {
"access_key_id": "your-s3-access-key",
"secret_access_key": "your-s3-secret-access-key",
"endpoint_url": "https://s3.amazonaws.com",
"bucket_name": "your-bucket",
"region": "us-east-1"
}
```
### Via Environment Variables
```bash
S3_ACCESS_KEY_ID=your-key
S3_SECRET_ACCESS_KEY=your-secret
S3_BUCKET_NAME=your-bucket
S3_ENDPOINT_URL=https://s3.amazonaws.com
S3_REGION=us-east-1
```
## Webhook Configuration
Webhooks are triggered on request completion or failure.
### Via Request JSON
```json
"webhook": {
"url": "https://your-webhook-url",
"extra_params": {
"custom_field": "value"
}
}
```
### Via Environment Variables
```bash
WEBHOOK_URL=https://your-webhook-url
WEBHOOK_TIMEOUT=30
```
## Example Request
### ACE Step Text-to-Music Workflow
```json
{
"input": {
"workflow_json": {
"14": {
"inputs": {
"tags": "funk, pop, upbeat, 105 BPM",
"lyrics": "Turn it up and let it flow",
"lyrics_strength": 0.99,
"clip": ["40", 1]
},
"class_type": "TextEncodeAceStepAudio"
},
"17": {
"inputs": {
"seconds": 180,
"batch_size": 1
},
"class_type": "EmptyAceStepLatentAudio"
},
"40": {
"inputs": {
"ckpt_name": "ace_step_v1_3.5b.safetensors"
},
"class_type": "CheckpointLoaderSimple"
}
}
}
}
```
## Response Format
A successful response includes execution metadata, ComfyUI output details, and generated audio assets.
### Response Fields
- `id`: Unique request identifier
- `status`: `completed`, `failed`, `processing`, `generating`, or `queued`
- `message`: Human-readable status message
- `comfyui_response`: Raw response from ComfyUI, including execution status and progress
- `output`: Array of generated outputs
- `timings`: Timing information for the request
### Output Object
Each entry in `output` includes:
- `filename`: Generated file name (e.g., `.mp3`)
- `local_path`: File path on the worker
- `url`: Pre-signed download URL (if S3 is configured)
- `type`: Output type (`output`)
- `subfolder`: Output directory (e.g., `audio`)
- `node_id`: ComfyUI node that produced the output
- `output_type`: Output category (e.g., `audio`)
## Notes and Limitations
- Only full ComfyUI workflow JSONs are supported
- Concurrent requests are not supported per worker
- ACE Step model must be installed before processing requests
- Audio generation duration and runtime depend on workflow configuration
+149
View File
@@ -0,0 +1,149 @@
from vastai import Serverless
import asyncio
async def main():
async with Serverless() as client:
endpoint = await client.get_endpoint(name="my-ace-endpoint")
# ComfyUI API compatible json workflow for ACE Step
workflow = {
"14": {
"inputs": {
"tags": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic",
"lyrics": "[verse]\nNeon lights they flicker bright\nCity hums in dead of night\nRhythms pulse through concrete veins\nLost in echoes of refrains\n\n[verse]\nBassline groovin in my chest\nHeartbeats match the citys zest\nElectric whispers fill the air\nSynthesized dreams everywhere\n\n[chorus]\nTurn it up and let it flow\nFeel the fire let it grow\nIn this rhythm we belong\nHear the night sing out our song",
"lyrics_strength": 0.99,
"clip": ["40", 1]
},
"class_type": "TextEncodeAceStepAudio",
"_meta": {
"title": "TextEncodeAceStepAudio"
}
},
"17": {
"inputs": {
"seconds": 180,
"batch_size": 1
},
"class_type": "EmptyAceStepLatentAudio",
"_meta": {
"title": "EmptyAceStepLatentAudio"
}
},
"18": {
"inputs": {
"samples": ["52", 0],
"vae": ["40", 2]
},
"class_type": "VAEDecodeAudio",
"_meta": {
"title": "VAE Decode Audio"
}
},
"40": {
"inputs": {
"ckpt_name": "ace_step_v1_3.5b.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"44": {
"inputs": {
"conditioning": ["14", 0]
},
"class_type": "ConditioningZeroOut",
"_meta": {
"title": "ConditioningZeroOut"
}
},
"49": {
"inputs": {
"model": ["51", 0],
"operation": ["50", 0]
},
"class_type": "LatentApplyOperationCFG",
"_meta": {
"title": "LatentApplyOperationCFG"
}
},
"50": {
"inputs": {
"multiplier": 1.15
},
"class_type": "LatentOperationTonemapReinhard",
"_meta": {
"title": "LatentOperationTonemapReinhard"
}
},
"51": {
"inputs": {
"shift": 6,
"model": ["40", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"52": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 65,
"cfg": 4,
"sampler_name": "er_sde",
"scheduler": "linear_quadratic",
"denoise": 1,
"model": ["49", 0],
"positive": ["14", 0],
"negative": ["44", 0],
"latent_image": ["17", 0]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"59": {
"inputs": {
"filename_prefix": "audio/ComfyUI",
"quality": "V0",
"audioUI": "",
"audio": ["18", 0]
},
"class_type": "SaveAudioMP3",
"_meta": {
"title": "Save Audio (MP3)"
}
}
}
payload = {
"input": {
"request_id": "",
"workflow_json": workflow,
"s3": {
"access_key_id": "",
"secret_access_key": "",
"endpoint_url": "",
"bucket_name": "",
"region": ""
},
"webhook": {
"url": "",
"extra_params": {
"user_id": "12345",
"project_id": "abc-def"
}
}
}
}
response = await endpoint.request("/generate/sync", payload)
# Response contains status, output, and any errors
print(response["response"])
if __name__ == "__main__":
asyncio.run(main())
+184
View File
@@ -0,0 +1,184 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_lyrics = [
"[verse]\nGuardian cloaked in twilight hue\nShadows melt where he breaks through\nEchoes swirl in mystic flight\nHooded hero owns the night\n\n[verse]\nThrough the chaos shapes arise\nFeral whispers, glowing eyes\nOrcs and creatures side by side\nMarch within the inky tide\n\n[chorus]\nRise above the fear and gloom\nLet your courage fully bloom\nIn the darkness stand your ground\nHear the night proclaim your sound",
"[verse]\nMorning sun on fields of gold\nGentle stories unfold\nEvery breeze a quiet song\nWhere the peaceful hearts belong\n\n[verse]\nLanterns glow at stable doors\nRustling leaves on orchard floors\nSimple joys in every hand\nLife grows soft in fertile land\n\n[chorus]\nLet the day drift slow and free\nRoot your soul where you can be\nIn this haven warm and bright\nFeel the earth breathe pure delight",
"[verse]\nLittle feet on dusty ground\nChasing dreams without a sound\nSoccer ball in morning light\nHopes take wing in youthful flight\n\n[verse]\nChrome reflections paint the day\nSwagger in the steps that play\nCopper tones in shining air\nChildhood gleaming everywhere\n\n[chorus]\nKick the world with boundless cheer\nHold the magic close and near\nIn each moment bold and true\nLet the sky belong to you",
"[verse]\nSunset bleeds across the street\nGilded calm in summer heat\nLow-rise towers rimmed with fire\nDreams ignite as lights climb higher\n\n[verse]\nFootsteps scatter through the haze\nFutures shimmer in the blaze\nEvery window tells a tale\nFloating through a tangerine veil\n\n[chorus]\nLet the neon softly glow\nLet your restless heartbeat slow\nIn this city forged in light\nCarry hope into the night",
"[verse]\nOcean breathes in rolling arcs\nSprays of diamond, glowing sparks\nWaves unfold a perfect line\nNatures rhythm feels divine\n\n[verse]\nSun above in golden sweep\nPaints the rise of every deep\nShimmer drifting through the blue\nWorld reborn in every view\n\n[chorus]\nLet the tide pull you along\nHear the waters ancient song\nIn the cresting waves youll find\nQuiet peace for heart and mind",
"[verse]\nGlass aglow with swirling light\nFruits and mints in colors bright\nIcy whispers clink and chime\nFlowing forms suspend in time\n\n[verse]\nCreamy spirals drift within\nGentle currents slowly spin\nWarm reflections lingering sweet\nMixing flavors at your feet\n\n[chorus]\nSip the glow and let it rise\nTaste the sunset in disguise\nIn this moment clear and true\nLet the warmth flow into you",
"[verse]\nEngines rumble down the lane\nCopper clouds of steam and rain\nOilpunk dreams in metal shine\nRider drifting down the line\n\n[verse]\nLeather jacket, steady glare\nStories sparking in the air\nMagazine lights frame his face\nKing of roads in timeless grace\n\n[chorus]\nThrottle up beyond the bend\nFeel the force of steel ascend\nRide the night and hold on tight\nClaim the world in streaks of light",
"[verse]\nCut-out shapes in swirling play\nTextures dance in bold array\nCats in denim, grinning wide\nStrut across the patterned tide\n\n[verse]\nPosters hum with neon glow\nSurreal scenes begin to grow\nColors crisp as folded art\nPatchwork beating like a heart\n\n[chorus]\nLet the collage come alive\nWatch the vibrant pieces thrive\nIn this joyful, crafted space\nEvery shape finds its own place",
"[verse]\nTiny world in crystal glass\nAncient tales behind the mass\nVillage lights in winter gleam\nFrozen in a mystic dream\n\n[verse]\nLantern beams in swirling air\nSoft enchantment everywhere\nShadows drift with gentle grace\nMagic sealed within the space\n\n[chorus]\nHold the sphere and you will see\nEchoes of a memory\nIn the glow of fragile light\nLives a realm of pure delight",
"[verse]\nArmor hums with power bright\nChopping sparks in jungle night\nMecha spirits shift and scream\nThrough the ferns like shattered beams\n\n[verse]\nAxes blaze in glowing arcs\nLighting up the shadowed marks\nNature roars in trembling air\nClash of steel and cosmic flare\n\n[chorus]\nRaise the fire, strike the ground\nLet your legend shake the sound\nIn the wild where echoes roam\nForge the fight and carve your home",
"[verse]\nCrowds ignite in vibrant flare\nBeats explode through smoky air\nDJ robes replaced with flame\nPope on decks in holy frame\n\n[verse]\nLeather gleams in blinding light\nTurntables spin with sacred might\nChoirs echo in the bass\nHeaven pulses through the place\n\n[chorus]\nLift the roof and shake the floor\nSacred rhythm evermore\nLet the music take control\nFeel the blessing in your soul",
]
benchmark_dataset = [
{
"input": {
"request_id": "",
"workflow_json": {
"14": {
"inputs": {
"tags": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic",
"lyrics": lyrics,
"lyrics_strength": 0.99,
"clip": ["40", 1]
},
"class_type": "TextEncodeAceStepAudio",
"_meta": {
"title": "TextEncodeAceStepAudio"
}
},
"17": {
"inputs": {
"seconds": 180,
"batch_size": 1
},
"class_type": "EmptyAceStepLatentAudio",
"_meta": {
"title": "EmptyAceStepLatentAudio"
}
},
"18": {
"inputs": {
"samples": ["52", 0],
"vae": ["40", 2]
},
"class_type": "VAEDecodeAudio",
"_meta": {
"title": "VAE Decode Audio"
}
},
"40": {
"inputs": {
"ckpt_name": "ace_step_v1_3.5b.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"44": {
"inputs": {
"conditioning": ["14", 0]
},
"class_type": "ConditioningZeroOut",
"_meta": {
"title": "ConditioningZeroOut"
}
},
"49": {
"inputs": {
"model": ["51", 0],
"operation": ["50", 0]
},
"class_type": "LatentApplyOperationCFG",
"_meta": {
"title": "LatentApplyOperationCFG"
}
},
"50": {
"inputs": {
"multiplier": 1.15
},
"class_type": "LatentOperationTonemapReinhard",
"_meta": {
"title": "LatentOperationTonemapReinhard"
}
},
"51": {
"inputs": {
"shift": 6,
"model": ["40", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"52": {
"inputs": {
"seed": "__RANDOM_INT__",
"steps": 65,
"cfg": 4,
"sampler_name": "er_sde",
"scheduler": "linear_quadratic",
"denoise": 1,
"model": ["49", 0],
"positive": ["14", 0],
"negative": ["44", 0],
"latent_image": ["17", 0]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"59": {
"inputs": {
"filename_prefix": "audio/ComfyUI",
"quality": "V0",
"audioUI": "",
"audio": ["18", 0]
},
"class_type": "SaveAudioMP3",
"_meta": {
"title": "Save Audio (MP3)"
}
}
}
}
} for lyrics in benchmark_lyrics
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
runs=1
),
workload_calculator= lambda _ : 1000.0
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+108 -6
View File
@@ -1,8 +1,16 @@
# ComfyUI PyWorker
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture.
This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
The cost for each request has a static value of `100`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node.
## Instance Setup
1. Pick a template
- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless))
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Requirements
@@ -10,11 +18,105 @@ This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) a
A docker image is provided but you may use any if the above requirements are met.
## Client
The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage.
### Setup
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
2. Set your API key:
```bash
export VAST_API_KEY=<your_api_key>
```
### Usage
```bash
# Default prompt
python -m workers.comfyui-json.client
# Custom prompt
python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow"
# With options
python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30
# Using a custom workflow file
python -m workers.comfyui-json.client --workflow my_workflow.json
# With S3 upload
python -m workers.comfyui-json.client --s3
```
### CLI Flags
| Flag | Default | Description |
|------|---------|-------------|
| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name |
| `--prompt` | (default) | Text prompt for image generation |
| `--workflow` | (none) | Path to custom workflow JSON file |
| `--width` | 512 | Image width in pixels |
| `--height` | 512 | Image height in pixels |
| `--steps` | 20 | Number of denoising steps |
| `--seed` | (random) | Random seed for reproducibility |
| `--s3` | (disabled) | Upload generated images to S3 |
### Output
Images are saved to `./generated_images/comfy_{seed}.png`.
### S3 Upload (Optional)
You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag.
**1. Set environment variables:**
```bash
export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com"
export S3_BUCKET_NAME="my-bucket"
export S3_ACCESS_KEY_ID="your-access-key-id"
export S3_SECRET_ACCESS_KEY="your-secret-access-key"
```
**2. Run with S3 upload enabled:**
```bash
python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3
```
Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`.
**Note:** Requires `boto3` (`pip install boto3`).
## Benchmarking
A simple image generation benchmark runs when each worker initializes to validate GPU performance and identify underperforming machines.
### Custom Benchmark Workflows
The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image workflow. Configure the benchmark complexity and duration using these variables:
You can provide a custom ComfyUI workflow for benchmarking by creating `workers/comfyui-json/misc/benchmark.json`. This allows you to test performance using your preferred models and workflow complexity.
**Ways to provide the benchmark file:**
- Fork this repository and add your `benchmark.json` file
- Write the file during worker provisioning (onstart script or setup phase)
An example file is provided in the repository. To ensure varied generations, use the placeholder `__RANDOM_INT__` in place of static seed values - it will be replaced with a random integer for each benchmark run.
### Default Benchmark (Fallback)
If `benchmark.json` is not available, a simple image generation benchmark runs when each worker initializes. This validates GPU performance and helps identify underperforming machines.
The default benchmark uses Stable Diffusion v1.5 with ComfyUI's standard text-to-image workflow. Configure it using these environment variables:
| Environment Variable | Default Value | Description |
| -------------------- | ------------- | ----------- |
@@ -24,7 +126,7 @@ The benchmark uses Stable Diffusion v1.5 with ComfyUI's default text-to-image wo
Each benchmark run uses a random prompt from `misc/test_prompts.txt` and a random seed to ensure consistent GPU load patterns.
### Calibrating Benchmark Duration
#### Calibrating Fallback Benchmark Duration
To screen for underperforming hardware, set `BENCHMARK_TEST_STEPS` to match your expected production workflow duration. This allows you to identify machines that won't meet performance requirements.
@@ -203,7 +305,7 @@ WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds
## Client Libraries
See the test client examples for implementation details on how to integrate with the ComfyUI worker.
See the client example for implementation details on how to integrate with the ComfyUI worker.
---
+287 -130
View File
@@ -1,155 +1,312 @@
import logging
import os
import sys
import json
import uuid
import random
from urllib.parse import urljoin
import json
import asyncio
import logging
import argparse
import aiohttp
import requests
from vastai import Serverless
from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types import count_workload
# ---------------------- Config ----------------------
DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed"
ENDPOINT_NAME = "my-comfyui-endpoint"
DEFAULT_WIDTH = 512
DEFAULT_HEIGHT = 512
DEFAULT_STEPS = 20
COST = 100 # Fixed cost for ComfyUI requests
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
# Optional S3 Configuration (from environment variables)
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
log = logging.getLogger(__name__)
def call_text2image_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
"""Simple Text2Image using the new modifier-based approach"""
def get_s3_client():
"""Create and return an S3 client configured for the S3-compatible endpoint"""
try:
import boto3
from botocore.config import Config
except ImportError:
log.error("boto3 is required for S3 uploads. Install with: pip install boto3")
return None
def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"):
"""Helper function for making requests with consistent error handling"""
try:
response = requests.post(
url,
json=payload,
timeout=timeout,
verify=verify
)
if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]):
log.error("S3 environment variables not fully configured. Required:")
log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY")
return None
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred during {context}: {http_err}")
log.error(f"Status Code: {response.status_code}")
log.error("Response content:", response.text)
return None
except requests.exceptions.Timeout:
log.error(f"Timeout occurred during {context}: {url}")
return None
except requests.exceptions.ConnectionError:
log.error(f"Connection error occurred during {context}: {url}")
return None
except json.JSONDecodeError as json_err:
log.error(f"Failed to decode JSON response during {context}: {json_err}")
if 'response' in locals():
print("Response content:", response.text)
return None
except Exception as err:
log.error(f"An unexpected error occurred during {context}: {err}")
if 'response' in locals():
log.error("Response content (if available):", response.text)
return None
WORKER_ENDPOINT = "/generate/sync"
# This worker has concurrency = 1. All workloads have cost value 1.0
COST = count_workload()
# Route to get worker URL
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
# First request - get routing information
route_response = make_request(
url=urljoin(server_url, "/route/"),
payload=route_payload,
timeout=4,
context="route request"
return boto3.client(
"s3",
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config=Config(signature_version="s3v4"),
)
if route_response is None:
return None
if "url" not in route_response or not route_response["url"]:
log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.")
return None
if "status" in route_response:
print(f"Autoscaler status: {route_response['status']}")
return None
# Extract data from route response
url = route_response["url"]
auth_data = dict(
signature=route_response["signature"],
cost=route_response["cost"],
endpoint=route_response["endpoint"],
reqnum=route_response["reqnum"],
url=route_response["url"],
)
# Build the payload for the worker request
worker_payload = {
# ---------------------- API Functions ----------------------
async def call_generate(
client: Serverless,
*,
endpoint_name: str,
prompt: str,
width: int,
height: int,
steps: int,
seed: int,
) -> dict:
"""Generate image using Text2Image modifier"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"request_id": str(uuid.uuid4()),
"modifier": "Text2Image",
"modifications": {
"prompt": "a beautiful landscape with mountains and lakes",
"width": 1024,
"height": 1024,
"steps": 20,
"seed": random.randint(0, 2**32 - 1)
"prompt": prompt,
"width": width,
"height": height,
"steps": steps,
"seed": seed,
},
"workflow_json": {} # Empty since using modifier approach
}
}
return await endpoint.request("/generate/sync", payload, cost=COST)
req_data = dict(payload=worker_payload, auth_data=auth_data)
worker_url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {worker_url}")
# Second request - call the worker endpoint
worker_response = make_request(
url=worker_url,
payload=req_data,
verify=get_cert_file_path(),
context="worker request"
)
async def call_generate_workflow(
client: Serverless,
*,
endpoint_name: str,
workflow_json: dict,
) -> dict:
"""Generate using custom workflow JSON"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"input": {
"request_id": str(uuid.uuid4()),
"workflow_json": workflow_json,
}
}
return await endpoint.request("/generate/sync", payload, cost=COST)
return worker_response
# ---------------------- Demo Class ----------------------
class APIDemo:
def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False):
self.client = client
self.endpoint_name = endpoint_name
self.upload_s3 = upload_s3
self.s3_client = get_s3_client() if upload_s3 else None
if upload_s3 and not self.s3_client:
log.warning("S3 upload requested but client creation failed. Images will only be saved locally.")
def extract_filename(self, response: dict) -> str | None:
"""Extract the generated image filename from ComfyUI response"""
if "comfyui_response" in response:
for data in response["comfyui_response"].values():
if isinstance(data, dict) and "outputs" in data:
for node_output in data["outputs"].values():
if "images" in node_output and node_output["images"]:
return node_output["images"][0].get("filename")
return None
async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch and save image locally from the worker, optionally upload to S3"""
os.makedirs("generated_images", exist_ok=True)
return await self._fetch_image(worker_url, filename, local_name)
def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None:
"""Upload a local file to S3 and return the S3 URL"""
if not self.s3_client:
return None
try:
self.s3_client.upload_file(
local_path,
S3_BUCKET_NAME,
s3_key,
ExtraArgs={"ContentType": "image/png"}
)
s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}"
print(f" ☁️ Uploaded to S3: {s3_key}")
return s3_url
except Exception as e:
log.error(f"Failed to upload to S3: {e}")
return None
async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None:
"""Fetch image from worker's /view endpoint and save locally"""
if not worker_url:
return None
try:
url = f"{worker_url}/view"
params = {"filename": filename, "type": "output"}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, ssl=False) as resp:
if resp.status == 200:
path = f"generated_images/{local_name}"
image_data = await resp.read()
with open(path, "wb") as f:
f.write(image_data)
print(f" 💾 Saved: {path}")
# Upload to S3 if enabled
if self.upload_s3 and self.s3_client:
s3_key = f"comfyui/{local_name}"
self._upload_to_s3(path, s3_key)
return path
return None
except Exception:
return None
async def demo_prompt(
self,
prompt: str,
width: int,
height: int,
steps: int,
seed: int | None,
):
"""Demo: Generate image from text prompt"""
print("=" * 60)
print("COMFYUI TEXT-TO-IMAGE DEMO")
print("=" * 60)
if seed is None:
seed = random.randint(0, 2**32 - 1)
print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}")
print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}")
print("\n🎨 Generating image...")
response = await call_generate(
self.client,
endpoint_name=self.endpoint_name,
prompt=prompt,
width=width,
height=height,
steps=steps,
seed=seed,
)
print("\n✅ Generation complete!")
# Get worker URL for fetching images
worker_url = response.get("url", "")
print(f"Worker URL: {worker_url}")
# Fetch and save image
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, f"comfy_{seed}.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
async def demo_workflow(self, workflow_file: str):
"""Demo: Generate using custom workflow file"""
print("=" * 60)
print("COMFYUI CUSTOM WORKFLOW DEMO")
print("=" * 60)
if not os.path.exists(workflow_file):
log.error(f"Workflow file not found: {workflow_file}")
return
with open(workflow_file, "r") as f:
workflow_json = json.load(f)
print(f"Workflow: {workflow_file}")
print("\n🎨 Generating...")
response = await call_generate_workflow(
self.client,
endpoint_name=self.endpoint_name,
workflow_json=workflow_json,
)
print("\n✅ Generation complete!")
worker_url = response.get("url", "")
if "response" in response:
filename = self.extract_filename(response["response"])
if filename:
path = await self.save_image(worker_url, filename, "workflow.png")
if not path:
print(f"❌ Failed to fetch image")
else:
print("❌ No image in response")
else:
print("❌ Unexpected response format")
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT",
help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')")
p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead")
p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})")
p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})")
p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})")
p.add_argument("--seed", type=int, default=None, help="Seed (default: random)")
p.add_argument("--s3", action="store_true",
help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)")
return p
async def main_async():
args = build_arg_parser().parse_args()
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
if args.s3:
print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint, upload_s3=args.s3)
if args.workflow:
await demo.demo_workflow(workflow_file=args.workflow)
else:
await demo.demo_prompt(
prompt=args.prompt,
width=args.width,
height=args.height,
steps=args.steps,
seed=args.seed,
)
except AttributeError as e:
if "API key" in str(e):
log.error("API key missing. Set VAST_API_KEY environment variable.")
else:
log.error(f"Error: {e}")
sys.exit(1)
except Exception as e:
log.error(f"Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
result = call_text2image_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
if result is None:
log.error("Text2Image workflow failed")
else:
print(result)
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}")
asyncio.run(main_async())
-60
View File
@@ -1,60 +0,0 @@
import os
import sys
import random
import dataclasses
from typing import Dict, Any
from functools import cache
from math import ceil
from lib.data_types import ApiPayload, JsonDataException
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
def count_workload() -> float:
# Always 100.0 where there is a single instance of ComfyUI handling requests
# Results will indicate % or a job completed per second. Avoids sub 0.1 sec performance indication
return 100.0
@dataclasses.dataclass
class ComfyWorkflowData(ApiPayload):
input: dict
@classmethod
def for_test(cls):
"""
Use the variables available to simulate workflows of the required running time
Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090)
"""
test_prompt = random.choice(test_prompts).rstrip()
return cls(
input={
"request_id": f"test-{random.randint(1000, 99999)}",
"modifier": "Text2Image",
"modifications": {
"prompt": test_prompt,
"width": os.getenv('BENCHMARK_TEST_WIDTH', 512),
"height": os.getenv('BENCHMARK_TEST_HEIGHT', 512),
"steps": os.getenv('BENCHMARK_TEST_STEPS', 20),
"seed": random.randint(0, sys.maxsize),
}
}
)
def generate_payload_json(self) -> Dict[str, Any]:
# input is already a dict, just return it wrapped in the expected structure
return {"input": self.input}
def count_workload(self) -> float:
return count_workload()
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ComfyWorkflowData":
# Extract required fields
if "input" not in json_msg:
raise JsonDataException({"input": "missing parameter"})
return cls(
input=json_msg["input"]
)
+1
View File
@@ -0,0 +1 @@
# This folder is required for the provisioning scripts of ace and wan to complete.
@@ -1,34 +0,0 @@
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
stardew valley, fine details
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
realistic futuristic city-downtown with short buildings, sunset
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
Pope Francis wearing biker (leather jacket), a masterpiece
Luke Skywalker ordering a burger and fries from the Death Star canteen.
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
london luxurious interior living-room, light walls
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
the street of amedieval fantasy town, at dawn, dark, highly detailed
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
-116
View File
@@ -1,116 +0,0 @@
import os
import logging
import dataclasses
import base64
from typing import Optional, Union, Type
from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import ComfyWorkflowData
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288")
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: "
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: ", # This error is emitted when the model file is not there at all
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def generate_client_response(
client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
# Check if the response is actually streaming based on response headers/content-type
is_streaming_response = (
model_response.content_type == "text/event-stream"
or model_response.content_type == "application/x-ndjson"
or model_response.headers.get("Transfer-Encoding") == "chunked"
or "stream" in model_response.content_type.lower()
)
if is_streaming_response:
log.debug("Detected streaming response...")
res = web.StreamResponse()
res.content_type = model_response.content_type
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
else:
log.debug("Detected non-streaming response...")
content = await model_response.read()
return web.Response(
body=content,
status=model_response.status,
content_type=model_response.content_type
)
@dataclasses.dataclass
class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/generate/sync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return "/health"
@classmethod
def payload_cls(cls) -> Type[ComfyWorkflowData]:
return ComfyWorkflowData
def make_benchmark_payload(self) -> ComfyWorkflowData:
return ComfyWorkflowData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
return await generate_client_response(client_request, model_response)
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=False,
benchmark_handler=ComfyWorkflowHandler(
benchmark_runs=3, benchmark_words=100
),
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, "Downloading:"),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
-8
View File
@@ -1,8 +0,0 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import ComfyWorkflowData
WORKER_ENDPOINT = "/generate/sync"
if __name__ == "__main__":
test_load_cmd(ComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
+81
View File
@@ -0,0 +1,81 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_prompts = [
"Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.",
"Cozy farming-game scene with fine details.",
"2D vector child with soccer ball; airbrush chrome; swagger; antique copper.",
"Realistic futuristic downtown of low buildings at sunset.",
"Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.",
"Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.",
"Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.",
"Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.",
"Medieval village inside glass sphere; volumetric light; macro focus.",
"Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.",
"Pope Francis DJ in leather jacket, mixing on giant console; dramatic.",
]
benchmark_dataset = [
{
"input": {
"request_id": f"test-{random.randint(1000, 99999)}",
"modifier": "Text2Image",
"modifications": {
"prompt": prompt,
"width": 512,
"height": 512,
"steps": 20,
"seed": random.randint(0, sys.maxsize)
}
}
} for prompt in benchmark_prompts
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
)
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
-92
View File
@@ -1,92 +0,0 @@
This is the base PyWorker for comfyui. It can be used to create PyWorker that use various models and
workflows. It provides two endpoints:
1. `/prompt`: Uses the default comfy workflow defined under `misc/default_workflows`
2. `/custom_workflow`: Allows the client to send their own comfy workflow with each API request.
To use the comfyui PyWorker, `$COMFY_MODEL` env variable must be set in the template. Current options are
`sd3` and `flux`. Each have example clients.
To add new models, a JSON with name `$COMFY_MODEL.json` must be created under `misc/default_workflows`
NOTE: default workflows follow this format:
```json
{
"input": {
"handler": "RawWorkflow",
"aws_access_key_id": "your-s3-access-key",
"aws_secret_access_key": "your-s3-secret-access-key",
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
"aws_bucket_name": "your-bucket",
"webhook_url": "your-webhook-url",
"webhook_extra_params": {},
"workflow_json": {}
}
}
```
You can ignore all of these fields except for `workflow_json`.
Fields written as "{{FOO}}" will be replaced using data from a user request. For example, SD3's workflow has the
following nodes:
```json
"5": {
"inputs": {
"width": "{{WIDTH}}",
"height": "{{HEIGHT}}",
"batch_size": 1
},
"6": {
"inputs": {
"text": "{{PROMPT}}",
"clip": ["11", 0]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
...
"17": {
"inputs": {
"scheduler": "simple",
"steps": "{{STEPS}}",
"denoise": 1,
"model": ["12", 0]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
...
"25": {
"inputs": {
"noise_seed": "{{SEED}}"
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
}
```
Incoming requests have the following JSON format:
```json
{
prompt: str
width: int
height: int
steps: int
seed: int
}
```
Each value in those fields with replace the placeholder of the same name in the default workflow.
See Vast's serverless documentation for more details on how to use comfyui with autoscaler
-176
View File
@@ -1,176 +0,0 @@
import logging
from urllib.parse import urljoin
import requests
from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
"""
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
"""
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
def call_default_workflow(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/prompt"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status()
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(
prompt="a fat fluffy cat", width=1024, height=1024, steps=20, seed=123456789
)
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
print_truncate_res(str(response.json()))
def call_custom_workflow_for_sd3(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/custom-workflow"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status()
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
workflow = {
"3": {
"inputs": {
"seed": 156680208700286,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": ["4", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["5", 0],
},
"class_type": "KSampler",
},
"4": {
"inputs": {"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"},
"class_type": "CheckpointLoaderSimple",
},
"5": {
"inputs": {"width": 512, "height": 512, "batch_size": 1},
"class_type": "EmptyLatentImage",
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, purple galaxy bottle",
"clip": ["4", 1],
},
"class_type": "CLIPTextEncode",
},
"7": {
"inputs": {"text": "text, watermark", "clip": ["4", 1]},
"class_type": "CLIPTextEncode",
},
"8": {
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
"class_type": "VAEDecode",
},
"9": {
"inputs": {"filename_prefix": "ComfyUI", "images": ["8", 0]},
"class_type": "SaveImage",
},
}
# these values should match the values in the custom workflow above,
# they are used to calculate workload
custom_fields = dict(
steps=20,
width=512,
height=512,
)
req_data = dict(
payload=dict(custom_fields=custom_fields, workflow=workflow),
auth_data=auth_data,
)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
print_truncate_res(str(response.json()))
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
try:
call_default_workflow(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_custom_workflow_for_sd3(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
-205
View File
@@ -1,205 +0,0 @@
import sys
import os
import json
import random
import dataclasses
import inspect
from typing import Dict, Any
from functools import cache
from math import ceil
from enum import Enum
from lib.data_types import ApiPayload, JsonDataException
with open("workers/comfyui/misc/test_prompts.txt", "r") as f:
test_prompts = f.readlines()
class Model(Enum):
Flux = "flux"
Sd3 = "sd3"
def get_request_time(self) -> int:
match self:
case Model.Flux:
return 23
case Model.Sd3:
return 6
@cache
def get_model() -> Model:
match os.environ.get("COMFY_MODEL"):
case "flux":
return Model.Flux
case "sd3":
return Model.Sd3
case None:
raise Exception(
"For comfyui pyworker, $COMFY_MODEL must be set in the vast template"
)
case model:
raise Exception(f"Unsupported comfyui model: {model}")
@cache
def get_request_template() -> str:
with open(f"workers/comfyui/misc/default_workflows/{get_model().value}.json") as f:
return f.read()
def count_workload(width: int, height: int, steps: int) -> float:
"""
we want to normalize the workload is a number such that cur_perf(tokens/second) for 1024x1024 image with
28 steps is 200 tokens on a 4090.
in order get that we calculate the
A = ( absolute workload based on given data )
B = ( absolute workload for a 1024x1024 image with 28 steps )
and adjust the workload to 200 tokens by A/B.
we then adjust for difference between Flux and SD3 by multiplying this value by expected request time for a
standard image(23s for Flux, 6s for SD3).
On a 4090, this would give us a workload that would give a cur_perf(workload / request_time) of around 200
"""
def _calculate_absolute_tokens(width_: int, height_: int, steps_: int) -> float:
"""
This is based on how openai counts image generation tokens, see: https://openai.com/api/pricing/
we count how many 512x512 grids are needed to cover the image.
each tile is then counted as 175 tokens.
each image generation also has constant of 85 base tokens.
we then adjust the count based on the number of steps. The baseline number of steps is assumed to be 28.
Some testing with flux gave me this data:
steps(X) | request time(Y)
__________|_________________
07(0.25x) | 11s (0.47x)
14(0.50x) | 15s (0.65x)
21(0.75x) | 20s (0.86x)
28(1.00x) | 23s (1.00x)
35(1.25x) | 28s (1.21x)
42(1.50x) | 32s (1.39x)
49(1.75x) | 37s (1.60x)
this gives a linear regression of Y = 0.61*X + 6.57
we can use this as an adjustment_factor for token count
adjustment_factor = (0.61 * steps + 6.57)
"""
width_grids = ceil(width_ / 512)
height_grids = ceil(height_ / 512)
tokens = 85 + width_grids * height_grids * 175
adjustment_factor = 0.61 * steps_ + 6.57
return tokens * adjustment_factor
REQUEST_TIME_FOR_STANDARD_IMAGE = get_model().get_request_time()
absolute_tokens = _calculate_absolute_tokens(
width_=width, height_=height, steps_=steps
)
absolute_tokens_standard_image = _calculate_absolute_tokens(
width_=1024, height_=1024, steps_=28
)
return REQUEST_TIME_FOR_STANDARD_IMAGE * (
(absolute_tokens / absolute_tokens_standard_image) * 200
)
@dataclasses.dataclass
class DefaultComfyWorkflowData(ApiPayload):
prompt: str
width: int
height: int
steps: int
seed: int
@classmethod
def for_test(cls):
test_prompt = random.choice(test_prompts).rstrip()
return cls(
prompt=test_prompt,
width=1024,
height=1024,
steps=28,
seed=random.randint(0, sys.maxsize),
)
def generate_payload_json(
self,
) -> Dict[str, Any]:
return json.loads(
get_request_template()
.replace("{{PROMPT}}", self.prompt)
# these values should be of int type. Since "{{VAR}}" is wrapped with " in the template
# to make the JSON valid, we must replace the double quotes. i.e. "{{WIDTH}}" -> 1024 and not "1024"
.replace('"{{WIDTH}}"', str(self.width))
.replace('"{{HEIGHT}}"', str(self.height))
.replace('"{{STEPS}}"', str(self.steps))
.replace('"{{SEED}}"', str(self.seed))
)
def count_workload(self) -> float:
return count_workload(width=self.width, height=self.height, steps=self.steps)
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "DefaultComfyWorkflowData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
@dataclasses.dataclass
class CustomComfyWorkflowData(ApiPayload):
custom_fields: Dict[str, int]
workflow: Dict[str, Any]
@classmethod
def for_test(cls):
raise NotImplementedError("Custom comfy workflow is not used for testing")
def count_workload(self) -> float:
return count_workload(
width=int(self.custom_fields.get("width", 1024)),
height=int(self.custom_fields.get("height", 1024)),
steps=int(self.custom_fields.get("steps", 28)),
)
def generate_payload_json(self) -> Dict[str, Any]:
template_json = json.loads(get_request_template())
template_json["input"]["workflow_json"] = self.workflow
return template_json
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "CustomComfyWorkflowData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
@@ -1,137 +0,0 @@
{
"input": {
"handler": "RawWorkflow",
"aws_access_key_id": "your-s3-access-key",
"aws_secret_access_key": "your-s3-secret-access-key",
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
"aws_bucket_name": "your-bucket",
"webhook_url": "your-webhook-url",
"webhook_extra_params": {},
"workflow_json": {
"5": {
"inputs": {
"width": "{{WIDTH}}",
"height": "{{HEIGHT}}",
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "{{PROMPT}}",
"clip": ["11", 0]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": ["13", 0],
"vae": ["10", 0]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": ["8", 0]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"10": {
"inputs": {
"vae_name": "ae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"11": {
"inputs": {
"clip_name1": "t5xxl_fp16.safetensors",
"clip_name2": "clip_l.safetensors",
"type": "flux"
},
"class_type": "DualCLIPLoader",
"_meta": {
"title": "DualCLIPLoader"
}
},
"12": {
"inputs": {
"unet_name": "flux1-dev.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"13": {
"inputs": {
"noise": ["25", 0],
"guider": ["22", 0],
"sampler": ["16", 0],
"sigmas": ["17", 0],
"latent_image": ["5", 0]
},
"class_type": "SamplerCustomAdvanced",
"_meta": {
"title": "SamplerCustomAdvanced"
}
},
"16": {
"inputs": {
"sampler_name": "euler"
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect"
}
},
"17": {
"inputs": {
"scheduler": "simple",
"steps": "{{STEPS}}",
"denoise": 1,
"model": ["12", 0]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
"22": {
"inputs": {
"model": ["12", 0],
"conditioning": ["6", 0]
},
"class_type": "BasicGuider",
"_meta": {
"title": "BasicGuider"
}
},
"25": {
"inputs": {
"noise_seed": "{{SEED}}"
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
}
}
}
}
@@ -1,142 +0,0 @@
{
"input": {
"handler": "RawWorkflow",
"aws_access_key_id": "your-s3-access-key",
"aws_secret_access_key": "your-s3-secret-access-key",
"aws_endpoint_url": "https://my-endpoint.backblaze.com",
"aws_bucket_name": "your-bucket",
"webhook_url": "your-webhook-url",
"webhook_extra_params": {},
"workflow_json": {
"6": {
"inputs": {
"text": "{{PROMPT}}",
"clip": ["252", 1]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"13": {
"inputs": {
"shift": 3,
"model": ["252", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"67": {
"inputs": {
"conditioning": ["71", 0]
},
"class_type": "ConditioningZeroOut",
"_meta": {
"title": "ConditioningZeroOut"
}
},
"68": {
"inputs": {
"start": 0.1,
"end": 1,
"conditioning": ["67", 0]
},
"class_type": "ConditioningSetTimestepRange",
"_meta": {
"title": "ConditioningSetTimestepRange"
}
},
"69": {
"inputs": {
"conditioning_1": ["68", 0],
"conditioning_2": ["70", 0]
},
"class_type": "ConditioningCombine",
"_meta": {
"title": "Conditioning (Combine)"
}
},
"70": {
"inputs": {
"start": 0,
"end": 0.1,
"conditioning": ["71", 0]
},
"class_type": "ConditioningSetTimestepRange",
"_meta": {
"title": "ConditioningSetTimestepRange"
}
},
"71": {
"inputs": {
"text": "bad quality, poor quality, doll, disfigured, jpg, toy, bad anatomy, missing limbs, missing fingers, 3d, cgi",
"clip": ["252", 1]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"135": {
"inputs": {
"width": "{{WIDTH}}",
"height": "{{HEIGHT}}",
"batch_size": 1
},
"class_type": "EmptySD3LatentImage",
"_meta": {
"title": "EmptySD3LatentImage"
}
},
"231": {
"inputs": {
"samples": ["271", 0],
"vae": ["252", 2]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"233": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": ["231", 0]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"252": {
"inputs": {
"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"271": {
"inputs": {
"seed": "{{SEED}}",
"steps": "{{STEPS}}",
"cfg": 4.5,
"sampler_name": "dpmpp_2m",
"scheduler": "sgm_uniform",
"denoise": 1,
"model": ["13", 0],
"positive": ["6", 0],
"negative": ["69", 0],
"latent_image": ["135", 0]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
}
}
}
}
-34
View File
@@ -1,34 +0,0 @@
cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background
stardew valley, fine details
2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture
realistic futuristic city-downtown with short buildings, sunset
seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water
inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award.
biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover
generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric.
fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details
Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting
(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece
Pope Francis wearing biker (leather jacket), a masterpiece
Luke Skywalker ordering a burger and fries from the Death Star canteen.
I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar
portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece
young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece
Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting
crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting
london luxurious interior living-room, light walls
Parisian luxurious interior penthouse bedroom, dark walls, wooden panels
cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot
houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style
Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity
High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight
a landscape from the Moon with the Earth setting on the horizon, realistic, detailed
Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view
A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism
the street of amedieval fantasy town, at dawn, dark, highly detailed
overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark
a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field
electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render
exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar.
-143
View File
@@ -1,143 +0,0 @@
import os
import logging
import dataclasses
import base64
from typing import Optional, Union, Type
from aiohttp import web, ClientResponse
from anyio import open_file
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData
MODEL_SERVER_URL = "http://127.0.0.1:18288" # API Wrapper Service
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188"
MODEL_SERVER_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted
"Value not in list: unet_name", # This error is emitted when the model file is not there at all
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def generate_client_response(
request: web.Request, response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
_ = request
match response.status:
case 200:
log.debug("SUCCESS")
res = await response.json()
if "output" not in res:
return web.json_response(
data=dict(error="there was an error in the workflow"),
status=422,
)
image_paths = [path["local_path"] for path in res["output"]["images"]]
if not image_paths:
return web.json_response(
data=dict(error="workflow did not produce any images"),
status=422,
)
images = []
for image_path in image_paths:
async with await open_file(image_path, mode="rb") as f:
contents = await f.read()
images.append(
f"data:image/png;base64,{base64.b64encode(contents).decode('utf-8')}"
)
return web.json_response(data=dict(images=images))
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
@dataclasses.dataclass
class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[DefaultComfyWorkflowData]:
return DefaultComfyWorkflowData
def make_benchmark_payload(self) -> DefaultComfyWorkflowData:
return DefaultComfyWorkflowData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
return await generate_client_response(client_request, model_response)
@dataclasses.dataclass
class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]):
@property
def endpoint(self) -> str:
return "/runsync"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[CustomComfyWorkflowData]:
return CustomComfyWorkflowData
def make_benchmark_payload(self) -> CustomComfyWorkflowData:
return CustomComfyWorkflowData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
return await generate_client_response(client_request, model_response)
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=False,
benchmark_handler=DefaultComfyWorkflowHandler(
benchmark_runs=3, benchmark_words=100
),
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, "Downloading:"),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/prompt", backend.create_handler(DefaultComfyWorkflowHandler())),
web.post("/custom-workflow", backend.create_handler(CustomComfyWorkflowHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
-15
View File
@@ -1,15 +0,0 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import DefaultComfyWorkflowData, Model
WORKER_ENDPOINT = "/prompt"
if __name__ == "__main__":
test_args.add_argument(
"-m",
dest="comfy_model",
choices=list(map(lambda x: x.value, Model)),
required=True,
help="Image generation model name",
)
test_load_cmd(DefaultComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args)
-321
View File
@@ -1,321 +0,0 @@
# Vast PyWorker
## Hello_world example
There is a hello_world PyWorker implementation under `workers/hello_world`. This PyWorker is
created for an LLM model server that runs on port 5001 has two API endpoints:
1. `/generate`: generates an full response to the prompt and sends a JSON response
2. `/generate_stream`: streams a response one token at a time
Both of these endpoints take the same API JSON payload:
```
{
"prompt": String,
"max_response_tokens": Number | null
}
```
We want the PyWorker to also expose two endpoints that correspond to the above endpoints.
### Structure
All PyWorkers have four files:
```
.
└── workers
└── hello_world
├── __init__.py
├── data_types.py # contains data types representing model API endpoints
├── server.py # contains endpoint handlers
└── test_load.py # script for load testing
```
All of the classes follow strict type hinting. It is recommended that you type hint all of your function.
This will allow your IDE or VSCode with `pyright` plugin to find any type errors in your implementation.
You can also install `pyright` with `sudo npm install -g pyright` and run `pyright` in the root of the project to find
any type errors.
### data_types.py: Contains data types representing model API endpoints
This file defines the structure of the data your model server expects (its API contract) and, critically, how PyWorker *interprets* that data for autoscaling purposes. You define Python data classes that mirror the JSON payloads your model's API uses.
These classes **must** inherit from `lib.data_types.ApiPayload`. This inheritance is not just for structure; it's how PyWorker knows how to:
* **Parse Incoming Requests:** Convert JSON from clients into usable Python objects.
* **Calculate Workload:** Determine the computational cost of a request.
* **Generate Test Data:** Create realistic inputs for benchmarking.
* **Format Requests for the Model Server:** Prepare data for the underlying model.
```python
import dataclasses
import random
from typing import Dict, Any
from transformers import OpenAIGPTTokenizer # used to count tokens in a prompt
import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model
from lib.data_types import ApiPayload
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
# you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs
tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
@dataclasses.dataclass
class InputData(ApiPayload):
prompt: str
max_response_tokens: int
@classmethod
def for_test(cls) -> "ApiPayload":
"""defines how create a payload for load testing"""
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
return cls(prompt=prompt, max_response_tokens=300)
def generate_payload_json(self) -> Dict[str, Any]:
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
return dataclasses.asdict(self)
def count_workload(self) -> float:
"""defines how to calculate workload for a payload"""
return len(tokenizer.tokenize(self.prompt))
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
"""
defines how to transform JSON data to AuthData and payload type,
in this case `InputData` defined above represents the data sent to the model API.
AuthData is data generated by autoscaler in order to authenticate payloads.
In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker
for more complicated examples
"""
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
```
### server.py: Creating Your Model's API Endpoints
This section guides you through creating the core of your custom model API: the `EndpointHandler`. Think of `EndpointHandler` as the bridge between incoming requests from users and your underlying model. It's the key to making your model accessible and scalable.
**Why use an `EndpointHandler`?**
* **Organized Request Handling:** It provides a structured way to handle different types of requests (like generating text, generating images, or performing other model-specific tasks).
* **Scalability:** By separating request handling from the model itself, you can easily scale your API to handle many concurrent users.
* **Flexibility:** You can customize how requests are processed, validated, and transformed before being sent to your model.
* **Standard Interface:** It provides a consistent interface for interacting with your model, regardless of the underlying implementation.
For every model API endpoint you want to expose (e.g., `/generate`, `/generate_stream`), you'll implement an `EndpointHandler`. This class is responsible for:
The `EndpointHandler` achieves this through several key methods:
* **Receiving and validating incoming requests (`get_data_from_request`):** This method ensures the request contains the necessary data (authentication and payload) and is in the correct format. It's the entry point for all requests.
* **Defining the endpoint (`endpoint`):** This method specifies the URL endpoint on the model API server where requests will be sent (e.g., `/generate`).
* **Specifying the payload type (`payload_cls`):** This method indicates the specific `ApiPayload` class used for this endpoint, defining the structure of the request data.
* **Creating benchmark payloads (`make_benchmark_payload`):** This method creates payloads specifically for benchmarking the model's performance.
* **Handling the model's response (`generate_client_response`):** This method takes the response from the model API server and transforms it into the format expected by the client making the request to your PyWorker. This allows you to customize the output as needed.
The `EndpointHandler` class has several abstract functions that you *must* implement to define the behavior of your specific endpoints. Here, we'll implement two common endpoints: `/generate` (for synchronous requests) and `/generate_stream` (for streaming responses):
```python
"""
AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route.
When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData
work.
When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON:
{
auth_data: AuthData,
payload : InputData # defined above
}
"""
from aiohttp import web
from lib.data_types import EndpointHandler, JsonDataException
from lib.server import start_server
from .data_types import InputData
# This class is the implementer for the '/generate' endpoint of model API
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
# the API endpoint
return "/generate"
@classmethod
def payload_cls(cls) -> Type[InputData]:
"""this function should just return ApiPayload subclass used by this handler"""
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
"""
defines how to convert `InputData` defined above, to
JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but
can be more complicated, See comfyui for an example
"""
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
"""
defines how to generate an InputData for benchmarking. This needs to be defined in only
one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test()
method on InputData. However, in some cases you might need to fine tune your InputData used for
benchmarking to closely resemble the average request users call the endpoint with in order to get best
autoscaling performance
"""
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
"""
defines how to convert a model API response to a response to PyWorker client
"""
_ = client_request
match model_response.status:
case 200:
log.debug("SUCCESS")
data = await model_response.json()
return web.json_response(data=data)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
```
We also handle `GenerateStreamHandler` for streaming responses. It is identical to `GenerateHandler`, except for
the endpoint name and how we create a web response, as it is a streaming response:
```python
class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("Streaming response...")
res = web.StreamResponse()
res.content_type = "text/event-stream"
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
```
You can now instantiate a Backend and use it to handle requests.
```python
from lib.backend import Backend, LogAction
# the url and port of model API
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the log line that is emitted once the server has started
MODEL_SERVER_START_LOG_MSG = "server has started"
MODEL_SERVER_ERROR_LOG_MSGS = [
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
]
backend = Backend(
model_server_url=MODEL_SERVER_URL,
# location of model log file
model_log_file=os.environ["MODEL_LOG"],
# for some model backends that can only handle one request at a time, be sure to set this to False to
# let PyWorker handling queueing requests.
allow_parallel_requests=True,
# give the backend an EndpointHandler instance that is used for benchmarking
# number of benchmark run and number of words for a random benchmark run are given
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
# defines how to handle specific log messages. See docstring of LogAction for details
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
# this is a simple ping handler for PyWorker
async def handle_ping(_: web.Request):
return web.Response(body="pong")
# this is a handler for forwarding a health check to model API
async def handle_healthcheck(_: web.Request):
healthcheck_res = await backend.session.get("/healthcheck")
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
routes = [
web.post("/generate", backend.create_handler(GenerateHandler())),
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
web.get("/ping", handle_ping),
web.get("/healthcheck", handle_healthcheck),
]
if __name__ == "__main__":
# start server, called from start_server.sh
start_server(backend, routes)
```
### test_load.py
Here you can create a script that allows you test an endpoint group running instances with this PyWorker
```python
from lib.test_harness import run
from .data_types import InputData
WORKER_ENDPOINT = "/generate"
if __name__ == "__main__":
run(InputData.for_test(), WORKER_ENDPOINT)
```
You can then run the following command from the root of this repo to load test endpoint group:
```sh
# sends 1000 requests at the rate of 0.5 requests per second
python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"
```
View File
View File
-48
View File
@@ -1,48 +0,0 @@
import dataclasses
import random
import inspect
from typing import Dict, Any
from transformers import OpenAIGPTTokenizer
import nltk
from lib.data_types import ApiPayload, JsonDataException
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
# used to count to count tokens and workload for LLM
tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
@dataclasses.dataclass
class InputData(ApiPayload):
prompt: str
max_response_tokens: int
@classmethod
def for_test(cls) -> "InputData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
return cls(prompt=prompt, max_response_tokens=300)
def generate_payload_json(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
def count_workload(self) -> int:
return len(tokenizer.tokenize(self.prompt))
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
-175
View File
@@ -1,175 +0,0 @@
"""
PyWorker works as a man-in-the-middle between the client and model API. It's function is:
1. receive request from client, update metrics such as workload of a request, number of pending requests, etc.
2a. transform the data and forward the transformed data to model API
2b. send updated metrics to autoscaler
3. transform response from model API(if needed) and forward the response to client
PyWorker forward requests to many model API endpoint. each endpoint must have an EndpointHandler. You can also
write function to just forward requests that don't generate anything with the model to model API without an
EndpointHandler. This is useful for endpoints such as healthchecks. See below for example
"""
import os
import logging
import dataclasses
from typing import Dict, Any, Optional, Union, Type
from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import InputData
# the url and port of model API
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the log line that is emitted once the server has started
MODEL_SERVER_START_LOG_MSG = "infer server has started"
MODEL_SERVER_ERROR_LOG_MSGS = [
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
# This class is the implementer for the '/generate' endpoint of model API
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
# the API endpoint
return "/generate"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
"""
defines how to convert `InputData` defined above, to
json data to be sent to the model API
"""
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
"""
defines how to generate an InputData for benchmarking. This needs to be defined in only
one EndpointHandler, the one passed to the backend as the benchmark handler
"""
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
"""
defines how to convert a model API response to a response to PyWorker client
"""
_ = client_request
match model_response.status:
case 200:
log.debug("SUCCESS")
data = await model_response.json()
return web.json_response(data=data)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
# This is the same as GenerateHandler, except that it calls a streaming endpoint of the model API and streams the
# response, which itself is streaming, back to the client.
# it is nearly identical to handler as above, but it calls a different model API endpoint and it streams the
# streaming response from model API to client
class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@property
def healthcheck_endpoint(self) -> Optional[str]:
return None
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("Streaming response...")
res = web.StreamResponse()
res.content_type = "text/event-stream"
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
# This is the backend instance of pyworker. Only one must be made which uses EndpointHandlers to process
# incoming requests
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
# give the backend a handler instance that is used for benchmarking
# number of benchmark run and number of words for a random benchmark run are given
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
# defines how to handle specific log messages. See docstring of LogAction for details
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
# this is a simple ping handler for pyworker
async def handle_ping(_: web.Request):
return web.Response(body="pong")
# this is a handler for forwarding a health check to modelAPI
async def handle_healthcheck(_: web.Request):
healthcheck_res = await backend.session.get("/healthcheck")
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
routes = [
web.post("/generate", backend.create_handler(GenerateHandler())),
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
web.get("/ping", handle_ping),
web.get("/healthcheck", handle_healthcheck),
]
if __name__ == "__main__":
# start the PyWorker server
start_server(backend, routes)
-7
View File
@@ -1,7 +0,0 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import InputData
WORKER_ENDPOINT = "/generate"
if __name__ == "__main__":
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
+88
View File
@@ -0,0 +1,88 @@
# Null PyWorker
Holds Vast Serverless reservations open without forwarding any work to a
model. Use it when your real workload (a queue consumer in any language)
runs as a separate process on the instance and you just want to drive
Vast autoscaling: **one POST reserves a worker, one POST releases it.**
## Use case
You have a job queue on your own infrastructure (Redis, SQS, NATS, etc.)
and a consumer (node, golang, python, a binary — anything) that pulls
from it. You want one Vast worker per unit of in-flight work, scaling
elastically from zero. The null PyWorker is the autoscaling driver; your
consumer does the work.
## How it works
Reservations use the framework's session API. The SDK's
`endpoint.session(...)` POSTs `/session/create` to reserve a worker;
`session.close()` POSTs `/session/end` to release it. `max_sessions=1`
means each worker holds exactly one reservation — the next reservation
either lands on a free worker or triggers a scale-up.
The PyWorker itself does nothing functional:
- One trivial `/ping` route to satisfy the framework's benchmark
requirement (its `max_perf` is pinned to 100).
- An internal `/release` endpoint on `127.0.0.1:18999` for the local
consumer to end the session without needing `session_auth`.
## Endpoint parameters
Tested working configuration:
| Parameter | Value | Why |
|---|---|---|
| `target_util` | `1.0` | One session = one worker. Default `0.9` rounds up to an extra worker. |
| `min_load` | `0` | Scale-to-zero floor. |
| `max_queue_time` | `1` | Stop routing to an occupied worker after ~1s of implied queue. |
| `target_queue_time` | `0.5` | Trigger scale-up promptly once anything queues. |
| `inactivity_timeout` | `10` (seconds) | Permit scale-to-zero after 10s idle. |
## API
| Route | Where | Use |
|---|---|---|
| `POST /session/create` | endpoint, signed | Reserve a worker (`endpoint.session(...)`) |
| `POST /session/end` | endpoint, signed | Release (`session.close()`) |
| `POST /release` | `127.0.0.1:18999`, no auth | Local consumer release, no `session_auth` needed |
## Healthcheck
Default: stub on `127.0.0.1:18999/health` returning `200`. Set
`BACKEND_HEALTH_URL=http://127.0.0.1:9090/health` (absolute URL) to point
the framework at your queue consumer's health endpoint instead — if the
consumer dies, the autoscaler sees the worker as broken.
## Deploying
1. Point `PYWORKER_REPO` at this repo (or your fork).
2. Set `BACKEND=null` in the template.
3. Run your queue consumer alongside the PyWorker. When it's done with
a unit of work:
```bash
curl -X POST http://127.0.0.1:18999/release
```
## Client demo
```bash
# Single reservation, hold 180s
python -m workers.null.client --endpoint <NAME> --instance alpha
# Three concurrent reservations, started 30s apart, each held 360s
python -m workers.null.client --endpoint <NAME> --instance alpha --count 3 --hold 360
```
Flags: `--count` (number of concurrent sessions, default 1), `--hold`
(seconds each session is held, default 180), `--interval` (seconds
between starts when `--count > 1`, default 30), `--cost` (cost reported
at session-create, default 100 = `max_perf`), `--instance` (`prod` |
`alpha` | `candidate` | `local`).
## Environment variables
- `BACKEND_HEALTH_URL` — absolute URL the framework healthchecks. Stub
is used when unset.
- `NULL_CONTROL_PORT` — internal control server port. Defaults to `18999`.
+64
View File
@@ -0,0 +1,64 @@
import argparse
import asyncio
import logging
import os
import sys
from vastai import Serverless
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
async def reserve(client: Serverless, endpoint_name: str, hold: float, cost: int, label: str):
endpoint = await client.get_endpoint(name=endpoint_name)
async with await endpoint.session(cost=cost, lifetime=hold + 60) as s:
sid = s.session_id
log.info("[%s] %s open, holding %.0fs", label, sid, hold)
await asyncio.sleep(hold)
log.info("[%s] %s closed", label, sid)
async def main_async():
p = argparse.ArgumentParser(description="Vast Null PyWorker demo client")
p.add_argument("--endpoint", default=os.environ.get("VAST_ENDPOINT", "null-prod"))
p.add_argument("--instance", choices=("prod", "alpha", "candidate", "local"),
default=os.environ.get("VAST_INSTANCE", "prod"))
p.add_argument("--count", type=int, default=1,
help="concurrent sessions to open (default: 1)")
p.add_argument("--interval", type=float, default=30.0,
help="seconds between session starts when count>1 (default: 30)")
p.add_argument("--hold", type=float, default=180.0,
help="seconds to hold each session (default: 180)")
p.add_argument("--cost", type=int, default=100,
help="cost reported at session-create (default: 100)")
args = p.parse_args()
print(f"endpoint={args.endpoint} instance={args.instance} "
f"count={args.count} hold={args.hold}s cost={args.cost}")
try:
async with Serverless(instance=args.instance) as client:
tasks = []
for i in range(args.count):
label = f"res-{i+1}" if args.count > 1 else "reservation"
tasks.append(asyncio.create_task(
reserve(client, args.endpoint, args.hold, args.cost, label),
name=label,
))
if i + 1 < args.count:
await asyncio.sleep(args.interval)
await asyncio.gather(*tasks, return_exceptions=True)
except KeyboardInterrupt:
log.info("Interrupted")
except Exception as e:
log.error("Error: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main_async())
+143
View File
@@ -0,0 +1,143 @@
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from urllib.parse import urlsplit
from aiohttp import web
from vastai import (
Worker,
WorkerConfig,
HandlerConfig,
BenchmarkConfig,
LogActionConfig,
)
log = logging.getLogger(__file__)
TARGET_PERF = 100.0
BENCHMARK_SENTINEL = "__null_worker_benchmark__"
INTERNAL_HOST = "127.0.0.1"
INTERNAL_PORT = int(os.environ.get("NULL_CONTROL_PORT", 18999))
STUB_HEALTH_PATH = "/health"
BACKEND_HEALTH_URL = os.environ.get("BACKEND_HEALTH_URL", "").strip()
if BACKEND_HEALTH_URL:
_p = urlsplit(BACKEND_HEALTH_URL)
if not _p.scheme or not _p.hostname:
raise ValueError(f"BACKEND_HEALTH_URL must be absolute, got: {BACKEND_HEALTH_URL!r}")
HEALTH_BASE_URL = f"{_p.scheme}://{_p.hostname}"
HEALTH_PORT = _p.port or (443 if _p.scheme == "https" else 80)
HEALTH_PATH = _p.path or "/"
USE_STUB_HEALTH = False
else:
HEALTH_BASE_URL = f"http://{INTERNAL_HOST}"
HEALTH_PORT = INTERNAL_PORT
HEALTH_PATH = STUB_HEALTH_PATH
USE_STUB_HEALTH = True
_backend_ref: dict = {"backend": None}
def _build_internal_app() -> web.Application:
app = web.Application()
async def release_handler(_request: web.Request) -> web.Response:
# Closes the singleton session. Uses name-mangled __close_session
# to bypass the session_auth check — safe because this server is
# bound to 127.0.0.1, and it spares the consumer from threading
# session_auth through its queue.
backend = _backend_ref.get("backend")
if backend is None:
return web.json_response({"released": False, "reason": "backend not ready"}, status=503)
sids = list(backend.sessions.keys())
if not sids:
return web.json_response({"released": False, "reason": "no active session"}, status=200)
closed = []
for sid in sids:
try:
if await backend._Backend__close_session(sid):
closed.append(sid)
except Exception as e:
log.warning(f"Error closing session {sid}: {e}")
return web.json_response({"released": bool(closed), "session_ids": closed}, status=200)
app.router.add_post("/release", release_handler)
if USE_STUB_HEALTH:
async def stub_health(_request: web.Request) -> web.Response:
return web.Response(status=200, text="ok")
app.router.add_get(STUB_HEALTH_PATH, stub_health)
return app
@asynccontextmanager
async def null_lifecycle():
# Pin max_throughput to TARGET_PERF exactly — the framework's
# __run_benchmark short-circuits to float(file_contents) if this exists.
try:
with open(".has_benchmark", "w") as fh:
fh.write(str(int(TARGET_PERF)))
except OSError as e:
log.warning(f"Could not pin benchmark cache: {e}")
runner = web.AppRunner(_build_internal_app())
await runner.setup()
await web.TCPSite(runner, INTERNAL_HOST, INTERNAL_PORT).start()
log.info(
"Null pyworker control server: http://%s:%d (POST /release%s)",
INTERNAL_HOST,
INTERNAL_PORT,
f", GET {STUB_HEALTH_PATH}" if USE_STUB_HEALTH else "",
)
if not USE_STUB_HEALTH:
log.info("Framework healthcheck → %s", BACKEND_HEALTH_URL)
try:
yield
finally:
await runner.cleanup()
async def ping(**params: object) -> dict:
# Exists only to satisfy the framework's "at least one handler with a
# BenchmarkConfig" requirement. Sleep 1s on the benchmark path as a
# fallback in case the .has_benchmark cache pin failed; otherwise the
# benchmark cache short-circuits and this never runs.
if params.get(BENCHMARK_SENTINEL):
await asyncio.sleep(1.0)
return {"ok": True, "benchmark": True}
return {"ok": True}
worker_config = WorkerConfig(
model_server_url=HEALTH_BASE_URL,
model_server_port=HEALTH_PORT,
model_healthcheck_url=HEALTH_PATH,
lifecycle=null_lifecycle(),
max_sessions=1,
handlers=[
HandlerConfig(
route="/ping",
allow_parallel_requests=True,
remote_function=ping,
workload_calculator=lambda _payload: TARGET_PERF,
benchmark_config=BenchmarkConfig(
generator=lambda: {BENCHMARK_SENTINEL: True},
runs=1,
concurrency=1,
do_warmup=False,
),
),
],
log_action_config=LogActionConfig(),
)
_worker = Worker(worker_config)
_backend_ref["backend"] = _worker.backend
_worker.run()
+33 -26
View File
@@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser
This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker.
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended)
- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended)
- [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless))
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless))
All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected.
2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
@@ -34,38 +33,20 @@ uv pip install -r requirements.txt
Several examples have been provided in the client to help you get started with your own implementation.
### Completions
Call to `/v1/completions` with json response
First, set your API key as an environment variable:
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
export VAST_API_KEY=<your_api_key>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively.
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
python -m workers.openai.client --chat-stream --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Interactive Chat (streaming)
@@ -75,6 +56,32 @@ Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
python -m workers.openai.client --interactive --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client --chat --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client --tools --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client --completion --endpoint <ENDPOINT_NAME> --model <MODEL_NAME>
```
-77
View File
@@ -1,77 +0,0 @@
# <INFERENCE_SERVER> + <MODEL_NAME> (serverless)
Run <INFERENCE_SERVER> with our serverless autoscaling infrastructure.
See the [serverless documentation](https://docs.vast.ai/serverless) and the [Getting Started](https://docs.vast.ai/serverless/getting-started) guide for in-depth details about how to use these templates.
## Configuration
Two environment variables are provided to help you configure the <INFERENCE_SERVER> server:
| Variable | Default Value | Used For |
| --- | --- | --- |
| `MODEL_NAME` | `<MODEL_NAME>` | The model to load. Also accepts [hf.co/repo/model](#) links |
| `<ARGS_VAR>` | `<ARGS_VAL>` | Arguments to pass to the `<ARGS_RECEIVER>` command |
This template has been configured to work with <MIN_VRAM> VRAM. Setting alternative models and server arguments will change the VRAM requirements. Check model cards and <INFERENCE_SERVER_DOCS> for guidance.
## Usage
We have provided a demonstration client to help you implement this template into your own infrastructure
### Client Setup
Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
### Completions
Call to `/v1/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --completion --model <MODEL_NAME>
```
### Chat Completion (json)
Call to `/v1/chat/completions` with json response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat --model <MODEL_NAME>
```
### Chat Completion (streaming)
Call to `/v1/chat/completions` with streaming response
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --chat-stream --model <MODEL_NAME>
```
### Tool Use (json)
Call to `/v1/chat/completions` with tool and json response.
This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --tools --model <MODEL_NAME>
```
### Interactive Chat (streaming)
Interactive session with calls to `/v1/chat/completions`.
Type `clear` to clear the chat history or `quit` to exit.
```bash
python -m workers.openai.client -k <API_KEY> -e <ENDPOINT_NAME> --interactive --model <MODEL_NAME>
```
+366 -428
View File
@@ -1,14 +1,15 @@
import logging
import sys
import json
import os
import sys
import subprocess
from urllib.parse import urljoin
from typing import Dict, Any, Optional, Iterator, Union, List
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig
import argparse
from typing import Any, Dict, List, Optional
from vastai import Serverless
import asyncio
# ---------------------- Logging ----------------------
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
@@ -16,135 +17,20 @@ logging.basicConfig(
)
log = logging.getLogger(__file__)
COMPLETIONS_PROMPT = "the capital of USA is"
# ---------------------- Prompts ----------------------
COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by"
CHAT_PROMPT = "Think step by step: Tell me about the Python programming language."
TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?"
class APIClient:
"""Lightweight client focused solely on API communication"""
# Remove the generic WORKER_ENDPOINT since we're now going direct
DEFAULT_COST = 100
DEFAULT_TIMEOUT = 4
def __init__(
self,
endpoint_group_name: str,
api_key: str,
server_url: str,
endpoint_api_key: str,
):
self.endpoint_group_name = endpoint_group_name
self.api_key = api_key
self.server_url = server_url
self.endpoint_api_key = endpoint_api_key
def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]:
"""Get worker URL and auth data from routing service"""
if not self.endpoint_api_key:
raise ValueError("No valid endpoint API key available")
route_payload = {
"endpoint": self.endpoint_group_name,
"api_key": self.endpoint_api_key,
"cost": cost,
}
response = requests.post(
urljoin(self.server_url, "/route/"),
json=route_payload,
timeout=self.DEFAULT_TIMEOUT,
)
response.raise_for_status()
return response.json()
def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Create auth data from routing response"""
return {
"signature": message["signature"],
"cost": message["cost"],
"endpoint": message["endpoint"],
"reqnum": message["reqnum"],
"url": message["url"],
}
def _make_request(
self,
payload: Dict[str, Any],
endpoint: str,
method: str = "POST",
stream: bool = False,
) -> Union[Dict[str, Any], Iterator[str]]:
"""Make request directly to the specific worker endpoint"""
# Get worker URL and auth data
cost = payload.get("max_tokens", self.DEFAULT_COST)
message = self._get_worker_url(cost=cost)
worker_url = message["url"]
auth_data = self._create_auth_data(message)
req_data = {"payload": {"input": payload}, "auth_data": auth_data}
url = urljoin(worker_url, endpoint)
log.debug(f"Making direct request to: {url}")
log.debug(f"Payload: {req_data}")
# Make the request using the specified method
if method.upper() == "POST":
response = requests.post(
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET":
response = requests.get(
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
if stream:
return self._handle_streaming_response(response)
else:
return response.json()
def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]:
"""Handle streaming response and yield tokens"""
try:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
yield data # Yield the full chunk
except json.JSONDecodeError:
continue
except Exception as e:
log.error(f"Error handling streaming response: {e}")
raise
def call_completions(
self, config: CompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/completions", stream=config.stream
)
def call_chat_completions(
self, config: ChatCompletionConfig
) -> Union[Dict[str, Any], Iterator[str]]:
payload = config.to_dict()
return self._make_request(
payload=payload, endpoint="/v1/chat/completions", stream=config.stream
)
TOOLS_PROMPT = (
"Can you list the files in the current working directory and tell me what you see? "
"What do you think this directory might be for?"
)
ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name
DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- Tooling ----------------------
class ToolManager:
"""Handles tool definitions and execution"""
@@ -164,7 +50,7 @@ class ToolManager:
@staticmethod
def get_ls_tool_definition() -> List[Dict[str, Any]]:
"""Get the ls tool definition"""
"""OpenAI-compatible tool schema"""
return [
{
"type": "function",
@@ -178,98 +64,220 @@ class ToolManager:
def execute_tool_call(self, tool_call: Dict[str, Any]) -> str:
"""Execute a tool call and return the result"""
function_name = tool_call["function"]["name"]
function_name = (tool_call.get("function") or {}).get("name")
if function_name == "list_files":
return self.list_files()
else:
raise ValueError(f"Unknown tool function: {function_name}")
raise ValueError(f"Unknown tool function: {function_name}")
# ----- Helpers to handle streamed tool_calls assembly -----
def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None:
"""
OpenAI-style streaming sends partial tool_calls with an index and partial fields.
We merge into a per-index state dict until the assistant message finishes.
"""
idx = tc_delta.get("index")
if idx is None:
return
entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"})
if tc_delta.get("id"):
entry["id"] = tc_delta["id"]
fn_delta = tc_delta.get("function") or {}
if "name" in fn_delta and fn_delta["name"]:
entry["function"]["name"] = fn_delta["name"]
if "arguments" in fn_delta and fn_delta["arguments"]:
entry["function"]["arguments"] += fn_delta["arguments"]
def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]:
return [state[i] for i in sorted(state.keys())]
# ---- OpenAI-compatible calls (non-streaming) ----
async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
}
log.debug("POST /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"])
return resp["response"]
async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]:
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["max_tokens"])
return resp["response"]
# ---- Streaming variants ----
async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"model": model,
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"stop": kwargs["stop"]} if "stop" in kwargs else {}),
}
log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"], stream=True)
return resp["response"] # async generator
async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs):
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"model": model,
"messages": messages,
"max_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"stream": True,
**({"tools": kwargs["tools"]} if "tools" in kwargs else {}),
**({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}),
}
log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500])
resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["max_tokens"], stream=True)
return resp["response"] # async generator
# ---------------------- Demo Runner ----------------------
class APIDemo:
"""Demo and testing functionality for the API client"""
def __init__(
self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None
):
def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None):
self.client = client
self.model = model
self.endpoint_name = endpoint_name
self.tool_manager = tool_manager or ToolManager()
def handle_streaming_response(
self, response_stream, show_reasoning: bool = True
) -> str:
"""
Handle streaming chat response and display all output.
"""
# ----- Streaming handler -----
async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str:
full_response = ""
reasoning_content = ""
reasoning_started = False
content_started = False
printed_reasoning = False
printed_answer = False
finish_reason = None
for chunk in response_stream:
# Normalize the chunk
if isinstance(chunk, str):
chunk = chunk.strip()
if chunk.startswith("data: "):
chunk = chunk[6:].strip()
if chunk in ["[DONE]", ""]:
continue
try:
parsed_chunk = json.loads(chunk)
except json.JSONDecodeError:
continue
elif isinstance(chunk, dict):
parsed_chunk = chunk
else:
continue
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
# Parse delta from the chunk
choices = parsed_chunk.get("choices", [])
if not choices:
continue
# Track finish reason
if choice.get("finish_reason"):
finish_reason = choice.get("finish_reason")
delta = choices[0].get("delta", {})
reasoning_token = delta.get("reasoning_content", "")
content_token = delta.get("content", "")
# Print reasoning token if applicable
if show_reasoning and reasoning_token:
if not reasoning_started:
# reasoning tokens
rc = delta.get("reasoning_content")
if rc and show_reasoning:
if not printed_reasoning:
print("\n🧠 Reasoning: ", end="", flush=True)
reasoning_started = True
print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True)
reasoning_content += reasoning_token
printed_reasoning = True
print(rc, end="", flush=True)
reasoning_content += rc
# Print content token
if content_token:
if not content_started:
if show_reasoning and reasoning_started:
print(f"\n💬 Response: ", end="", flush=True)
# content tokens
content_part = delta.get("content")
if content_part:
if not printed_answer:
if show_reasoning and printed_reasoning:
print("\n💬 Response: ", end="", flush=True)
else:
print("Assistant: ", end="", flush=True)
content_started = True
print(content_token, end="", flush=True)
full_response += content_token
print() # Ensure newline after response
printed_answer = True
print(content_part, end="", flush=True)
full_response += content_part
print() # newline
if show_reasoning:
if reasoning_started or content_started:
if printed_reasoning or printed_answer:
print("\nStreaming completed.")
if reasoning_started:
if printed_reasoning:
print(f"Reasoning tokens: {len(reasoning_content.split())}")
if content_started:
if printed_answer:
print(f"Response tokens: {len(full_response.split())}")
if finish_reason:
print(f"Finish reason: {finish_reason}")
return full_response
def test_tool_support(self) -> bool:
"""Test if the endpoint supports function calling"""
log.debug("Testing endpoint tool calling support...")
async def demo_completions(self) -> None:
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
# Try a simple request with minimal tools to test support
response = await call_completions(
client=self.client,
model=self.model,
prompt=COMPLETIONS_PROMPT,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print("\nResponse:")
print(json.dumps(response, indent=2))
async def demo_chat(self, use_streaming: bool = True) -> None:
print("=" * 60)
print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}")
print("=" * 60)
messages = [{"role": "user", "content": CHAT_PROMPT}]
if use_streaming:
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
try:
await self.handle_streaming_response(stream, show_reasoning=True)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
else:
response = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE
)
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get("reasoning", "")
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def test_tool_support(self) -> bool:
"""Probe that tool schema is accepted (no actual call)"""
messages = [{"role": "user", "content": "Hello"}]
minimal_tool = [
{
@@ -277,179 +285,158 @@ class APIDemo:
"function": {"name": "test_function", "description": "Test function"},
}
]
config = ChatCompletionConfig(
model=self.model,
messages=messages,
max_tokens=10,
tools=minimal_tool,
tool_choice="none", # Don't actually call the tool
)
try:
response = self.client.call_chat_completions(config)
_ = await call_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=minimal_tool,
tool_choice="none",
max_tokens=10
)
return True
except Exception as e:
log.error(f"Error: Endpoint does not support tool calling: {e}")
log.error("Endpoint does not support tool calling: %s", e)
return False
def demo_completions(self) -> None:
"""Demo: test basic completions endpoint"""
print("=" * 60)
print("COMPLETIONS DEMO")
print("=" * 60)
config = CompletionConfig(
model=self.model, prompt=COMPLETIONS_PROMPT, stream=False
)
log.info(
f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'"
)
response = self.client.call_completions(config)
if isinstance(response, dict):
print("\nResponse:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_chat(self, use_streaming: bool = True) -> None:
"""
Demo: test chat completions endpoint with optional streaming
"""
print("=" * 60)
print(
f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}"
)
print("=" * 60)
config = ChatCompletionConfig(
model=self.model,
messages=[{"role": "user", "content": CHAT_PROMPT}],
stream=use_streaming,
)
log.info(f"Testing chat completions with model '{self.model}'...")
response = self.client.call_chat_completions(config)
if use_streaming:
try:
self.handle_streaming_response(response, show_reasoning=True)
except Exception as e:
log.error(f"\nError during streaming: {e}")
import traceback
traceback.print_exc()
return
else:
if isinstance(response, dict):
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning = message.get("reasoning_content", "") or message.get(
"reasoning", ""
)
if reasoning:
print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m")
print(f"\n💬 Assistant: {content}")
print(f"\nFull Response:")
print(json.dumps(response, indent=2))
else:
log.error("Unexpected response format")
def demo_ls_tool(self) -> None:
"""Demo: ask LLM to list files in the current directory and describe what it sees"""
async def demo_ls_tool(self) -> None:
"""Ask to list files using function calling, then provide final analysis"""
print("=" * 60)
print("TOOL USE DEMO: List Directory Contents")
print("=" * 60)
# Test if tools are supported first
if not self.test_tool_support():
if not await self.test_tool_support():
return
# Request with tool available
messages = [{"role": "user", "content": TOOLS_PROMPT}]
messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}]
config = ChatCompletionConfig(
# First pass: let the model decide tools, stream tool_calls and partial content
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
tools=self.tool_manager.get_ls_tool_definition(),
tool_choice="auto",
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
log.info(f"Making initial request with tool using model '{self.model}'...")
response = self.client.call_chat_completions(config)
assistant_content_buf: List[str] = []
tool_calls_state: Dict[int, Dict[str, Any]] = {}
printed_reasoning = False
printed_answer = False
if not isinstance(response, dict):
raise ValueError("Expected dict response for tool use")
async for chunk in stream:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
choice = response.get("choices", [{}])[0]
message = choice.get("message", {})
rc = delta.get("reasoning_content")
if rc:
if not printed_reasoning:
printed_reasoning = True
print("🧠 Reasoning: ", end="", flush=True)
print(rc, end="", flush=True)
print(f"Assistant response: {message.get('content', 'No content')}")
content_part = delta.get("content")
if content_part:
assistant_content_buf.append(content_part)
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(content_part, end="", flush=True)
# Check for tool calls
tool_calls = message.get("tool_calls")
if not tool_calls:
raise ValueError(
"No tool calls made - model may not support function calling"
)
if "tool_calls" in delta and delta["tool_calls"]:
for tc_delta in delta["tool_calls"]:
_merge_tool_call_delta(tool_calls_state, tc_delta)
print(f"Tool calls detected: {len(tool_calls)}")
# If no tool calls, were done.
if not tool_calls_state:
print("\n(No tool calls were made.)")
return
# Execute the tool call
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
print(f"Executing tool: {function_name}")
# Build assistant message with tool_calls
assistant_message = {
"role": "assistant",
"content": "".join(assistant_content_buf) if assistant_content_buf else None,
"tool_calls": _tool_state_to_message_tool_calls(tool_calls_state),
}
messages.append(assistant_message)
tool_result = self.tool_manager.execute_tool_call(tool_call)
print(f"Tool result:\n{tool_result}")
# Execute tools and feed results back
for tc in assistant_message["tool_calls"]:
tool_name = (tc.get("function") or {}).get("name")
call_id = tc.get("id")
raw_args = (tc.get("function") or {}).get("arguments") or "{}"
# Add tool result and continue conversation
messages.append(message) # Add assistant's message with tool call
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": tool_result,
}
)
try:
args = json.loads(raw_args) if raw_args.strip() else {}
except Exception as e:
tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args})
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
continue
# Get final response
final_config = ChatCompletionConfig(
model=self.model,
messages=messages,
tools=self.tool_manager.get_ls_tool_definition(),
)
try:
if tool_name == "list_files":
tool_result = self.tool_manager.list_files()
else:
tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"})
except Exception as e:
tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"})
print("Getting final response...")
final_response = self.client.call_chat_completions(final_config)
print("\n[Tool executed]", tool_name)
print(tool_result[:500] + ("..." if len(tool_result) > 500 else ""))
messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result})
if isinstance(final_response, dict):
final_choice = final_response.get("choices", [{}])[0]
final_message = final_choice.get("message", {})
final_content = final_message.get("content", "")
# Second pass: get final streamed answer after tool results
stream2 = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:")
print("=" * 60)
print(final_content)
print("=" * 60)
final_buf = []
printed_reasoning2 = False
printed_answer2 = False
def interactive_chat(self) -> None:
async for chunk in stream2:
choice = (chunk.get("choices") or [{}])[0]
delta = choice.get("delta", {})
rc2 = delta.get("reasoning_content")
if rc2:
if not printed_reasoning2:
printed_reasoning2 = True
print("\n🧠 Reasoning (post-tools): ", end="", flush=True)
print(rc2, end="", flush=True)
c2 = delta.get("content")
if c2:
final_buf.append(c2)
if not printed_answer2:
printed_answer2 = True
print("\n💬 Response (final): ", end="", flush=True)
print(c2, end="", flush=True)
print("\n" + "=" * 60)
print("FINAL LLM ANALYSIS:")
print("=" * 60)
print("".join(final_buf))
print("=" * 60)
async def interactive_chat(self) -> None:
"""Interactive chat session with streaming"""
print("=" * 60)
print("INTERACTIVE STREAMING CHAT")
print("=" * 60)
print(f"Using model: {self.model}")
print("Type 'quit' to exit, 'clear' to clear history")
print()
messages = []
messages: List[Dict[str, Any]] = []
while True:
try:
@@ -467,16 +454,16 @@ class APIDemo:
messages.append({"role": "user", "content": user_input})
config = ChatCompletionConfig(
model=self.model, messages=messages, stream=True, temperature=0.7
)
print("Assistant: ", end="", flush=True)
response = self.client.call_chat_completions(config)
assistant_content = self.handle_streaming_response(
response, show_reasoning=True
stream = await stream_chat_completions(
client=self.client,
model=self.model,
messages=messages,
endpoint_name=self.endpoint_name,
max_tokens=MAX_TOKENS,
temperature=0.7
)
assistant_content = await self.handle_streaming_response(stream, show_reasoning=True)
# Add assistant response to conversation history
messages.append({"role": "assistant", "content": assistant_content})
@@ -485,115 +472,66 @@ class APIDemo:
print("\n👋 Chat interrupted. Goodbye!")
break
except Exception as e:
log.error(f"\nError: {e}")
log.error("\nError: %s", e)
continue
def main():
"""Main function with CLI switches for different tests"""
from lib.test_utils import test_args
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)")
p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
# Add mandatory model argument
test_args.add_argument(
"--model", required=True, help="Model to use for requests (required)"
)
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--completion", action="store_true", help="Test completions endpoint")
modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)")
modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming")
modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session")
return p
# Add test mode arguments
test_args.add_argument(
"--completion", action="store_true", help="Test completions endpoint"
)
test_args.add_argument(
"--chat",
action="store_true",
help="Test chat completions endpoint (non-streaming)",
)
test_args.add_argument(
"--chat-stream",
action="store_true",
help="Test chat completions endpoint with streaming",
)
test_args.add_argument(
"--tools",
action="store_true",
help="Test function calling with ls tool (non-streaming)",
)
test_args.add_argument(
"--interactive",
action="store_true",
help="Start interactive streaming chat session",
)
args = test_args.parse_args()
async def main_async():
args = build_arg_parser().parse_args()
# Check that only one test mode is selected
test_modes = [
args.completion,
args.chat,
args.chat_stream,
args.tools,
args.interactive,
]
selected_count = sum(test_modes)
if selected_count == 0:
selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive])
if selected == 0:
print("Please specify exactly one test mode:")
print(" --completion : Test completions endpoint")
print(" --chat : Test chat completions endpoint (non-streaming)")
print(" --chat-stream : Test chat completions endpoint with streaming")
print(" --tools : Test function calling with ls tool (non-streaming)")
print(" --tools : Test function calling with ls tool")
print(" --interactive : Start interactive streaming chat session")
print(
f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT"
)
print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint")
sys.exit(1)
elif selected_count > 1:
elif selected > 1:
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using model: {args.model}")
print(f"Using endpoint: {args.endpoint}")
try:
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
async with Serverless() as client:
demo = APIDemo(client, args.model, args.endpoint, ToolManager())
if not endpoint_api_key:
log.error(
f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting."
)
sys.exit(1)
# Create the core API client
client = APIClient(
endpoint_group_name=args.endpoint_group_name,
api_key=args.api_key,
server_url=Endpoint.get_autoscaler_server_url(args.instance),
endpoint_api_key=endpoint_api_key,
)
# Create tool manager and demo (passing the model parameter)
tool_manager = ToolManager()
demo = APIDemo(client, args.model, tool_manager)
print(f"Using model: {args.model}")
print("=" * 60)
# Run the selected test
if args.completion:
demo.demo_completions()
elif args.chat:
demo.demo_chat(use_streaming=False)
elif args.chat_stream:
demo.demo_chat(use_streaming=True)
elif args.tools:
demo.demo_ls_tool()
elif args.interactive:
demo.interactive_chat()
if args.completion:
await demo.demo_completions()
elif args.chat:
await demo.demo_chat(use_streaming=False)
elif args.chat_stream:
await demo.demo_chat(use_streaming=True)
elif args.tools:
await demo.demo_ls_tool()
elif args.interactive:
await demo.interactive_chat()
except Exception as e:
log.error(f"Error during test: {e}", exc_info=True)
log.error("Error during test: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()
asyncio.run(main_async())
-58
View File
@@ -1,58 +0,0 @@
import json
from dataclasses import dataclass, field, fields, is_dataclass
from typing import Optional, List, Dict, Any
class SerializableDataclass:
def _serialize_recursive(self, obj: Any) -> Any:
if is_dataclass(obj):
return {
field.name: self._serialize_recursive(getattr(obj, field.name))
for field in fields(obj)
}
elif isinstance(obj, dict):
return {key: self._serialize_recursive(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return [self._serialize_recursive(item) for item in obj]
elif isinstance(obj, set):
return [self._serialize_recursive(item) for item in obj]
else:
return obj
def to_dict(self) -> Dict[str, Any]:
return self._serialize_recursive(self)
def to_json(self, indent: int = 2) -> str:
return json.dumps(self.to_dict(), indent=indent)
@dataclass
class CompletionConfig(SerializableDataclass):
"""Configuration for completion requests"""
model: str
prompt: str = "Hello"
max_tokens: int = 256
temperature: float = 0.7
top_k: int = 20
top_p: float = 0.4
stream: bool = False
@dataclass
class ChatCompletionConfig(SerializableDataclass):
"""Configuration for chat completion requests"""
model: str
messages: list = field(default_factory=list)
max_tokens: int = 2096
temperature: float = 0.7
top_k: int = 20
top_p: float = 0.4
stream: bool = False
tools: Optional[List[Dict[str, Any]]] = field(default_factory=list)
tool_choice: str = "auto"
def __post_init__(self):
if self.messages is None:
self.messages = [{"role": "user", "content": "Hello"}]
-182
View File
@@ -1,182 +0,0 @@
import os, json, random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from lib.data_types import EndpointHandler, ApiPayload, JsonDataException
from typing import Union, Type, Dict, Any, Optional
from aiohttp import web, ClientResponse
import nltk
import logging
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
log = logging.getLogger(__name__)
"""
Generic dataclass accepts any dictionary in input.
"""
@dataclass
class GenericData(ApiPayload, ABC):
input: Dict[str, Any]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GenericData":
return cls(input=data["input"])
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData":
errors = {}
# Validate required parameters
required_params = ["input"]
for param in required_params:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
try:
# Create clean data dict and delegate to from_dict
clean_data = {"input": json_msg["input"]}
return cls.from_dict(clean_data)
except (json.JSONDecodeError, JsonDataException) as e:
errors["parameters"] = str(e)
raise JsonDataException(errors)
@classmethod
@abstractmethod
def for_test(cls) -> "GenericData":
pass
def generate_payload_json(self) -> Dict[str, Any]:
return self.input
def count_workload(self) -> int:
return self.input.get("max_tokens", 0)
@dataclass
class GenericHandler(EndpointHandler[GenericData], ABC):
@property
@abstractmethod
def endpoint(self) -> str:
pass
@property
def healthcheck_endpoint(self) -> Optional[str]:
return os.environ.get("MODEL_HEALTH_ENDPOINT")
@classmethod
def payload_cls(cls) -> Type[GenericData]:
return GenericData
@abstractmethod
def make_benchmark_payload(self) -> GenericData:
pass
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
# Check if the response is actually streaming based on response headers/content-type
is_streaming_response = (
model_response.content_type == "text/event-stream"
or model_response.content_type == "application/x-ndjson"
or model_response.headers.get("Transfer-Encoding") == "chunked"
or "stream" in model_response.content_type.lower()
)
if is_streaming_response:
log.debug("Detected streaming response...")
res = web.StreamResponse()
res.content_type = model_response.content_type
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
else:
log.debug("Detected non-streaming response...")
content = await model_response.read()
return web.Response(
body=content,
status=200,
content_type=model_response.content_type,
)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
@dataclass
class CompletionsData(GenericData):
@classmethod
def for_test(cls) -> "CompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
test_input = {
"model": model,
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input)
@dataclass
class CompletionsHandler(GenericHandler):
@property
def endpoint(self) -> str:
return "/v1/completions"
@classmethod
def payload_cls(cls) -> Type[CompletionsData]:
return CompletionsData
def make_benchmark_payload(self) -> CompletionsData:
return CompletionsData.for_test()
@dataclass
class ChatCompletionsData(GenericData):
"""Chat completions-specific data implementation"""
@classmethod
def for_test(cls) -> "ChatCompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
# Chat completions use messages format instead of prompt
test_input = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"max_tokens": 500,
}
return cls(input=test_input)
@dataclass
class ChatCompletionsHandler(GenericHandler):
@property
def endpoint(self) -> str:
return "/v1/chat/completions"
@classmethod
def payload_cls(cls) -> Type[ChatCompletionsData]:
return ChatCompletionsData
def make_benchmark_payload(self) -> ChatCompletionsData:
return ChatCompletionsData.for_test()
-60
View File
@@ -1,60 +0,0 @@
import os
import logging
from .data_types.server import CompletionsHandler, ChatCompletionsHandler
from aiohttp import web
from lib.backend import Backend, LogAction
from lib.server import start_server
# This line indicates that the inference server is listening
MODEL_SERVER_START_LOG_MSG = [
"Application startup complete.", # vLLM
"llama runner started", # Ollama
'"message":"Connected","target":"text_generation_router"', # TGI
'"message":"Connected","target":"text_generation_router::server"', # TGI
]
MODEL_SERVER_ERROR_LOG_MSGS = [
"INFO exited: vllm", # vLLM
"RuntimeError: Engine", # vLLM
"Error: pull model manifest:", # Ollama
"stalled; retrying", # Ollama
"Error: WebserverFailed", # TGI
"Error: DownloadError", # TGI
"Error: ShardCannotStart", # TGI
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
backend = Backend(
model_server_url=os.environ["MODEL_SERVER_URL"],
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/v1/completions", backend.create_handler(CompletionsHandler())),
web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
-28
View File
@@ -1,28 +0,0 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types.server import CompletionsData
import os
WORKER_ENDPOINT = "/v1/completions"
if __name__ == "__main__":
# Check if MODEL_NAME environment variable is set
model_name_set = os.environ.get("MODEL_NAME") is not None
# Add model argument - required only if MODEL_NAME is not set
test_args.add_argument(
"--model",
dest="model",
required=not model_name_set,
help="Model to use for completions request (required if MODEL_NAME env var not set)",
)
# Parse known args to get model early, before test_load_cmd adds its args
known_args, _ = test_args.parse_known_args()
# Set environment variable if model was provided
if hasattr(known_args, "model") and known_args.model:
os.environ["MODEL_NAME"] = known_args.model
print(f"Set MODEL_NAME environment variable to: {known_args.model}")
# Now call test_load_cmd normally - it will add its own args and re-parse
test_load_cmd(CompletionsData, WORKER_ENDPOINT, arg_parser=test_args)
+86
View File
@@ -0,0 +1,86 @@
import nltk
import random
import os
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# vLLM model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18000
MODEL_LOG_FILE = '/var/log/portal/vllm.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# vLLM-specific log messages
MODEL_LOAD_LOG_MSG = [
"Application startup complete.",
]
MODEL_ERROR_LOG_MSGS = [
"INFO exited: vllm",
"RuntimeError: Engine",
"Traceback (most recent call last):"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Download'
]
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
def request_parser(request):
data = request
if request.get("input") is not None:
data = request.get("input")
return data
def completions_benchmark_generator() -> dict:
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
model = os.environ.get("MODEL_NAME")
if not model:
raise ValueError("MODEL_NAME environment variable not set")
benchmark_data = {
"model": model,
"prompt": prompt,
"temperature": 0.7,
"max_tokens": 500,
}
return benchmark_data
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/v1/completions",
workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True,
request_parser=request_parser,
max_queue_time=600.0,
benchmark_config=BenchmarkConfig(
generator=completions_benchmark_generator,
concurrency=10,
runs=3
)
),
HandlerConfig(
route="/v1/chat/completions",
workload_calculator= lambda data: data.get("max_tokens", 0),
allow_parallel_requests=True,
request_parser=request_parser,
max_queue_time=600.0,
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+93 -9
View File
@@ -1,19 +1,103 @@
This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints:
# HuggingFace TGI PyWorker
1. `generate`: Generates the LLM's response to a given prompt in a single request.
2. `generate_stream`: Streams the LLM's response token by token.
This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's.
Both endpoints use the following API payload format:
## Instance Setup
1. Pick a template
This worker is compatible with any TGI backend. We have a template you can use or you can create your own.
- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless))
The template can be configured via the template interface. You may want to change the model or startup arguments.
2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface.
## Client Setup (Demo)
1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client.
```bash
git clone https://github.com/vast-ai/pyworker
cd pyworker
pip install uv
uv venv -p 3.12
source .venv/bin/activate
uv pip install -r requirements.txt
```
## Using the Test Client
The test client demonstrates both streaming and non-streaming generation using TGI's native API.
First, set your API key as an environment variable:
```bash
export VAST_API_KEY=<your_api_key>
```
The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`.
### Generate (Streaming)
Call to `/generate_stream` with streaming response:
```bash
python -m workers.tgi.client --generate-stream --endpoint <ENDPOINT_NAME>
```
### Generate (Non-Streaming)
Call to `/generate` with json response:
```bash
python -m workers.tgi.client --generate --endpoint <ENDPOINT_NAME>
```
### Interactive Session (Streaming)
Interactive session with streaming responses. Type `quit` to exit.
```bash
python -m workers.tgi.client --interactive --endpoint <ENDPOINT_NAME>
```
## API Endpoints
TGI provides two primary endpoints:
### Generate (Non-Streaming)
`/generate` - Returns the complete response in a single request.
```json
{
"inputs": "PROMPT",
"inputs": "Your prompt here",
"parameters": {
"max_new_tokens": 250
"max_new_tokens": 1024,
"temperature": 0.7,
"return_full_text": false
}
}
```
Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an
instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take
approximately 2 seconds to complete.
### Generate Stream (Streaming)
`/generate_stream` - Streams the response token by token.
```json
{
"inputs": "Your prompt here",
"parameters": {
"max_new_tokens": 1024,
"temperature": 0.7,
"do_sample": true,
"return_full_text": false
}
}
```
## Performance Notes
The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete.
+201 -104
View File
@@ -1,11 +1,13 @@
import logging
import sys
import json
from urllib.parse import urljoin
import requests
from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
import os
import sys
import argparse
from vastai import Serverless
import asyncio
# ---------------------- Logging ----------------------
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
@@ -13,113 +15,208 @@ logging.basicConfig(
)
log = logging.getLogger(__file__)
# ---------------------- Defaults ----------------------
DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language."
def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None:
WORKER_ENDPOINT = "/generate"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name
MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
# ---------------------- API Calls ----------------------
async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict:
"""Non-streaming generation via /generate endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"return_full_text": False,
}
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=url,
)
payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}")
response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status()
res = response.json()
print(res)
log.debug("POST /generate %s", json.dumps(payload)[:500])
resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"])
return resp["response"]
def call_generate_stream(
endpoint_group_name: str, api_key: str, server_url: str
) -> None:
WORKER_ENDPOINT = "/generate_stream"
COST = 100
route_payload = {
"endpoint": endpoint_group_name,
"api_key": api_key,
"cost": COST,
async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs):
"""Streaming generation via /generate_stream endpoint"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS),
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
"do_sample": True,
"return_full_text": False,
}
}
response = requests.post(
urljoin(server_url, "/route/"),
json=route_payload,
timeout=4,
log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500])
resp = await endpoint.request(
"/generate_stream",
payload,
cost=payload["parameters"]["max_new_tokens"],
stream=True,
)
response.raise_for_status() # Raise an exception for bad status codes
message = response.json()
url = message["url"]
print(f"url: {url}")
auth_data = dict(
signature=message["signature"],
cost=message["cost"],
endpoint=message["endpoint"],
reqnum=message["reqnum"],
url=message["url"],
)
payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500))
req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT)
response = requests.post(url, json=req_data, stream=True)
response.raise_for_status() # Raise an exception for bad status codes
for line in response.iter_lines():
payload = line.decode().lstrip("data:").rstrip()
if payload:
return resp["response"] # async generator
# ---------------------- Demo Runner ----------------------
class APIDemo:
"""Demo and testing functionality for the TGI API client"""
def __init__(self, client: Serverless, endpoint_name: str):
self.client = client
self.endpoint_name = endpoint_name
async def handle_streaming_response(self, stream) -> str:
"""Process streaming response and print tokens"""
full_response = ""
printed_answer = False
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
if not printed_answer:
printed_answer = True
print("\n💬 Response: ", end="", flush=True)
print(tok, end="", flush=True)
full_response += tok
print() # newline
if printed_answer:
print(f"\nStreaming completed. Response tokens: {len(full_response.split())}")
return full_response
async def demo_generate(self) -> None:
"""Demo non-streaming generation"""
print("=" * 60)
print("GENERATE DEMO (NON-STREAMING)")
print("=" * 60)
response = await call_generate(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
print(f"\n💬 Response: {response.get('generated_text', '')}")
print(f"\nFull Response:\n{json.dumps(response, indent=2)}")
async def demo_generate_stream(self) -> None:
"""Demo streaming generation"""
print("=" * 60)
print("GENERATE DEMO (STREAMING)")
print("=" * 60)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=DEFAULT_PROMPT,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
try:
await self.handle_streaming_response(stream)
except Exception as e:
log.error("\nError during streaming: %s", e, exc_info=True)
async def interactive_chat(self) -> None:
"""Interactive session with streaming generation"""
print("=" * 60)
print("INTERACTIVE STREAMING SESSION")
print("=" * 60)
print(f"Using endpoint: {self.endpoint_name}")
print("Type 'quit' to exit")
print()
while True:
try:
data = json.loads(payload)
print(data["token"]["text"], end="")
sys.stdout.flush()
except (json.JSONDecodeError, KeyError) as e:
log.warning(f"Failed to parse streaming response: {e}")
user_input = input("You: ").strip()
if user_input.lower() == "quit":
print("👋 Goodbye!")
break
elif not user_input:
continue
print("Assistant: ", end="", flush=True)
stream = await call_generate_stream(
client=self.client,
endpoint_name=self.endpoint_name,
prompt=user_input,
max_tokens=MAX_TOKENS,
temperature=DEFAULT_TEMPERATURE,
)
full_response = ""
async for event in stream:
tok = (event.get("token") or {}).get("text")
if tok:
print(tok, end="", flush=True)
full_response += tok
print() # newline
except KeyboardInterrupt:
print("\n👋 Session interrupted. Goodbye!")
break
except Exception as e:
log.error("\nError: %s", e)
continue
print()
# ---------------------- CLI ----------------------
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)")
p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})")
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)")
modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming")
modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session")
return p
async def main_async():
args = build_arg_parser().parse_args()
selected = sum([args.generate, args.generate_stream, args.interactive])
if selected == 0:
print("Please specify exactly one test mode:")
print(" --generate : Test generate endpoint (non-streaming)")
print(" --generate-stream : Test generate endpoint with streaming")
print(" --interactive : Start interactive streaming session")
print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint")
sys.exit(1)
elif selected > 1:
print("Please specify exactly one test mode")
sys.exit(1)
print("=" * 60)
print(f"Using endpoint: {args.endpoint}")
try:
async with Serverless() as client:
demo = APIDemo(client, args.endpoint)
if args.generate:
await demo.demo_generate()
elif args.generate_stream:
await demo.demo_generate_stream()
elif args.interactive:
await demo.interactive_chat()
except Exception as e:
log.error("Error during test: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__":
from lib.test_utils import test_args
args = test_args.parse_args()
endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key,
instance=args.instance,
)
if endpoint_api_key:
try:
call_generate(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
call_generate_stream(
api_key=endpoint_api_key,
endpoint_group_name=args.endpoint_group_name,
server_url=args.server_url,
)
except Exception as e:
log.error(f"Error during API call: {e}")
else:
log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ")
asyncio.run(main_async())
-73
View File
@@ -1,73 +0,0 @@
import dataclasses
import random
import inspect
from typing import Dict, Any
from transformers import OpenAIGPTTokenizer
import nltk
from lib.data_types import ApiPayload, JsonDataException
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
@dataclasses.dataclass
class InputParameters:
max_new_tokens: int = 256
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputParameters":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)
@dataclasses.dataclass
class InputData(ApiPayload):
inputs: str
parameters: InputParameters
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "InputData":
return cls(
inputs=data["inputs"], parameters=InputParameters(**data["parameters"])
)
@classmethod
def for_test(cls) -> "InputData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
return cls(inputs=prompt, parameters=InputParameters())
def generate_payload_json(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
def count_workload(self) -> int:
return self.parameters.max_new_tokens
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
try:
parameters = InputParameters.from_json_msg(json_msg["parameters"])
return cls(inputs=json_msg["inputs"], parameters=parameters)
except JsonDataException as e:
errors["parameters"] = e.message
raise JsonDataException(errors)
-130
View File
@@ -1,130 +0,0 @@
import os
import logging
from typing import Union, Type
import dataclasses
from aiohttp import web, ClientResponse
from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import InputData
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the last log line that gets emitted once comfyui+extensions have been fully loaded
MODEL_SERVER_START_LOG_MSG = [
'"message":"Connected","target":"text_generation_router"',
'"message":"Connected","target":"text_generation_router::server"',
]
MODEL_SERVER_ERROR_LOG_MSGS = [
"Error: WebserverFailed",
"Error: DownloadError",
"Error: ShardCannotStart",
]
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
_ = client_request
match model_response.status:
case 200:
log.debug("SUCCESS")
data = await model_response.json()
return web.json_response(data=data)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@property
def healthcheck_endpoint(self) -> str:
return f"{MODEL_SERVER_URL}/health"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("Streaming response...")
res = web.StreamResponse()
res.content_type = "text/event-stream"
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
backend = Backend(
model_server_url=MODEL_SERVER_URL,
model_log_file=os.environ["MODEL_LOG"],
allow_parallel_requests=True,
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
log_actions=[
*[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG],
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
async def handle_ping(_):
return web.Response(body="pong")
routes = [
web.post("/generate", backend.create_handler(GenerateHandler())),
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
web.get("/ping", handle_ping),
]
if __name__ == "__main__":
start_server(backend, routes)
-7
View File
@@ -1,7 +0,0 @@
from lib.test_utils import test_load_cmd, test_args
from .data_types import InputData
WORKER_ENDPOINT = "/generate"
if __name__ == "__main__":
test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args)
+77
View File
@@ -0,0 +1,77 @@
import nltk
import random
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# TGI model configuration
MODEL_SERVER_URL = 'http://0.0.0.0'
MODEL_SERVER_PORT = 5001
MODEL_LOG_FILE = "/workspace/infer.log"
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# TGI-specific log messages
MODEL_LOAD_LOG_MSG = [
'"message":"Connected","target":"text_generation_router"',
'"message":"Connected","target":"text_generation_router::server"',
]
MODEL_ERROR_LOG_MSGS = [
"Error: WebserverFailed",
"Error: DownloadError",
"Error: ShardCannotStart",
]
MODEL_INFO_LOG_MSGS = [
'"message":"Download'
]
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
def benchmark_generator() -> dict:
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
benchmark_data = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 500,
"temperature": 0.7,
"return_full_text": False
}
}
return benchmark_data
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate",
allow_parallel_requests=True,
max_queue_time=600.0,
benchmark_config=BenchmarkConfig(
generator=benchmark_generator,
concurrency=10,
runs=3
),
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
),
HandlerConfig(
route="/generate_stream",
allow_parallel_requests=True,
max_queue_time=600.0,
workload_calculator= lambda x: x["parameters"]["max_new_tokens"]
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()
+170
View File
@@ -0,0 +1,170 @@
# ComfyUI Wan 2.2 PyWorker
This is the PyWorker implementation for running **Wan 2.2 T2V A14B** text-to-video workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI video-generation workflows through a proxy-based architecture and returning generated video assets.
Each request has a static cost of `10000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node.
## Requirements
This worker requires the following components:
- ComfyUI (https://github.com/comfyanonymous/ComfyUI)
- ComfyUI API Wrapper (https://github.com/ai-dock/comfyui-api-wrapper)
- Wan 2.2 T2V A14B models and required custom nodes
A Docker image is provided with all required Wan 2.2 models pre-installed, but any image may be used if the above requirements are met.
## Endpoint
The worker exposes a single synchronous endpoint:
- `/generate/sync`: Processes a complete ComfyUI workflow JSON and generates video output
## Request Format
The Wan 2.2 worker **only supports custom workflow mode**. Modifier-based workflows are not supported.
```json
{
"input": {
"request_id": "uuid-string",
"workflow_json": {
// Complete ComfyUI Wan 2.2 workflow JSON
},
"s3": { },
"webhook": { }
}
}
```
## Request Fields
### Required Fields
- `input`: Container for all request parameters
- `input.workflow_json`: Complete ComfyUI workflow graph for Wan 2.2 video generation
### Optional Fields
- `input.request_id`: Client-defined request identifier
- `input.s3`: S3-compatible storage configuration
- `input.webhook`: Webhook configuration for completion notifications
The special string `"__RANDOM_INT__"` may be used in the workflow JSON and will be replaced with a random integer before submission to ComfyUI.
## S3 Configuration
Generated video assets can be automatically uploaded to S3-compatible storage. Configuration can be supplied per request or via environment variables. Request-level values take precedence.
### Via Request JSON
```json
"s3": {
"access_key_id": "your-s3-access-key",
"secret_access_key": "your-s3-secret-access-key",
"endpoint_url": "https://s3.amazonaws.com",
"bucket_name": "your-bucket",
"region": "us-east-1"
}
```
### Via Environment Variables
```bash
S3_ACCESS_KEY_ID=your-key
S3_SECRET_ACCESS_KEY=your-secret
S3_BUCKET_NAME=your-bucket
S3_ENDPOINT_URL=https://s3.amazonaws.com
S3_REGION=us-east-1
```
## Webhook Configuration
Webhooks are triggered on request completion or failure.
### Via Request JSON
```json
"webhook": {
"url": "https://your-webhook-url",
"extra_params": {
"custom_field": "value"
}
}
```
### Via Environment Variables
```bash
WEBHOOK_URL=https://your-webhook-url
WEBHOOK_TIMEOUT=30
```
## Example Request
### Wan 2.2 Text-to-Video Workflow
```json
{
"input": {
"workflow_json": {
"90": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader"
},
"99": {
"inputs": {
"text": "A cinematic slow-motion portrait of a woman turning her head",
"clip": ["90", 0]
},
"class_type": "CLIPTextEncode"
},
"104": {
"inputs": {
"width": 640,
"height": 640,
"length": 81,
"batch_size": 1
},
"class_type": "EmptyHunyuanLatentVideo"
}
}
}
}
```
## Response Format
A successful response includes execution metadata, ComfyUI output details, and generated video assets.
### Response Fields
- `id`: Unique request identifier
- `status`: `completed`, `failed`, `processing`, `generating`, or `queued`
- `message`: Human-readable status message
- `comfyui_response`: Raw response from ComfyUI, including execution status and progress
- `output`: Array of generated outputs
- `timings`: Timing information for the request
### Output Object
Each entry in `output` includes:
- `filename`: Generated file name (e.g., `.mp4`)
- `local_path`: File path on the worker
- `url`: Pre-signed download URL (if S3 is configured)
- `type`: Output type (`output`)
- `subfolder`: Output directory (e.g., `video`)
- `node_id`: ComfyUI node that produced the output
- `output_type`: Output category (e.g., `images`)
## Notes and Limitations
- Only full ComfyUI workflow JSONs are supported
- Concurrent requests are not supported per worker
- Wan 2.2 models must be installed before processing requests
- Video generation workflows may take several minutes depending on resolution, length, and GPU performance
+205
View File
@@ -0,0 +1,205 @@
from vastai import Serverless
import asyncio
async def main():
async with Serverless() as client:
endpoint = await client.get_endpoint(name="my-wan-endpoint")
# ComfyUI API compatible json workflow for Wan 2.2 T2V
workflow = {
"90": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader",
"_meta": {
"title": "Load CLIP"
}
},
"91": {
"inputs": {
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW",
"clip": ["90", 0]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"92": {
"inputs": {
"vae_name": "wan_2.1_vae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"93": {
"inputs": {
"shift": 8.000000000000002,
"model": ["101", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"94": {
"inputs": {
"shift": 8,
"model": ["102", 0]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"95": {
"inputs": {
"add_noise": "disable",
"noise_seed": 0,
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 10,
"end_at_step": 10000,
"return_with_leftover_noise": "disable",
"model": ["94", 0],
"positive": ["99", 0],
"negative": ["91", 0],
"latent_image": ["96", 0]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"96": {
"inputs": {
"add_noise": "enable",
"noise_seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 0,
"end_at_step": 10,
"return_with_leftover_noise": "enable",
"model": ["93", 0],
"positive": ["99", 0],
"negative": ["91", 0],
"latent_image": ["104", 0]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"97": {
"inputs": {
"samples": ["95", 0],
"vae": ["92", 0]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"98": {
"inputs": {
"filename_prefix": "video/ComfyUI",
"format": "auto",
"codec": "auto",
"video": ["100", 0]
},
"class_type": "SaveVideo",
"_meta": {
"title": "Save Video"
}
},
"99": {
"inputs": {
"text": "Beautiful young European woman with honey blonde hair gracefully turning her head back over shoulder, gentle smile, bright eyes looking at camera. Hair flowing in slow motion as she turns. Soft natural lighting, clean background, cinematic portrait.",
"clip": ["90", 0]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"100": {
"inputs": {
"fps": 16,
"images": ["97", 0]
},
"class_type": "CreateVideo",
"_meta": {
"title": "Create Video"
}
},
"101": {
"inputs": {
"unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"102": {
"inputs": {
"unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"104": {
"inputs": {
"width": 640,
"height": 640,
"length": 81,
"batch_size": 1
},
"class_type": "EmptyHunyuanLatentVideo",
"_meta": {
"title": "EmptyHunyuanLatentVideo"
}
}
}
payload = {
"input": {
"request_id": "",
"workflow_json": workflow,
"s3": {
"access_key_id": "",
"secret_access_key": "",
"endpoint_url": "",
"bucket_name": "",
"region": ""
},
"webhook": {
"url": "",
"extra_params": {
"user_id": "12345",
"project_id": "abc-def"
}
}
}
}
response = await endpoint.request("/generate/sync", payload)
# Response contains status, output, and any errors
print(response["response"])
if __name__ == "__main__":
asyncio.run(main())
+288
View File
@@ -0,0 +1,288 @@
import random
import sys
from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig
# ComyUI model configuration
MODEL_SERVER_URL = 'http://127.0.0.1'
MODEL_SERVER_PORT = 18288
MODEL_LOG_FILE = '/var/log/portal/comfyui.log'
MODEL_HEALTHCHECK_ENDPOINT = "/health"
# ComyUI-specific log messages
MODEL_LOAD_LOG_MSG = [
"To see the GUI go to: "
]
MODEL_ERROR_LOG_MSGS = [
"MetadataIncompleteBuffer",
"Value not in list: ",
"[ERROR] Provisioning Script failed"
]
MODEL_INFO_LOG_MSGS = [
'"message":"Downloading'
]
benchmark_prompts = [
"Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.",
"Cozy farming-game scene with fine details.",
"2D vector child with soccer ball; airbrush chrome; swagger; antique copper.",
"Realistic futuristic downtown of low buildings at sunset.",
"Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.",
"Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.",
"Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.",
"Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.",
"Medieval village inside glass sphere; volumetric light; macro focus.",
"Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.",
"Pope Francis DJ in leather jacket, mixing on giant console; dramatic.",
]
benchmark_dataset = [
{
"input": {
"workflow_json": {
"90": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader",
"_meta": {
"title": "Load CLIP"
}
},
"91": {
"inputs": {
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW",
"clip": [
"90",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"92": {
"inputs": {
"vae_name": "wan_2.1_vae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"93": {
"inputs": {
"shift": 8.000000000000002,
"model": [
"101",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"94": {
"inputs": {
"shift": 8,
"model": [
"102",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"95": {
"inputs": {
"add_noise": "disable",
"noise_seed": 0,
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 10,
"end_at_step": 10000,
"return_with_leftover_noise": "disable",
"model": [
"94",
0
],
"positive": [
"99",
0
],
"negative": [
"91",
0
],
"latent_image": [
"96",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"96": {
"inputs": {
"add_noise": "enable",
"noise_seed": "__RANDOM_INT__",
"steps": 20,
"cfg": 3.5,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 0,
"end_at_step": 10,
"return_with_leftover_noise": "enable",
"model": [
"93",
0
],
"positive": [
"99",
0
],
"negative": [
"91",
0
],
"latent_image": [
"104",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"97": {
"inputs": {
"samples": [
"95",
0
],
"vae": [
"92",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"98": {
"inputs": {
"filename_prefix": "video/ComfyUI",
"format": "auto",
"codec": "auto",
"video": [
"100",
0
]
},
"class_type": "SaveVideo",
"_meta": {
"title": "Save Video"
}
},
"99": {
"inputs": {
"text":prompt,
"clip": [
"90",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"100": {
"inputs": {
"fps": 16,
"images": [
"97",
0
]
},
"class_type": "CreateVideo",
"_meta": {
"title": "Create Video"
}
},
"101": {
"inputs": {
"unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"102": {
"inputs": {
"unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"104": {
"inputs": {
"width": 640,
"height": 640,
"length": 81,
"batch_size": 1
},
"class_type": "EmptyHunyuanLatentVideo",
"_meta": {
"title": "EmptyHunyuanLatentVideo"
}
}
}
}
} for prompt in benchmark_prompts
]
worker_config = WorkerConfig(
model_server_url=MODEL_SERVER_URL,
model_server_port=MODEL_SERVER_PORT,
model_log_file=MODEL_LOG_FILE,
model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT,
handlers=[
HandlerConfig(
route="/generate/sync",
allow_parallel_requests=False,
max_queue_time=10.0,
benchmark_config=BenchmarkConfig(
dataset=benchmark_dataset,
runs=1
),
workload_calculator= lambda _ : 10000.0
)
],
log_action_config=LogActionConfig(
on_load=MODEL_LOAD_LOG_MSG,
on_error=MODEL_ERROR_LOG_MSGS,
on_info=MODEL_INFO_LOG_MSGS
)
)
Worker(worker_config).run()