Files
pyworker/workers/null/worker.py
T
Rob Ballantyne 18974873e5 Add null pyworker for queue-driven autoscaling
A PyWorker that does not forward to any model server. POST /reserve holds
the worker busy until the client disconnects (or the duration cap elapses),
so users with their own job queue can drive Vast autoscaling without
exposing inbound model traffic on the instance.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 16:48:52 +01:00

84 lines
2.4 KiB
Python

import asyncio
import logging
import os
from contextlib import asynccontextmanager
from vastai import (
Worker,
WorkerConfig,
HandlerConfig,
BenchmarkConfig,
LogActionConfig,
)
log = logging.getLogger(__file__)
# Safety cap: if a client never disconnects and never sets `duration`, the
# reservation is auto-released after this many seconds so a stuck client
# can't pin a worker indefinitely. Override with MAX_RESERVATION_SECONDS.
MAX_RESERVATION_SECONDS = float(os.environ.get("MAX_RESERVATION_SECONDS", 3600))
# Marker the benchmark path sets so the same remote function can return
# immediately during capacity estimation instead of sleeping.
BENCHMARK_SENTINEL = "__null_worker_benchmark__"
@asynccontextmanager
async def null_lifecycle():
log.info("Null pyworker active (no model server)")
try:
yield
finally:
log.info("Null pyworker shutting down")
async def reserve_worker(**params: object) -> dict:
if params.get(BENCHMARK_SENTINEL):
return {"ok": True, "benchmark": True}
requested = params.get("duration")
if requested is None:
duration = MAX_RESERVATION_SECONDS
else:
try:
duration = max(0.0, min(float(requested), MAX_RESERVATION_SECONDS))
except (TypeError, ValueError):
duration = MAX_RESERVATION_SECONDS
log.info(
f"Reservation acquired; holding worker busy for up to {duration:.1f}s "
f"(release early by disconnecting the HTTP request)"
)
try:
await asyncio.sleep(duration)
log.info("Reservation duration elapsed; releasing worker")
return {"released": "duration_elapsed", "duration": duration}
except asyncio.CancelledError:
log.info("Reservation released by client disconnect")
raise
worker_config = WorkerConfig(
model_server_url="http://127.0.0.1",
model_server_port=1,
lifecycle=null_lifecycle(),
handlers=[
HandlerConfig(
route="/reserve",
allow_parallel_requests=False,
max_queue_time=30.0,
remote_function=reserve_worker,
workload_calculator=lambda _payload: 100.0,
benchmark_config=BenchmarkConfig(
generator=lambda: {BENCHMARK_SENTINEL: True},
runs=1,
concurrency=1,
do_warmup=False,
),
),
],
log_action_config=LogActionConfig(),
)
Worker(worker_config).run()