Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e756f61b9a | |||
| 8cb98c84f9 | |||
| e251afda2b | |||
| 74bd932327 |
+17
-41
@@ -26,8 +26,7 @@ from lib.data_types import (
|
|||||||
LogAction,
|
LogAction,
|
||||||
ApiPayload_T,
|
ApiPayload_T,
|
||||||
JsonDataException,
|
JsonDataException,
|
||||||
RequestMetrics,
|
RequestMetrics
|
||||||
BenchmarkResult
|
|
||||||
)
|
)
|
||||||
|
|
||||||
VERSION = "0.1.0"
|
VERSION = "0.1.0"
|
||||||
@@ -73,9 +72,6 @@ class Backend:
|
|||||||
self._total_pubkey_fetch_errors = 0
|
self._total_pubkey_fetch_errors = 0
|
||||||
self._pubkey = self._fetch_pubkey()
|
self._pubkey = self._fetch_pubkey()
|
||||||
self.__start_healthcheck: bool = False
|
self.__start_healthcheck: bool = False
|
||||||
self._model_tail_start_time = None
|
|
||||||
self._model_loaded_time = None
|
|
||||||
self._first_healthcheck_ok = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pubkey(self) -> Optional[RSA.RsaKey]:
|
def pubkey(self) -> Optional[RSA.RsaKey]:
|
||||||
@@ -107,7 +103,6 @@ class Backend:
|
|||||||
|
|
||||||
#######################################Private#######################################
|
#######################################Private#######################################
|
||||||
def _fetch_pubkey(self):
|
def _fetch_pubkey(self):
|
||||||
t0 = time.time()
|
|
||||||
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
command = ["curl", "-X", "GET", "https://run.vast.ai/pubkey/"]
|
||||||
result = subprocess.check_output(command, universal_newlines=True)
|
result = subprocess.check_output(command, universal_newlines=True)
|
||||||
log.debug("public key:")
|
log.debug("public key:")
|
||||||
@@ -124,8 +119,6 @@ class Backend:
|
|||||||
self._total_pubkey_fetch_errors += 1
|
self._total_pubkey_fetch_errors += 1
|
||||||
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
if self._total_pubkey_fetch_errors >= MAX_PUBKEY_FETCH_ATTEMPTS:
|
||||||
self.backend_errored("Failed to get autoscaler pubkey")
|
self.backend_errored("Failed to get autoscaler pubkey")
|
||||||
else:
|
|
||||||
log.debug(f"pubkey fetch+parse took {time.time()-t0:.2f}s")
|
|
||||||
return key
|
return key
|
||||||
|
|
||||||
async def __handle_request(
|
async def __handle_request(
|
||||||
@@ -246,10 +239,6 @@ class Backend:
|
|||||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||||
async with self.healthcheck_session.get(health_check_url) as response:
|
async with self.healthcheck_session.get(health_check_url) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
if not self._first_healthcheck_ok:
|
|
||||||
if self._model_loaded_time:
|
|
||||||
log.debug(f"first healthcheck OK after {time.time()-self._model_loaded_time:.2f}s since ModelLoaded")
|
|
||||||
self._first_healthcheck_ok = True
|
|
||||||
log.debug("Healthcheck successful")
|
log.debug("Healthcheck successful")
|
||||||
elif response.status == 503:
|
elif response.status == 503:
|
||||||
log.debug(f"Healthcheck failed with status: {response.status}")
|
log.debug(f"Healthcheck failed with status: {response.status}")
|
||||||
@@ -325,20 +314,17 @@ class Backend:
|
|||||||
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "r") as f:
|
||||||
log.debug("already ran benchmark")
|
log.debug("already ran benchmark")
|
||||||
# trigger model load
|
# trigger model load
|
||||||
# payload = self.benchmark_handler.make_benchmark_payload()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
# _ = await self.__call_api(
|
_ = await self.__call_api(
|
||||||
# handler=self.benchmark_handler, payload=payload
|
handler=self.benchmark_handler, payload=payload
|
||||||
# )
|
)
|
||||||
return float(f.readline())
|
return float(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log.debug("Initial run to trigger model loading...")
|
log.debug("Initial run to trigger model loading...")
|
||||||
t_bench0 = time.time()
|
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||||
log.debug(f"warmup request took {time.time()-t_bench0:.2f}s")
|
|
||||||
t_benchmark_loop0 = time.time()
|
|
||||||
|
|
||||||
max_throughput = 0
|
max_throughput = 0
|
||||||
sum_throughput = 0
|
sum_throughput = 0
|
||||||
@@ -346,23 +332,18 @@ class Backend:
|
|||||||
|
|
||||||
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
benchmark_requests = []
|
tasks = []
|
||||||
|
total_workload = 0
|
||||||
|
|
||||||
for i in range(concurrent_requests):
|
for _ in range(concurrent_requests):
|
||||||
payload = self.benchmark_handler.make_benchmark_payload()
|
payload = self.benchmark_handler.make_benchmark_payload()
|
||||||
workload = payload.count_workload()
|
total_workload += payload.count_workload()
|
||||||
task = self.__call_api(handler=self.benchmark_handler, payload=payload)
|
tasks.append(
|
||||||
benchmark_requests.append(
|
self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||||
BenchmarkResult(request_idx=i, workload=workload, task=task)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
responses = await gather(*[br.task for br in benchmark_requests])
|
responses = await gather(*tasks)
|
||||||
for br, response in zip(benchmark_requests, responses):
|
|
||||||
br.response = response
|
|
||||||
|
|
||||||
total_workload = sum(br.workload for br in benchmark_requests if br.is_successful)
|
|
||||||
time_elapsed = time.time() - start
|
time_elapsed = time.time() - start
|
||||||
successful_responses = sum([1 for br in benchmark_requests if br.is_successful])
|
|
||||||
|
|
||||||
throughput = total_workload / time_elapsed
|
throughput = total_workload / time_elapsed
|
||||||
sum_throughput += throughput
|
sum_throughput += throughput
|
||||||
@@ -376,7 +357,7 @@ class Backend:
|
|||||||
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
||||||
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
||||||
f"Throughput: {throughput} workload/s",
|
f"Throughput: {throughput} workload/s",
|
||||||
f"Successful responses: {successful_responses}/{concurrent_requests}",
|
f"Successful responses: {len([r for r in responses if r.status == 200])}",
|
||||||
"#" * 60,
|
"#" * 60,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -386,7 +367,6 @@ class Backend:
|
|||||||
log.debug(
|
log.debug(
|
||||||
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||||
)
|
)
|
||||||
log.debug(f"benchmark loop took {time.time()-t_benchmark_loop0:.2f}s")
|
|
||||||
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||||
f.write(str(max_throughput))
|
f.write(str(max_throughput))
|
||||||
return max_throughput
|
return max_throughput
|
||||||
@@ -399,17 +379,14 @@ class Backend:
|
|||||||
for action, msg in self.log_actions:
|
for action, msg in self.log_actions:
|
||||||
match action:
|
match action:
|
||||||
case LogAction.ModelLoaded if msg in log_line:
|
case LogAction.ModelLoaded if msg in log_line:
|
||||||
now = time.time()
|
log.debug(
|
||||||
elapsed = now - self._model_tail_start_time
|
f"Got log line indicating model is loaded: {log_line}"
|
||||||
log.debug(f"ModelLoaded observed after {elapsed:.2f}s: {log_line}")
|
)
|
||||||
# some backends need a few seconds after logging successful startup before
|
# some backends need a few seconds after logging successful startup before
|
||||||
# they can begin accepting requests
|
# they can begin accepting requests
|
||||||
# await sleep(5)
|
await sleep(5)
|
||||||
try:
|
try:
|
||||||
t_bench0 = time.time()
|
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
self._model_loaded_time = time.time()
|
|
||||||
log.debug(f"benchmark total took {self._model_loaded_time - t_bench0:.2f}s")
|
|
||||||
self.__start_healthcheck = True
|
self.__start_healthcheck = True
|
||||||
self.metrics._model_loaded(
|
self.metrics._model_loaded(
|
||||||
max_throughput=max_throughput,
|
max_throughput=max_throughput,
|
||||||
@@ -428,7 +405,6 @@ class Backend:
|
|||||||
|
|
||||||
async def tail_log():
|
async def tail_log():
|
||||||
log.debug(f"tailing file: {self.model_log_file}")
|
log.debug(f"tailing file: {self.model_log_file}")
|
||||||
self._model_tail_start_time = time.time()
|
|
||||||
async with await open_file(self.model_log_file) as f:
|
async with await open_file(self.model_log_file) as f:
|
||||||
while True:
|
while True:
|
||||||
line = await f.readline()
|
line = await f.readline()
|
||||||
|
|||||||
+1
-12
@@ -3,7 +3,7 @@ import logging
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable
|
from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type
|
||||||
from aiohttp import web, ClientResponse
|
from aiohttp import web, ClientResponse
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@@ -206,17 +206,6 @@ class RequestMetrics:
|
|||||||
status: str
|
status: str
|
||||||
success: bool = False
|
success: bool = False
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BenchmarkResult:
|
|
||||||
request_idx: int
|
|
||||||
workload: float
|
|
||||||
task: Awaitable[ClientResponse]
|
|
||||||
response: Optional[ClientResponse] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_successful(self) -> bool:
|
|
||||||
return self.response is not None and self.response.status == 200
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelMetrics:
|
class ModelMetrics:
|
||||||
"""Model specific metrics"""
|
"""Model specific metrics"""
|
||||||
|
|||||||
@@ -152,13 +152,11 @@ class Metrics:
|
|||||||
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success],
|
"request_idxs": [r.request_idx for r in self.model_metrics.requests_deleting if r.success == success],
|
||||||
"success": success
|
"success": success
|
||||||
}
|
}
|
||||||
log.debug(f"Deleting requests that {'succeeded' if success else 'failed'}: {data['request_idxs']}")
|
|
||||||
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
full_path = report_addr.rstrip("/") + "/delete_requests/"
|
||||||
for attempt in range(1, 4):
|
for attempt in range(1, 4):
|
||||||
try:
|
try:
|
||||||
session = await self.http()
|
session = await self.http()
|
||||||
async with session.post(full_path, json=data) as res:
|
async with session.post(full_path, json=data) as res:
|
||||||
log.debug(f"delete_requests response: {res.status}")
|
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
return True
|
return True
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|||||||
+55
-22
@@ -2,9 +2,6 @@
|
|||||||
|
|
||||||
set -e -o pipefail
|
set -e -o pipefail
|
||||||
|
|
||||||
log() { echo "$(date +'%Y-%m-%d %H:%M:%S') $*"; }
|
|
||||||
step(){ _t0=$(date +%s); eval "$1"; _dt=$(($(date +%s)-_t0)); log "$2 took ${_dt}s"; }
|
|
||||||
|
|
||||||
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
|
WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}"
|
||||||
|
|
||||||
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
|
SERVER_DIR="$WORKSPACE_DIR/vast-pyworker"
|
||||||
@@ -44,28 +41,33 @@ echo_var DEBUG_LOG
|
|||||||
echo_var PYWORKER_LOG
|
echo_var PYWORKER_LOG
|
||||||
echo_var MODEL_LOG
|
echo_var MODEL_LOG
|
||||||
|
|
||||||
# # Populate /etc/environment with quoted values
|
# Populate /etc/environment with quoted values
|
||||||
# if ! grep -q "VAST" /etc/environment; then
|
if ! grep -q "VAST" /etc/environment; then
|
||||||
# env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
|
env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do
|
||||||
# name=${line%%=*}
|
name=${line%%=*}
|
||||||
# value=${line#*=}
|
value=${line#*=}
|
||||||
# printf '%s="%s"\n' "$name" "$value"
|
printf '%s="%s"\n' "$name" "$value"
|
||||||
# done > /etc/environment
|
done > /etc/environment
|
||||||
# fi
|
fi
|
||||||
|
|
||||||
if [ ! -d "$ENV_PATH" ]
|
if [ ! -d "$ENV_PATH" ]
|
||||||
then
|
then
|
||||||
echo "setting up venv"
|
echo "setting up venv"
|
||||||
step 'if ! which uv; then curl -LsSf https://astral.sh/uv/install.sh | sh; source ~/.local/bin/env; fi' "uv install"
|
if ! which uv; then
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
source ~/.local/bin/env
|
||||||
|
fi
|
||||||
|
|
||||||
# Fork testing
|
# Fork testing
|
||||||
step '[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"' "git clone"
|
[[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"
|
||||||
step 'if [[ -n ${PYWORKER_REF:-} ]]; then (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); fi' "git checkout"
|
if [[ -n ${PYWORKER_REF:-} ]]; then
|
||||||
|
(cd "$SERVER_DIR" && git checkout "$PYWORKER_REF")
|
||||||
|
fi
|
||||||
|
|
||||||
|
uv venv --python-preference only-managed "$ENV_PATH" -p 3.10
|
||||||
|
source "$ENV_PATH/bin/activate"
|
||||||
|
|
||||||
step 'uv venv --python-preference only-managed "$ENV_PATH" -p 3.10' "venv create"
|
uv pip install -r "${SERVER_DIR}/requirements.txt"
|
||||||
step 'source "$ENV_PATH/bin/activate"' "venv activate"
|
|
||||||
step 'uv pip install -r "${SERVER_DIR}/requirements.txt"' "pip install requirements"
|
|
||||||
|
|
||||||
touch ~/.no_auto_tmux
|
touch ~/.no_auto_tmux
|
||||||
else
|
else
|
||||||
@@ -78,8 +80,39 @@ fi
|
|||||||
[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
|
[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1
|
||||||
|
|
||||||
if [ "$USE_SSL" = true ]; then
|
if [ "$USE_SSL" = true ]; then
|
||||||
step 'openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" -nodes -sha256 -keyout /etc/instance.key -out /etc/instance.csr -config /etc/openssl-san.cnf' "openssl csr"
|
|
||||||
step 'curl --header "Content-Type: application/octet-stream" --data-binary @//etc/instance.csr -X POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt' "sign cert"
|
cat << EOF > /etc/openssl-san.cnf
|
||||||
|
[req]
|
||||||
|
default_bits = 2048
|
||||||
|
distinguished_name = req_distinguished_name
|
||||||
|
req_extensions = v3_req
|
||||||
|
|
||||||
|
[req_distinguished_name]
|
||||||
|
countryName = US
|
||||||
|
stateOrProvinceName = CA
|
||||||
|
organizationName = Vast.ai Inc.
|
||||||
|
commonName = vast.ai
|
||||||
|
|
||||||
|
[v3_req]
|
||||||
|
basicConstraints = CA:FALSE
|
||||||
|
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
|
||||||
|
subjectAltName = @alt_names
|
||||||
|
|
||||||
|
[alt_names]
|
||||||
|
IP.1 = 0.0.0.0
|
||||||
|
EOF
|
||||||
|
|
||||||
|
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
|
||||||
|
-nodes \
|
||||||
|
-sha256 \
|
||||||
|
-keyout /etc/instance.key \
|
||||||
|
-out /etc/instance.csr \
|
||||||
|
-config /etc/openssl-san.cnf
|
||||||
|
|
||||||
|
curl --header 'Content-Type: application/octet-stream' \
|
||||||
|
--data-binary @//etc/instance.csr \
|
||||||
|
-X \
|
||||||
|
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
@@ -89,11 +122,11 @@ export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
|
|||||||
|
|
||||||
cd "$SERVER_DIR"
|
cd "$SERVER_DIR"
|
||||||
|
|
||||||
log "launching PyWorker server"
|
echo "launching PyWorker server"
|
||||||
|
|
||||||
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines
|
||||||
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only
|
||||||
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG"
|
||||||
|
|
||||||
_t0=$(date +%s); (python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") & _dt=$(($(date +%s)-_t0)); log "PyWorker spawn took ${_dt}s"
|
(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") &
|
||||||
log "launching PyWorker server done"
|
echo "launching PyWorker server done"
|
||||||
|
|||||||
@@ -6,13 +6,10 @@ from typing import Union, Type, Dict, Any, Optional
|
|||||||
from aiohttp import web, ClientResponse
|
from aiohttp import web, ClientResponse
|
||||||
import nltk
|
import nltk
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
t0 = time.time()
|
|
||||||
nltk.download("words")
|
nltk.download("words")
|
||||||
WORD_LIST = nltk.corpus.words.words()
|
WORD_LIST = nltk.corpus.words.words()
|
||||||
print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} NLTK words download+load took {time.time()-t0:.2f}s")
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Generic dataclass accepts any dictionary in input.
|
Generic dataclass accepts any dictionary in input.
|
||||||
@@ -122,25 +119,14 @@ class GenericHandler(EndpointHandler[GenericData], ABC):
|
|||||||
class CompletionsData(GenericData):
|
class CompletionsData(GenericData):
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls) -> "CompletionsData":
|
def for_test(cls) -> "CompletionsData":
|
||||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||||
|
|
||||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
|
||||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
|
||||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
|
||||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
|
||||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
|
||||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
|
||||||
woodlands, shrublands, and mountainous areas.
|
|
||||||
|
|
||||||
Please answer the following question based on the above context."""
|
|
||||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
|
||||||
model = os.environ.get("MODEL_NAME")
|
model = os.environ.get("MODEL_NAME")
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError("MODEL_NAME environment variable not set")
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
|
|
||||||
test_input = {
|
test_input = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": f"{system_prompt}\n\n{unique_question}",
|
"prompt": prompt,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_tokens": 500,
|
"max_tokens": 500,
|
||||||
}
|
}
|
||||||
@@ -167,18 +153,7 @@ class ChatCompletionsData(GenericData):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_test(cls) -> "ChatCompletionsData":
|
def for_test(cls) -> "ChatCompletionsData":
|
||||||
system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base:
|
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
|
||||||
|
|
||||||
Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines
|
|
||||||
with distinctive black-and-white striped coats. There are three living species: Grévy's zebra
|
|
||||||
(Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the
|
|
||||||
genus Equus with horses and asses, the three groups being the only living members of the family
|
|
||||||
Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern
|
|
||||||
and southern Africa and can be found in a variety of habitats such as savannahs, grasslands,
|
|
||||||
woodlands, shrublands, and mountainous areas.
|
|
||||||
|
|
||||||
Please answer the following question based on the above context."""
|
|
||||||
unique_question = " ".join(random.choices(WORD_LIST, k=int(100)))
|
|
||||||
model = os.environ.get("MODEL_NAME")
|
model = os.environ.get("MODEL_NAME")
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError("MODEL_NAME environment variable not set")
|
raise ValueError("MODEL_NAME environment variable not set")
|
||||||
@@ -186,10 +161,7 @@ class ChatCompletionsData(GenericData):
|
|||||||
# Chat completions use messages format instead of prompt
|
# Chat completions use messages format instead of prompt
|
||||||
test_input = {
|
test_input = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": [
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
{"role": "system", "content": system_prompt}, # Shared prefix
|
|
||||||
{"role": "user", "content": unique_question} # Unique per request
|
|
||||||
],
|
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_tokens": 500,
|
"max_tokens": 500,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ def do_one(endpoint_name: str,
|
|||||||
# 1) Check if we got a worker back from route
|
# 1) Check if we got a worker back from route
|
||||||
worker_url = msg.get("url", "")
|
worker_url = msg.get("url", "")
|
||||||
if not worker_url:
|
if not worker_url:
|
||||||
status = msg.get("status", "")
|
|
||||||
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
|
m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S)
|
||||||
if m:
|
if m:
|
||||||
tot, loading, standby, err = map(int, m.groups())
|
tot, loading, standby, err = map(int, m.groups())
|
||||||
|
|||||||
Reference in New Issue
Block a user