Files
pyworker/workers/null/client.py
T
Rob Ballantyne 2aada7b210 Add --plateau to null pyworker demo (default 5min)
Previously the first release fired only 30s after the third reservation
started, so the autoscaler often hadn't even finished provisioning the
third worker yet. Default plateau to 300s so all three workers are
visibly running before scale-down begins; configurable via --plateau.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 18:26:31 +01:00

176 lines
5.2 KiB
Python

import argparse
import asyncio
import logging
import os
import sys
import time
from vastai import Serverless
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s[%(levelname)-5s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)
ENDPOINT_NAME = "null-prod"
async def reserve(
client: Serverless,
*,
endpoint_name: str,
duration: float,
label: str = "reservation",
) -> dict:
"""Hold a Vast worker open for `duration` seconds (or until we disconnect).
The worker counts itself busy for the lifetime of this call. Returning
here means the reservation has ended — either /release was called on
the worker's internal control port, or the duration cap fired, or the
HTTP request was cancelled.
"""
endpoint = await client.get_endpoint(name=endpoint_name)
payload = {"duration": duration}
start = time.monotonic()
log.info("[%s] POST /reserve duration=%ss", label, duration)
try:
resp = await endpoint.request("/reserve", payload, cost=150)
elapsed = time.monotonic() - start
log.info("[%s] returned after %.1fs: %s", label, elapsed, resp.get("response"))
return resp["response"]
except asyncio.CancelledError:
elapsed = time.monotonic() - start
log.info("[%s] cancelled after %.1fs (HTTP connection dropped)", label, elapsed)
raise
async def run_demo(
client: Serverless,
*,
endpoint_name: str,
interval: float,
plateau: float,
) -> None:
"""Trapezoidal load: ramp up three reservations, plateau, then scale down.
Start three reservations spaced `interval` seconds apart. Pick the
duration so that the first release fires `plateau` seconds *after the
last reservation started*, giving the autoscaler time to actually have
all three workers running before any of them begin to scale down.
Releases then fire `interval` seconds apart, matching the ramp-up.
Each reservation ends via its duration cap (a 200 success).
"""
n = 3
hold = (n - 1) * interval + plateau
tasks: list[asyncio.Task] = []
for i in range(1, n + 1):
label = f"res-{i}"
log.info("[%s] starting (auto-release after %.0fs)", label, hold)
task = asyncio.create_task(
reserve(
client,
endpoint_name=endpoint_name,
duration=hold,
label=label,
),
name=label,
)
tasks.append(task)
if i < n:
log.info("Waiting %.0fs before next reservation...", interval)
await asyncio.sleep(interval)
log.info(
"All %d reservations in flight; holding plateau for %.0fs, "
"then scaling down %.0fs apart",
n,
plateau,
interval,
)
results = await asyncio.gather(*tasks, return_exceptions=True)
for task, result in zip(tasks, results):
log.info("[%s] final: %r", task.get_name(), result)
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Vast Null PyWorker demo client")
p.add_argument(
"--endpoint",
default=os.environ.get("VAST_ENDPOINT", ENDPOINT_NAME),
help=f"Vast endpoint name (default: {ENDPOINT_NAME})",
)
p.add_argument(
"--duration",
type=float,
default=180.0,
help="Seconds to hold each worker busy (default: 180)",
)
modes = p.add_mutually_exclusive_group(required=False)
modes.add_argument(
"--reserve",
action="store_true",
help="Make a single /reserve call (default if no mode given)",
)
modes.add_argument(
"--demo",
action="store_true",
help="Run the staggered 3-reservation demo, cancelling one mid-way",
)
p.add_argument(
"--interval",
type=float,
default=30.0,
help="Demo mode: seconds between reservation steps (default: 30)",
)
p.add_argument(
"--plateau",
type=float,
default=300.0,
help=(
"Demo mode: seconds to hold all 3 reservations active before "
"scale-down starts. Gives the autoscaler time to fully spin "
"up the third worker (default: 300)"
),
)
return p
async def main_async():
args = build_arg_parser().parse_args()
print("=" * 60)
print(f"Endpoint: {args.endpoint}")
print("=" * 60)
try:
async with Serverless() as client:
if args.demo:
await run_demo(
client,
endpoint_name=args.endpoint,
interval=args.interval,
plateau=args.plateau,
)
else:
response = await reserve(
client,
endpoint_name=args.endpoint,
duration=args.duration,
label="reservation",
)
print(f"Reservation result: {response}")
except KeyboardInterrupt:
log.info("Interrupted; dropping any in-flight reservations")
except Exception as e:
log.error("Error: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main_async())