AUTO-421: clean up some issues

This commit is contained in:
Nader Arbabian
2025-07-11 15:04:54 -07:00
parent 3e49b7d04b
commit ce52419023
5 changed files with 42 additions and 21 deletions
+24 -9
View File
@@ -53,6 +53,13 @@ test_args.add_argument(
default="https://run.vast.ai", default="https://run.vast.ai",
help="Call local autoscaler instead of prod, for dev use only", help="Call local autoscaler instead of prod, for dev use only",
) )
test_args.add_argument(
"-i",
dest="instance",
type=str,
default="prod",
help="Autoscaler shard to run the command against, default: prod",
)
GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]] GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]]
@@ -70,6 +77,7 @@ class ClientState:
api_key: str api_key: str
server_url: str server_url: str
worker_endpoint: str worker_endpoint: str
instance: str
payload: ApiPayload payload: ApiPayload
url: str = "" url: str = ""
status: ClientStatus = ClientStatus.FetchEndpoint status: ClientStatus = ClientStatus.FetchEndpoint
@@ -79,11 +87,7 @@ class ClientState:
def make_call(self): def make_call(self):
self.status = ClientStatus.FetchEndpoint self.status = ClientStatus.FetchEndpoint
endpoint_api_key = Endpoint.get_endpoint_api_key( if not self.api_key:
endpoint_name=self.endpoint_group_name,
account_api_key=self.api_key,
)
if not endpoint_api_key:
self.as_error.append( self.as_error.append(
f"Endpoint {self.endpoint_group_name} not found for API key", f"Endpoint {self.endpoint_group_name} not found for API key",
) )
@@ -91,12 +95,14 @@ class ClientState:
return return
route_payload = { route_payload = {
"endpoint": self.endpoint_group_name, "endpoint": self.endpoint_group_name,
"api_key": endpoint_api_key, "api_key": self.api_key,
"cost": self.payload.count_workload(), "cost": self.payload.count_workload(),
} }
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post( response = requests.post(
urljoin(self.server_url, "/route/"), urljoin(self.server_url, "/route/"),
json=route_payload, json=route_payload,
headers=headers,
timeout=4, timeout=4,
) )
if response.status_code != 200: if response.status_code != 200:
@@ -135,6 +141,7 @@ class ClientState:
try: try:
self.make_call() self.make_call()
except Exception as e: except Exception as e:
print(e)
self.status = ClientStatus.Error self.status = ClientStatus.Error
_ = e _ = e
self.conn_errors[self.url] += 1 self.conn_errors[self.url] += 1
@@ -226,6 +233,7 @@ def run_test(
server_url: str, server_url: str,
worker_endpoint: str, worker_endpoint: str,
payload_cls: Type[ApiPayload], payload_cls: Type[ApiPayload],
instance: str,
): ):
threads = [] threads = []
@@ -234,8 +242,7 @@ def run_test(
print_thread.daemon = True # makes threads get killed on program exit print_thread.daemon = True # makes threads get killed on program exit
print_thread.start() print_thread.start()
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=endpoint_group_name, endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance
account_api_key=api_key,
) )
if not endpoint_api_key: if not endpoint_api_key:
log.debug(f"Endpoint {endpoint_group_name} not found for API key") log.debug(f"Endpoint {endpoint_group_name} not found for API key")
@@ -248,6 +255,7 @@ def run_test(
server_url=server_url, server_url=server_url,
worker_endpoint=worker_endpoint, worker_endpoint=worker_endpoint,
payload=payload_cls.for_test(), payload=payload_cls.for_test(),
instance=instance,
) )
clients.append(client) clients.append(client)
thread = threading.Thread(target=client.simulate_user, args=()) thread = threading.Thread(target=client.simulate_user, args=())
@@ -281,12 +289,19 @@ def test_load_cmd(
args = arg_parser.parse_args() args = arg_parser.parse_args()
if hasattr(args, "comfy_model"): if hasattr(args, "comfy_model"):
os.environ["COMFY_MODEL"] = args.comfy_model os.environ["COMFY_MODEL"] = args.comfy_model
server_url = dict(
prod="https://run.vast.ai",
alpha="https://run-alpha.vast.ai",
candidate="https://run-candidate.vast.ai",
local="http://localhost:8080",
)[args.instance]
run_test( run_test(
num_requests=args.num_requests, num_requests=args.num_requests,
requests_per_second=args.requests_per_second, requests_per_second=args.requests_per_second,
api_key=args.api_key, api_key=args.api_key,
server_url=args.server_url, server_url=server_url,
endpoint_group_name=args.endpoint_group_name, endpoint_group_name=args.endpoint_group_name,
worker_endpoint=endpoint, worker_endpoint=endpoint,
payload_cls=payload_cls, payload_cls=payload_cls,
instance=args.instance,
) )
+10 -10
View File
@@ -87,17 +87,17 @@ if [ "$USE_SSL" = true ]; then
IP.1 = 0.0.0.0 IP.1 = 0.0.0.0
EOF EOF
openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \ openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \
-nodes \ -nodes \
-sha256 \ -sha256 \
-keyout /etc/instance.key \ -keyout /etc/instance.key \
-out /etc/instance.csr \ -out /etc/instance.csr \
-config /etc/openssl-san.cnf -config /etc/openssl-san.cnf
curl --header 'Content-Type: application/octet-stream' \ curl --header 'Content-Type: application/octet-stream' \
--data-binary @//etc/instance.csr \ --data-binary @//etc/instance.csr \
-X \ -X \
POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt;
fi fi
+6 -2
View File
@@ -17,7 +17,9 @@ class Endpoint:
""" """
@staticmethod @staticmethod
def get_endpoint_api_key(endpoint_name: str, account_api_key: str) -> Optional[str]: def get_endpoint_api_key(
endpoint_name: str, account_api_key: str, instance: str
) -> Optional[str]:
""" """
Fetch endpoint API key from VastAI console following the healthcheck pattern. Fetch endpoint API key from VastAI console following the healthcheck pattern.
@@ -33,7 +35,9 @@ class Endpoint:
try: try:
log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}") log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}")
response = requests.get(vast_console_url, headers=headers) response = requests.get(
f"{vast_console_url}?autoscaler_instance={instance}", headers=headers
)
if response.status_code != 200: if response.status_code != 200:
error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}" error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}"
+1
View File
@@ -153,6 +153,7 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name, endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key, account_api_key=args.api_key,
instance=args.instance,
) )
if endpoint_api_key: if endpoint_api_key:
try: try:
+1
View File
@@ -100,6 +100,7 @@ if __name__ == "__main__":
endpoint_api_key = Endpoint.get_endpoint_api_key( endpoint_api_key = Endpoint.get_endpoint_api_key(
endpoint_name=args.endpoint_group_name, endpoint_name=args.endpoint_group_name,
account_api_key=args.api_key, account_api_key=args.api_key,
instance=args.instance,
) )
if endpoint_api_key: if endpoint_api_key:
try: try: