Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 72a5f6ad13 |
+37
-37
@@ -5,10 +5,9 @@ import base64
|
|||||||
import subprocess
|
import subprocess
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from asyncio import sleep, gather, Semaphore
|
from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
|
||||||
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from distutils.util import strtobool
|
|
||||||
|
|
||||||
from anyio import open_file
|
from anyio import open_file
|
||||||
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
|
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError
|
||||||
@@ -56,15 +55,11 @@ class Backend:
|
|||||||
reqnum = -1
|
reqnum = -1
|
||||||
msg_history = []
|
msg_history = []
|
||||||
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
|
||||||
unsecured: bool = dataclasses.field(
|
|
||||||
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.metrics = Metrics()
|
self.metrics = Metrics()
|
||||||
self._total_pubkey_fetch_errors = 0
|
self._total_pubkey_fetch_errors = 0
|
||||||
self._pubkey = self._fetch_pubkey()
|
self._pubkey = self._fetch_pubkey()
|
||||||
self.__start_healthcheck: bool = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pubkey(self) -> Optional[RSA.RsaKey]:
|
def pubkey(self) -> Optional[RSA.RsaKey]:
|
||||||
@@ -123,6 +118,16 @@ class Backend:
|
|||||||
return web.json_response(dict(error="invalid JSON"), status=422)
|
return web.json_response(dict(error="invalid JSON"), status=422)
|
||||||
workload = payload.count_workload()
|
workload = payload.count_workload()
|
||||||
|
|
||||||
|
async def wait_for_disconnection() -> None:
|
||||||
|
while request.transport and not request.transport.is_closing():
|
||||||
|
await sleep(0.5)
|
||||||
|
|
||||||
|
async def cancel_api_call_if_disconnected() -> web.Response:
|
||||||
|
await wait_for_disconnection()
|
||||||
|
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
||||||
|
self.metrics._request_canceled(workload=workload, reqnum=auth_data.reqnum)
|
||||||
|
return web.Response(status=500)
|
||||||
|
|
||||||
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
async def make_request() -> Union[web.Response, web.StreamResponse]:
|
||||||
log.debug(f"got request, {auth_data.reqnum}")
|
log.debug(f"got request, {auth_data.reqnum}")
|
||||||
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
|
self.metrics._request_start(workload=workload, reqnum=auth_data.reqnum)
|
||||||
@@ -168,42 +173,41 @@ class Backend:
|
|||||||
return web.Response(status=401)
|
return web.Response(status=401)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await make_request()
|
done, pending = await wait(
|
||||||
|
[
|
||||||
|
create_task(make_request()),
|
||||||
|
create_task(cancel_api_call_if_disconnected()),
|
||||||
|
],
|
||||||
|
return_when=FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
[task.cancel() for task in pending]
|
||||||
|
return done.pop().result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Exception in main handler loop {e}")
|
log.debug(f"Exception in main handler loop {e}")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
finally:
|
|
||||||
if request.task.cancelled():
|
|
||||||
log.debug(f"request with reqnum: {auth_data.reqnum} was canceled")
|
|
||||||
self.metrics._request_canceled(
|
|
||||||
workload=workload, reqnum=auth_data.reqnum
|
|
||||||
)
|
|
||||||
|
|
||||||
async def __healthcheck(self):
|
async def __healthcheck(self):
|
||||||
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
health_check_url = self.benchmark_handler.healthcheck_endpoint
|
||||||
if health_check_url is None:
|
if health_check_url is None:
|
||||||
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
log.debug("No healthcheck endpoint defined, skipping healthcheck")
|
||||||
return
|
return
|
||||||
while True:
|
await sleep(5)
|
||||||
await sleep(10)
|
try:
|
||||||
if self.__start_healthcheck is False:
|
log.debug(f"Performing healthcheck on {health_check_url}")
|
||||||
continue
|
async with self.session.get(health_check_url) as response:
|
||||||
try:
|
if response.status == 200:
|
||||||
log.debug(f"Performing healthcheck on {health_check_url}")
|
log.debug("Healthcheck successful")
|
||||||
async with self.session.get(health_check_url) as response:
|
elif response.status == 503:
|
||||||
if response.status == 200:
|
log.debug(f"Healthcheck failed with status: {response.status}")
|
||||||
log.debug("Healthcheck successful")
|
self.backend_errored(
|
||||||
elif response.status == 503:
|
f"Healthcheck failed with status: {response.status}"
|
||||||
log.debug(f"Healthcheck failed with status: {response.status}")
|
)
|
||||||
self.backend_errored(
|
else:
|
||||||
f"Healthcheck failed with status: {response.status}"
|
# endpoint not ready yet so bail
|
||||||
)
|
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
||||||
else:
|
except Exception as e:
|
||||||
# endpoint not ready yet so bail
|
log.debug(f"Healthcheck failed with exception: {e}")
|
||||||
log.debug(f"Healthcheck Endpoint not ready: {response.status}")
|
self.backend_errored(str(e))
|
||||||
except Exception as e:
|
|
||||||
log.debug(f"Healthcheck failed with exception: {e}")
|
|
||||||
self.backend_errored(str(e))
|
|
||||||
|
|
||||||
async def _start_tracking(self) -> None:
|
async def _start_tracking(self) -> None:
|
||||||
await gather(
|
await gather(
|
||||||
@@ -221,9 +225,6 @@ class Backend:
|
|||||||
return await self.session.post(url=handler.endpoint, json=api_payload)
|
return await self.session.post(url=handler.endpoint, json=api_payload)
|
||||||
|
|
||||||
def __check_signature(self, auth_data: AuthData) -> bool:
|
def __check_signature(self, auth_data: AuthData) -> bool:
|
||||||
if self.unsecured is True:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def verify_signature(message, signature):
|
def verify_signature(message, signature):
|
||||||
if self.pubkey is None:
|
if self.pubkey is None:
|
||||||
log.debug(f"No Public Key!")
|
log.debug(f"No Public Key!")
|
||||||
@@ -330,7 +331,6 @@ class Backend:
|
|||||||
await sleep(5)
|
await sleep(5)
|
||||||
try:
|
try:
|
||||||
max_throughput = await run_benchmark()
|
max_throughput = await run_benchmark()
|
||||||
self.__start_healthcheck = True
|
|
||||||
self.metrics._model_loaded(
|
self.metrics._model_loaded(
|
||||||
max_throughput=max_throughput,
|
max_throughput=max_throughput,
|
||||||
)
|
)
|
||||||
|
|||||||
+2
-1
@@ -5,6 +5,7 @@ import json
|
|||||||
from asyncio import sleep
|
from asyncio import sleep
|
||||||
from dataclasses import dataclass, asdict, field
|
from dataclasses import dataclass, asdict, field
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -118,7 +119,7 @@ class Metrics:
|
|||||||
|
|
||||||
def send_data(report_addr: str) -> None:
|
def send_data(report_addr: str) -> None:
|
||||||
data = compute_autoscaler_data()
|
data = compute_autoscaler_data()
|
||||||
full_path = report_addr.rstrip("/") + "/worker_status/"
|
full_path = urljoin(report_addr, "/worker_status/")
|
||||||
log.debug(
|
log.debug(
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
|
|||||||
+1
-1
@@ -27,7 +27,7 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs):
|
|||||||
log.debug("starting server...")
|
log.debug("starting server...")
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
runner = web.AppRunner(app, handler_cancellation=True)
|
runner = web.AppRunner(app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(
|
site = web.TCPSite(
|
||||||
runner,
|
runner,
|
||||||
|
|||||||
+9
-24
@@ -53,13 +53,6 @@ test_args.add_argument(
|
|||||||
default="https://run.vast.ai",
|
default="https://run.vast.ai",
|
||||||
help="Call local autoscaler instead of prod, for dev use only",
|
help="Call local autoscaler instead of prod, for dev use only",
|
||||||
)
|
)
|
||||||
test_args.add_argument(
|
|
||||||
"-i",
|
|
||||||
dest="instance",
|
|
||||||
type=str,
|
|
||||||
default="prod",
|
|
||||||
help="Autoscaler shard to run the command against, default: prod",
|
|
||||||
)
|
|
||||||
|
|
||||||
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
|
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
|
||||||
|
|
||||||
@@ -77,7 +70,6 @@ class ClientState:
|
|||||||
api_key: str
|
api_key: str
|
||||||
server_url: str
|
server_url: str
|
||||||
worker_endpoint: str
|
worker_endpoint: str
|
||||||
instance: str
|
|
||||||
payload: ApiPayload
|
payload: ApiPayload
|
||||||
url: str = ""
|
url: str = ""
|
||||||
status: ClientStatus = ClientStatus.FetchEndpoint
|
status: ClientStatus = ClientStatus.FetchEndpoint
|
||||||
@@ -87,7 +79,11 @@ class ClientState:
|
|||||||
|
|
||||||
def make_call(self):
|
def make_call(self):
|
||||||
self.status = ClientStatus.FetchEndpoint
|
self.status = ClientStatus.FetchEndpoint
|
||||||
if not self.api_key:
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
|
endpoint_name=self.endpoint_group_name,
|
||||||
|
account_api_key=self.api_key,
|
||||||
|
)
|
||||||
|
if not endpoint_api_key:
|
||||||
self.as_error.append(
|
self.as_error.append(
|
||||||
f"Endpoint {self.endpoint_group_name} not found for API key",
|
f"Endpoint {self.endpoint_group_name} not found for API key",
|
||||||
)
|
)
|
||||||
@@ -95,14 +91,12 @@ class ClientState:
|
|||||||
return
|
return
|
||||||
route_payload = {
|
route_payload = {
|
||||||
"endpoint": self.endpoint_group_name,
|
"endpoint": self.endpoint_group_name,
|
||||||
"api_key": self.api_key,
|
"api_key": endpoint_api_key,
|
||||||
"cost": self.payload.count_workload(),
|
"cost": self.payload.count_workload(),
|
||||||
}
|
}
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
urljoin(self.server_url, "/route/"),
|
urljoin(self.server_url, "/route/"),
|
||||||
json=route_payload,
|
json=route_payload,
|
||||||
headers=headers,
|
|
||||||
timeout=4,
|
timeout=4,
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
@@ -141,7 +135,6 @@ class ClientState:
|
|||||||
try:
|
try:
|
||||||
self.make_call()
|
self.make_call()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
|
||||||
self.status = ClientStatus.Error
|
self.status = ClientStatus.Error
|
||||||
_ = e
|
_ = e
|
||||||
self.conn_errors[self.url] += 1
|
self.conn_errors[self.url] += 1
|
||||||
@@ -233,7 +226,6 @@ def run_test(
|
|||||||
server_url: str,
|
server_url: str,
|
||||||
worker_endpoint: str,
|
worker_endpoint: str,
|
||||||
payload_cls: Type[ApiPayload],
|
payload_cls: Type[ApiPayload],
|
||||||
instance: str,
|
|
||||||
):
|
):
|
||||||
threads = []
|
threads = []
|
||||||
|
|
||||||
@@ -242,7 +234,8 @@ def run_test(
|
|||||||
print_thread.daemon = True # makes threads get killed on program exit
|
print_thread.daemon = True # makes threads get killed on program exit
|
||||||
print_thread.start()
|
print_thread.start()
|
||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
|
endpoint_name=endpoint_group_name,
|
||||||
|
account_api_key=api_key,
|
||||||
)
|
)
|
||||||
if not endpoint_api_key:
|
if not endpoint_api_key:
|
||||||
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
log.debug(f"Endpoint {endpoint_group_name} not found for API key")
|
||||||
@@ -255,7 +248,6 @@ def run_test(
|
|||||||
server_url=server_url,
|
server_url=server_url,
|
||||||
worker_endpoint=worker_endpoint,
|
worker_endpoint=worker_endpoint,
|
||||||
payload=payload_cls.for_test(),
|
payload=payload_cls.for_test(),
|
||||||
instance=instance,
|
|
||||||
)
|
)
|
||||||
clients.append(client)
|
clients.append(client)
|
||||||
thread = threading.Thread(target=client.simulate_user, args=())
|
thread = threading.Thread(target=client.simulate_user, args=())
|
||||||
@@ -289,19 +281,12 @@ def test_load_cmd(
|
|||||||
args = arg_parser.parse_args()
|
args = arg_parser.parse_args()
|
||||||
if hasattr(args, "comfy_model"):
|
if hasattr(args, "comfy_model"):
|
||||||
os.environ["COMFY_MODEL"] = args.comfy_model
|
os.environ["COMFY_MODEL"] = args.comfy_model
|
||||||
server_url = dict(
|
|
||||||
prod="https://run.vast.ai",
|
|
||||||
alpha="https://run-alpha.vast.ai",
|
|
||||||
candidate="https://run-candidate.vast.ai",
|
|
||||||
local="http://localhost:8080",
|
|
||||||
)[args.instance]
|
|
||||||
run_test(
|
run_test(
|
||||||
num_requests=args.num_requests,
|
num_requests=args.num_requests,
|
||||||
requests_per_second=args.requests_per_second,
|
requests_per_second=args.requests_per_second,
|
||||||
api_key=args.api_key,
|
api_key=args.api_key,
|
||||||
server_url=server_url,
|
server_url=args.server_url,
|
||||||
endpoint_group_name=args.endpoint_group_name,
|
endpoint_group_name=args.endpoint_group_name,
|
||||||
worker_endpoint=endpoint,
|
worker_endpoint=endpoint,
|
||||||
payload_cls=payload_cls,
|
payload_cls=payload_cls,
|
||||||
instance=args.instance,
|
|
||||||
)
|
)
|
||||||
|
|||||||
+11
-11
@@ -87,23 +87,23 @@ if [ "$USE_SSL" = true ]; then
|
|||||||
IP.1 = 0.0.0.0
|
IP.1 = 0.0.0.0
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
|
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
|
||||||
-nodes \
|
-nodes \
|
||||||
-sha256 \
|
-sha256 \
|
||||||
-keyout /etc/instance.key \
|
-keyout /etc/instance.key \
|
||||||
-out /etc/instance.csr \
|
-out /etc/instance.csr \
|
||||||
-config /etc/openssl-san.cnf
|
-config /etc/openssl-san.cnf
|
||||||
|
|
||||||
curl --header 'Content-Type: application/octet-stream' \
|
curl --header 'Content-Type: application/octet-stream' \
|
||||||
--data-binary @//etc/instance.csr \
|
--data-binary @//etc/instance.csr \
|
||||||
-X \
|
-X \
|
||||||
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
|
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED
|
export REPORT_ADDR WORKER_PORT USE_SSL
|
||||||
|
|
||||||
cd "$SERVER_DIR"
|
cd "$SERVER_DIR"
|
||||||
|
|
||||||
|
|||||||
@@ -17,9 +17,7 @@ class Endpoint:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_endpoint_api_key(
|
def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]:
|
||||||
endpoint_name: str, account_api_key: str, instance: str
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
"""
|
||||||
Fetch endpoint API key from VastAI console following the healthcheck pattern.
|
Fetch endpoint API key from VastAI console following the healthcheck pattern.
|
||||||
|
|
||||||
@@ -35,9 +33,7 @@ class Endpoint:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
|
||||||
response = requests.get(
|
response = requests.get(vast_console_url, headers=headers)
|
||||||
f"{vast_console_url}?autoscaler_instance={instance}", headers=headers
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
|
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
|
||||||
|
|||||||
@@ -153,7 +153,6 @@ if __name__ == "__main__":
|
|||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
endpoint_name=args.endpoint_group_name,
|
endpoint_name=args.endpoint_group_name,
|
||||||
account_api_key=args.api_key,
|
account_api_key=args.api_key,
|
||||||
instance=args.instance,
|
|
||||||
)
|
)
|
||||||
if endpoint_api_key:
|
if endpoint_api_key:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -100,7 +100,6 @@ if __name__ == "__main__":
|
|||||||
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
endpoint_api_key = Endpoint.get_endpoint_api_key(
|
||||||
endpoint_name=args.endpoint_group_name,
|
endpoint_name=args.endpoint_group_name,
|
||||||
account_api_key=args.api_key,
|
account_api_key=args.api_key,
|
||||||
instance=args.instance,
|
|
||||||
)
|
)
|
||||||
if endpoint_api_key:
|
if endpoint_api_key:
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user