Add staggered --demo mode to null pyworker client

Three concurrent /reserve calls 30s apart, then cancel the first to show
the early-release path. The remaining two run until their duration cap.
Useful for watching scale-up/scale-down behaviour in the autoscaler
dashboard.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Rob Ballantyne
2026-05-11 17:08:44 +01:00
parent ed0db198c3
commit 463f3de8ea
2 changed files with 130 additions and 22 deletions
+114 -18
View File
@@ -3,11 +3,12 @@ import asyncio
import logging
import os
import sys
import time
from vastai import Serverless
logging.basicConfig(
level=logging.DEBUG,
level=logging.INFO,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
@@ -16,18 +17,83 @@ log = logging.getLogger(__file__)
ENDPOINT_NAME = "null-prod"
async def reserve(client: Serverless, *, endpoint_name: str, duration: float) -> dict:
async def reserve(
client: Serverless,
*,
endpoint_name: str,
duration: float,
label: str = "reservation",
) -> dict:
"""Hold a Vast worker open for `duration` seconds (or until we disconnect).
The worker counts itself busy for the lifetime of this call, so the
autoscaler will keep it provisioned. Returning here means the reservation
has ended — either the worker hit its duration cap or the request errored.
The worker counts itself busy for the lifetime of this call. Returning
here means the reservation has ended — either /release was called on
the worker's internal control port, or the duration cap fired, or the
HTTP request was cancelled.
"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {"duration": duration}
log.info("POST /reserve duration=%ss", duration)
resp = await endpoint.request("/reserve", payload, cost=100)
return resp["response"]
start = time.monotonic()
log.info("[%s] POST /reserve duration=%ss", label, duration)
try:
resp = await endpoint.request("/reserve", payload, cost=100)
elapsed = time.monotonic() - start
log.info("[%s] returned after %.1fs: %s", label, elapsed, resp.get("response"))
return resp["response"]
except asyncio.CancelledError:
elapsed = time.monotonic() - start
log.info("[%s] cancelled after %.1fs (HTTP connection dropped)", label, elapsed)
raise
async def run_demo(
client: Serverless,
*,
endpoint_name: str,
duration: float,
interval: float,
) -> None:
"""Reserve, wait, reserve, wait, reserve, wait, cancel one.
All three reservations run concurrently as separate held HTTP requests.
After all three are in flight, we cancel the first to demonstrate the
early-release path. The remaining two are left to run to their natural
duration cap (or you can ctrl-c to drop them).
"""
tasks: list[asyncio.Task] = []
for i in range(1, 4):
label = f"res-{i}"
task = asyncio.create_task(
reserve(
client,
endpoint_name=endpoint_name,
duration=duration,
label=label,
),
name=label,
)
tasks.append(task)
if i < 3:
log.info("Waiting %.0fs before starting next reservation...", interval)
await asyncio.sleep(interval)
log.info(
"All 3 reservations in flight. Waiting %.0fs, then cancelling res-1...",
interval,
)
await asyncio.sleep(interval)
log.info("Cancelling res-1 (drops the HTTP connection — produces a 499)")
tasks[0].cancel()
log.info(
"res-2 and res-3 left running. They will end at their duration cap "
"(%.0fs), or you can ctrl-c to drop them.",
duration,
)
results = await asyncio.gather(*tasks, return_exceptions=True)
for task, result in zip(tasks, results):
log.info("[%s] final: %r", task.get_name(), result)
def build_arg_parser() -> argparse.ArgumentParser:
@@ -40,8 +106,27 @@ def build_arg_parser() -> argparse.ArgumentParser:
p.add_argument(
"--duration",
type=float,
default=60.0,
help="Seconds to hold the worker busy (default: 60)",
default=180.0,
help="Seconds to hold each worker busy (default: 180)",
)
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument(
"--reserve",
action="store_true",
help="Make a single /reserve call (default if no mode given)",
)
modes.add_argument(
"--demo",
action="store_true",
help="Run the staggered 3-reservation demo, cancelling one mid-way",
)
p.add_argument(
"--interval",
type=float,
default=30.0,
help="Demo mode: seconds between reservation steps (default: 30)",
)
return p
@@ -50,19 +135,30 @@ async def main_async():
args = build_arg_parser().parse_args()
print("=" * 60)
print(f"Reserving 1 worker on endpoint '{args.endpoint}' for {args.duration}s")
print(f"Endpoint: {args.endpoint}")
print("=" * 60)
try:
async with Serverless() as client:
response = await reserve(
client=client,
endpoint_name=args.endpoint,
duration=args.duration,
)
print(f"Reservation result: {response}")
if args.demo:
await run_demo(
client,
endpoint_name=args.endpoint,
duration=args.duration,
interval=args.interval,
)
else:
response = await reserve(
client,
endpoint_name=args.endpoint,
duration=args.duration,
label="reservation",
)
print(f"Reservation result: {response}")
except KeyboardInterrupt:
log.info("Interrupted; dropping any in-flight reservations")
except Exception as e:
log.error("Error during reservation: %s", e, exc_info=True)
log.error("Error: %s", e, exc_info=True)
sys.exit(1)