Compare commits

..

6 Commits

Author SHA1 Message Date
Colter Downing fd9d56e576 remove env var writing 2025-10-28 16:11:35 -07:00
Colter Downing 8d9ffb3a6c removed 5 sec sleep and warmup request on load 2025-10-28 15:28:30 -07:00
Colter Downing 5d5bc197d7 adding timings for cold start 2025-10-26 18:44:23 -07:00
Colter Downing bcecd6df40 Suppress matplot debug logs 2025-10-25 16:18:02 -07:00
Lucas Armand 4d9bf2048c Fix 2025-10-24 15:44:38 -07:00
Lucas Armand 7788bc4a62 Added some debug logs 2025-10-24 15:41:00 -07:00
6 changed files with 111 additions and 78 deletions
+41 -17
View File
@@ -26,7 +26,8 @@ from lib.data_types import (
LogAction, LogAction,
ApiPayload_T, ApiPayload_T,
JsonDataException, JsonDataException,
RequestMetrics RequestMetrics,
BenchmarkResult
) )
VERSION = "0.1.0" VERSION = "0.1.0"
@@ -72,6 +73,9 @@ class Backend:
self._total_pubkey_fetch_errors = 0 self._total_pubkey_fetch_errors = 0
self._pubkey = self._fetch_pubkey() self._pubkey = self._fetch_pubkey()
self.__start_healthcheck: bool = False self.__start_healthcheck: bool = False
self._model_tail_start_time = None
self._model_loaded_time = None
self._first_healthcheck_ok = False
@property @property
def pubkey(self) -> Optional[RSA.RsaKey]: def pubkey(self) -> Optional[RSA.RsaKey]:
@@ -103,6 +107,7 @@ class Backend:
#######################################Private####################################### #######################################Private#######################################
def _fetch_pubkey(self): def _fetch_pubkey(self):
t0 = time.time()
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"] command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
result = subprocess.check_output(command, universal_newlines=True) result = subprocess.check_output(command, universal_newlines=True)
log.debug("public key:") log.debug("public key:")
@@ -119,6 +124,8 @@ class Backend:
self._total_pubkey_fetch_errors += 1 self._total_pubkey_fetch_errors += 1
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS: if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
self.backend_errored("Failed to get autoscaler pubkey") self.backend_errored("Failed to get autoscaler pubkey")
else:
log.debug(f"pubkey fetch+parse took {time.time()-t0:.2f}s")
return key return key
async def __handle_request( async def __handle_request(
@@ -239,6 +246,10 @@ class Backend:
log.debug(f"Performing healthcheck on {health_check_url}") log.debug(f"Performing healthcheck on {health_check_url}")
async with self.healthcheck_session.get(health_check_url) as response: async with self.healthcheck_session.get(health_check_url) as response:
if response.status == 200: if response.status == 200:
if not self._first_healthcheck_ok:
if self._model_loaded_time:
log.debug(f"first healthcheck OK after {time.time()-self._model_loaded_time:.2f}s since ModelLoaded")
self._first_healthcheck_ok = True
log.debug("Healthcheck successful") log.debug("Healthcheck successful")
elif response.status == 503: elif response.status == 503:
log.debug(f"Healthcheck failed with status: {response.status}") log.debug(f"Healthcheck failed with status: {response.status}")
@@ -314,17 +325,20 @@ class Backend:
with open(BENCHMARK_INDICATOR_FILE, "r") as f: with open(BENCHMARK_INDICATOR_FILE, "r") as f:
log.debug("already ran benchmark") log.debug("already ran benchmark")
# trigger model load # trigger model load
payload = self.benchmark_handler.make_benchmark_payload() # payload = self.benchmark_handler.make_benchmark_payload()
_ = await self.__call_api( # _ = await self.__call_api(
handler=self.benchmark_handler, payload=payload # handler=self.benchmark_handler, payload=payload
) # )
return float(f.readline()) return float(f.readline())
except FileNotFoundError: except FileNotFoundError:
pass pass
log.debug("Initial run to trigger model loading...") log.debug("Initial run to trigger model loading...")
t_bench0 = time.time()
payload = self.benchmark_handler.make_benchmark_payload() payload = self.benchmark_handler.make_benchmark_payload()
await self.__call_api(handler=self.benchmark_handler, payload=payload) await self.__call_api(handler=self.benchmark_handler, payload=payload)
log.debug(f"warmup request took {time.time()-t_bench0:.2f}s")
t_benchmark_loop0 = time.time()
max_throughput = 0 max_throughput = 0
sum_throughput = 0 sum_throughput = 0
@@ -332,18 +346,23 @@ class Backend:
for run in range(1, self.benchmark_handler.benchmark_runs + 1): for run in range(1, self.benchmark_handler.benchmark_runs + 1):
start = time.time() start = time.time()
tasks = [] benchmark_requests = []
total_workload = 0
for _ in range(concurrent_requests): for i in range(concurrent_requests):
payload = self.benchmark_handler.make_benchmark_payload() payload = self.benchmark_handler.make_benchmark_payload()
total_workload += payload.count_workload() workload = payload.count_workload()
tasks.append( task = self.__call_api(handler=self.benchmark_handler, payload=payload)
self.__call_api(handler=self.benchmark_handler, payload=payload) benchmark_requests.append(
BenchmarkResult(request_idx=i, workload=workload, task=task)
) )
responses = await gather(*tasks) responses = await gather(*[br.task for br in benchmark_requests])
for br, response in zip(benchmark_requests, responses):
br.response = response
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
time_elapsed = time.time() - start time_elapsed = time.time() - start
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
throughput = total_workload / time_elapsed throughput = total_workload / time_elapsed
sum_throughput += throughput sum_throughput += throughput
@@ -357,7 +376,7 @@ class Backend:
f"Run: {run}, concurrent_requests: {concurrent_requests}", f"Run: {run}, concurrent_requests: {concurrent_requests}",
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s", f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
f"Throughput: {throughput} workload/s", f"Throughput: {throughput} workload/s",
f"Successful responses: {len([r for r in responses if r.status == 200])}", f"Successful responses: {successful_responses}/{concurrent_requests}",
"#" * 60, "#" * 60,
] ]
) )
@@ -367,6 +386,7 @@ class Backend:
log.debug( log.debug(
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}" f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
) )
log.debug(f"benchmark loop took {time.time()-t_benchmark_loop0:.2f}s")
with open(BENCHMARK_INDICATOR_FILE, "w") as f: with open(BENCHMARK_INDICATOR_FILE, "w") as f:
f.write(str(max_throughput)) f.write(str(max_throughput))
return max_throughput return max_throughput
@@ -379,14 +399,17 @@ class Backend:
for action, msg in self.log_actions: for action, msg in self.log_actions:
match action: match action:
case LogAction.ModelLoaded if msg in log_line: case LogAction.ModelLoaded if msg in log_line:
log.debug( now = time.time()
f"Got log line indicating model is loaded: {log_line}" elapsed = now - self._model_tail_start_time
) log.debug(f"ModelLoaded observed after {elapsed:.2f}s: {log_line}")
# some backends need a few seconds after logging successful startup before # some backends need a few seconds after logging successful startup before
# they can begin accepting requests # they can begin accepting requests
await sleep(5) # await sleep(5)
try: try:
t_bench0 = time.time()
max_throughput = await run_benchmark() max_throughput = await run_benchmark()
self._model_loaded_time = time.time()
log.debug(f"benchmark total took {self._model_loaded_time - t_bench0:.2f}s")
self.__start_healthcheck = True self.__start_healthcheck = True
self.metrics._model_loaded( self.metrics._model_loaded(
max_throughput=max_throughput, max_throughput=max_throughput,
@@ -405,6 +428,7 @@ class Backend:
async def tail_log(): async def tail_log():
log.debug(f"tailing file: {self.model_log_file}") log.debug(f"tailing file: {self.model_log_file}")
self._model_tail_start_time = time.time()
async with await open_file(self.model_log_file) as f: async with await open_file(self.model_log_file) as f:
while True: while True:
line = await f.readline() line = await f.readline()
+12 -1
View File
@@ -3,7 +3,7 @@ import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
from aiohttp import web, ClientResponse from aiohttp import web, ClientResponse
import inspect import inspect
@@ -206,6 +206,17 @@ class RequestMetrics:
status: str status: str
success: bool = False success: bool = False
@dataclass
class BenchmarkResult:
request_idx: int
workload: float
task: Awaitable[ClientResponse]
response: Optional[ClientResponse] = None
@property
def is_successful(self) -> bool:
return self.response is not None and self.response.status == 200
@dataclass @dataclass
class ModelMetrics: class ModelMetrics:
"""Model specific metrics""" """Model specific metrics"""
+2
View File
@@ -152,11 +152,13 @@ class Metrics:
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success], "request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success],
"success": success "success": success
} }
log.debug(f"Deleting requests that {'succeeded' if success else 'failed'}: {data['request_idxs']}")
full_path = report_addr.rstrip("/") + "/delete_requests/" full_path = report_addr.rstrip("/") + "/delete_requests/"
for attempt in range(1, 4): for attempt in range(1, 4):
try: try:
session = await self.http() session = await self.http()
async with session.post(full_path, json=data) as res: async with session.post(full_path, json=data) as res:
log.debug(f"delete_requests response: {res.status}")
res.raise_for_status() res.raise_for_status()
return True return True
except asyncio.TimeoutError: except asyncio.TimeoutError:
+22 -55
View File
@@ -2,6 +2,9 @@
set -e -o pipefail set -e -o pipefail
log() { echo "$(date +'%Y-%m-%d %H:%M:%S') $*"; }
step(){ _t0=$(date +%s); eval "$1"; _dt=$(($(date +%s)-_t0)); log "$2 took ${_dt}s"; }
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}" WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker" SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
@@ -41,33 +44,28 @@ echo_var DEBUG_LOG
echo_var PYWORKER_LOG echo_var PYWORKER_LOG
echo_var MODEL_LOG echo_var MODEL_LOG
# Populate /etc/environment with quoted values # # Populate /etc/environment with quoted values
if ! grep -q "VAST" /etc/environment; then # if ! grep -q "VAST" /etc/environment; then
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do # env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
name=${line%%=*} # name=${line%%=*}
value=${line#*=} # value=${line#*=}
printf '%s="%s"\n' "$name" "$value" # printf '%s="%s"\n' "$name" "$value"
done > /etc/environment # done > /etc/environment
fi # fi
if [ ! -d "$ENV_PATH" ] if [ ! -d "$ENV_PATH" ]
then then
echo "setting up venv" echo "setting up venv"
if ! which uv; then step 'if ! which uv; then curl -LsSf https://astral.sh/uv/install.sh | sh; source ~/.local/bin/env; fi' "uv install"
curl -LsSf https://astral.sh/uv/install.sh | sh
source ~/.local/bin/env
fi
# Fork testing # Fork testing
[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR" step '[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"' "git clone"
if [[ -n ${PYWORKER_REF:-} ]]; then step 'if [[ -n ${PYWORKER_REF:-} ]]; then (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); fi' "git checkout"
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
fi
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
source "$ENV_PATH/bin/activate"
uv pip install -r "${SERVER_DIR}/requirements.txt" step 'uv venv --python-preference only-managed "$ENV_PATH" -p 3.10' "venv create"
step 'source "$ENV_PATH/bin/activate"' "venv activate"
step 'uv pip install -r "${SERVER_DIR}/requirements.txt"' "pip install requirements"
touch ~/.no_auto_tmux touch ~/.no_auto_tmux
else else
@@ -80,39 +78,8 @@ fi
[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1 [ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
if [ "$USE_SSL" = true ]; then if [ "$USE_SSL" = true ]; then
step 'openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" -nodes -sha256 -keyout /etc/instance.key -out /etc/instance.csr -config /etc/openssl-san.cnf' "openssl csr"
cat << EOF > /etc/openssl-san.cnf step 'curl --header "Content-Type: application/octet-stream" --data-binary @//etc/instance.csr -X POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt' "sign cert"
[req]
default_bits = 2048
distinguished_name = req_distinguished_name
req_extensions = v3_req
[req_distinguished_name]
countryName = US
stateOrProvinceName = CA
organizationName = Vast.ai Inc.
commonName = vast.ai
[v3_req]
basicConstraints = CA:FALSE
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
subjectAltName = @alt_names
[alt_names]
IP.1 = 0.0.0.0
EOF
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
-nodes \
-sha256 \
-keyout /etc/instance.key \
-out /etc/instance.csr \
-config /etc/openssl-san.cnf
curl --header 'Content-Type: application/octet-stream' \
--data-binary @//etc/instance.csr \
-X \
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
fi fi
@@ -122,11 +89,11 @@ export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
cd "$SERVER_DIR" cd "$SERVER_DIR"
echo "launching PyWorker server" log "launching PyWorker server"
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines # if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only # from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG" [ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") & _t0=$(date +%s); (python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") & _dt=$(($(date +%s)-_t0)); log "PyWorker spawn took ${_dt}s"
echo "launching PyWorker server done" log "launching PyWorker server done"
+33 -5
View File
@@ -6,10 +6,13 @@ from typing import Union, Type, Dict, Any, Optional
from aiohttp import web, ClientResponse from aiohttp import web, ClientResponse
import nltk import nltk
import logging import logging
import time
log = logging.getLogger(__name__)
t0 = time.time()
nltk.download("words") nltk.download("words")
WORD_LIST = nltk.corpus.words.words() WORD_LIST = nltk.corpus.words.words()
log = logging.getLogger(__name__) print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} NLTK words download+load took {time.time()-t0:.2f}s")
""" """
Generic dataclass accepts any dictionary in input. Generic dataclass accepts any dictionary in input.
@@ -119,14 +122,25 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
class CompletionsData(GenericData): class CompletionsData(GenericData):
@classmethod @classmethod
def for_test(cls) -> "CompletionsData": def for_test(cls) -> "CompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250))) system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
model = os.environ.get("MODEL_NAME") model = os.environ.get("MODEL_NAME")
if not model: if not model:
raise ValueError("MODEL_NAME environment variable not set") raise ValueError("MODEL_NAME environment variable not set")
test_input = { test_input = {
"model": model, "model": model,
"prompt": prompt, "prompt": f"{system_prompt}\n\n{unique_question}",
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 500, "max_tokens": 500,
} }
@@ -153,7 +167,18 @@ class ChatCompletionsData(GenericData):
@classmethod @classmethod
def for_test(cls) -> "ChatCompletionsData": def for_test(cls) -> "ChatCompletionsData":
prompt = " ".join(random.choices(WORD_LIST, k=int(250))) system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
genus Equus with horses and asses, the three groups being the only living members of the family
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
woodlands, shrublands, and mountainous areas.
Please answer the following question based on the above context."""
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
model = os.environ.get("MODEL_NAME") model = os.environ.get("MODEL_NAME")
if not model: if not model:
raise ValueError("MODEL_NAME environment variable not set") raise ValueError("MODEL_NAME environment variable not set")
@@ -161,7 +186,10 @@ class ChatCompletionsData(GenericData):
# Chat completions use messages format instead of prompt # Chat completions use messages format instead of prompt
test_input = { test_input = {
"model": model, "model": model,
"messages": [{"role": "user", "content": prompt}], "messages": [
{"role": "system", "content": system_prompt}, # Shared prefix
{"role": "user", "content": unique_question} # Unique per request
],
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 500, "max_tokens": 500,
} }
+1
View File
@@ -82,6 +82,7 @@ def do_one(endpoint_name: str,
# 1) Check if we got a worker back from route # 1) Check if we got a worker back from route
worker_url = msg.get("url", "") worker_url = msg.get("url", "")
if not worker_url: if not worker_url:
status = msg.get("status", "")
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S) m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
if m: if m:
tot, loading, standby, err = map(int, m.groups()) tot, loading, standby, err = map(int, m.groups())