Add staggered --demo mode to null pyworker client

Three concurrent /reserve calls 30s apart, then cancel the first to show
the early-release path. The remaining two run until their duration cap.
Useful for watching scale-up/scale-down behaviour in the autoscaler
dashboard.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Rob Ballantyne
2026-05-11 17:08:44 +01:00
parent ed0db198c3
commit 463f3de8ea
2 changed files with 130 additions and 22 deletions
+16 -4
View File
@@ -135,14 +135,26 @@ authentication on it.
## Client example
Single reservation:
```bash
python -m workers.null.client --endpoint <ENDPOINT_NAME> --duration 600
```
This POSTs once to `/reserve`, which causes exactly one worker to be
provisioned (if none is free) and held busy. To exercise the full flow,
shell into the worker and run `curl -X POST http://127.0.0.1:18999/release`
— the client will return with `{"released": "explicit", ...}`.
To exercise the full flow, shell into the worker and run
`curl -X POST http://127.0.0.1:18999/release` — the client returns with
`{"released": "explicit", ...}`.
Staggered demo:
```bash
python -m workers.null.client --endpoint <ENDPOINT_NAME> --demo
```
Starts three reservations 30s apart (all held concurrently), waits another
30s, then cancels the first by dropping its HTTP connection. The remaining
two run until their duration cap. Useful for watching scale-up and
scale-down behaviour in the autoscaler dashboard.
## Notes and caveats
+108 -12
View File
@@ -3,11 +3,12 @@ import asyncio
import logging
import os
import sys
import time
from vastai import Serverless
logging.basicConfig(
level=logging.DEBUG,
level=logging.INFO,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
@@ -16,18 +17,83 @@ log = logging.getLogger(__file__)
ENDPOINT_NAME = "null-prod"
async def reserve(client: Serverless, *, endpoint_name: str, duration: float) -> dict:
async def reserve(
client: Serverless,
*,
endpoint_name: str,
duration: float,
label: str = "reservation",
) -> dict:
"""Hold a Vast worker open for `duration` seconds (or until we disconnect).
The worker counts itself busy for the lifetime of this call, so the
autoscaler will keep it provisioned. Returning here means the reservation
has ended — either the worker hit its duration cap or the request errored.
The worker counts itself busy for the lifetime of this call. Returning
here means the reservation has ended — either /release was called on
the worker's internal control port, or the duration cap fired, or the
HTTP request was cancelled.
"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {"duration": duration}
log.info("POST /reserve duration=%ss", duration)
start = time.monotonic()
log.info("[%s] POST /reserve duration=%ss", label, duration)
try:
resp = await endpoint.request("/reserve", payload, cost=100)
elapsed = time.monotonic() - start
log.info("[%s] returned after %.1fs: %s", label, elapsed, resp.get("response"))
return resp["response"]
except asyncio.CancelledError:
elapsed = time.monotonic() - start
log.info("[%s] cancelled after %.1fs (HTTP connection dropped)", label, elapsed)
raise
async def run_demo(
client: Serverless,
*,
endpoint_name: str,
duration: float,
interval: float,
) -> None:
"""Reserve, wait, reserve, wait, reserve, wait, cancel one.
All three reservations run concurrently as separate held HTTP requests.
After all three are in flight, we cancel the first to demonstrate the
early-release path. The remaining two are left to run to their natural
duration cap (or you can ctrl-c to drop them).
"""
tasks: list[asyncio.Task] = []
for i in range(1, 4):
label = f"res-{i}"
task = asyncio.create_task(
reserve(
client,
endpoint_name=endpoint_name,
duration=duration,
label=label,
),
name=label,
)
tasks.append(task)
if i < 3:
log.info("Waiting %.0fs before starting next reservation...", interval)
await asyncio.sleep(interval)
log.info(
"All 3 reservations in flight. Waiting %.0fs, then cancelling res-1...",
interval,
)
await asyncio.sleep(interval)
log.info("Cancelling res-1 (drops the HTTP connection — produces a 499)")
tasks[0].cancel()
log.info(
"res-2 and res-3 left running. They will end at their duration cap "
"(%.0fs), or you can ctrl-c to drop them.",
duration,
)
results = await asyncio.gather(*tasks, return_exceptions=True)
for task, result in zip(tasks, results):
log.info("[%s] final: %r", task.get_name(), result)
def build_arg_parser() -> argparse.ArgumentParser:
@@ -40,8 +106,27 @@ def build_arg_parser() -> argparse.ArgumentParser:
p.add_argument(
"--duration",
type=float,
default=60.0,
help="Seconds to hold the worker busy (default: 60)",
default=180.0,
help="Seconds to hold each worker busy (default: 180)",
)
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument(
"--reserve",
action="store_true",
help="Make a single /reserve call (default if no mode given)",
)
modes.add_argument(
"--demo",
action="store_true",
help="Run the staggered 3-reservation demo, cancelling one mid-way",
)
p.add_argument(
"--interval",
type=float,
default=30.0,
help="Demo mode: seconds between reservation steps (default: 30)",
)
return p
@@ -50,19 +135,30 @@ async def main_async():
args = build_arg_parser().parse_args()
print("=" * 60)
print(f"Reserving 1 worker on endpoint '{args.endpoint}' for {args.duration}s")
print(f"Endpoint: {args.endpoint}")
print("=" * 60)
try:
async with Serverless() as client:
response = await reserve(
client=client,
if args.demo:
await run_demo(
client,
endpoint_name=args.endpoint,
duration=args.duration,
interval=args.interval,
)
else:
response = await reserve(
client,
endpoint_name=args.endpoint,
duration=args.duration,
label="reservation",
)
print(f"Reservation result: {response}")
except KeyboardInterrupt:
log.info("Interrupted; dropping any in-flight reservations")
except Exception as e:
log.error("Error during reservation: %s", e, exc_info=True)
log.error("Error: %s", e, exc_info=True)
sys.exit(1)