Compare commits

..

4 Commits

7 changed files with 78 additions and 31 deletions
+38 -27
View File
@@ -280,41 +280,52 @@ class Backend:
return float(f.readline()) return float(f.readline())
except FileNotFoundError: except FileNotFoundError:
pass pass
log.debug("Initial run to trigger model loading...")
payload = self.benchmark_handler.make_benchmark_payload()
await self.__call_api(handler=self.benchmark_handler, payload=payload)
max_throughput = 0 max_throughput = 0
last_throughput = 0
sum_throughput = 0 sum_throughput = 0
for run in range(self.benchmark_handler.benchmark_runs + 1): concurrent_requests = 10 if self.allow_parallel_requests else 1
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
start = time.time() start = time.time()
payload = self.benchmark_handler.make_benchmark_payload() tasks = []
res = await self.__call_api( total_workload = 0
handler=self.benchmark_handler, payload=payload
) for _ in range(concurrent_requests):
data = await res.json() payload = self.benchmark_handler.make_benchmark_payload()
time_elapsed = time.time() - start total_workload += payload.count_workload()
# first run triggers one-time loading of the model which is very slow, so we skip counting it tasks.append(
if run == 0: self.__call_api(handler=self.benchmark_handler, payload=payload)
continue
else:
workload = payload.count_workload()
last_throughput = workload / time_elapsed
sum_throughput += last_throughput
max_throughput = max(max_throughput, last_throughput)
log.debug(
"\n".join(
[
"#" * 60,
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
"",
f"response: {data}",
"#" * 60,
]
)
) )
responses = await gather(*tasks)
time_elapsed = time.time() - start
throughput = total_workload / time_elapsed
sum_throughput += throughput
max_throughput = max(max_throughput, throughput)
# Log results for debugging
log.debug(
"\n".join(
[
"#" * 60,
f"Run: {run}, concurrent_requests: {concurrent_requests}",
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
f"Throughput: {throughput} workload/s",
f"Successful responses: {len([r for r in responses if r.status == 200])}",
"#" * 60,
]
)
)
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
log.debug( log.debug(
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}" f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
) )
# save max_throughput so we don't have to run benchmark again on restart of cold instances
with open(BENCHMARK_INDICATOR_FILE, "w") as f: with open(BENCHMARK_INDICATOR_FILE, "w") as f:
f.write(str(max_throughput)) f.write(str(max_throughput))
return max_throughput return max_throughput
+3
View File
@@ -10,6 +10,7 @@ from collections import Counter
from dataclasses import dataclass, field, asdict from dataclasses import dataclass, field, asdict
from urllib.parse import urljoin from urllib.parse import urljoin
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
import requests import requests
from lib.data_types import AuthData, ApiPayload from lib.data_types import AuthData, ApiPayload
@@ -120,9 +121,11 @@ class ClientState:
self.url = worker_address self.url = worker_address
url = urljoin(worker_address, self.worker_endpoint) url = urljoin(worker_address, self.worker_endpoint)
self.status = ClientStatus.Generating self.status = ClientStatus.Generating
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
if response.status_code != 200: if response.status_code != 200:
self.infer_error.append( self.infer_error.append(
+6 -1
View File
@@ -30,7 +30,12 @@ class Endpoint:
Returns: Returns:
Endpoint API key if successful, None otherwise Endpoint API key if successful, None otherwise
""" """
vast_console_url = "https://console.vast.ai/api/v0/endptjobs/" endpoints = {
"alpha": "alpha",
"candidate": "candidate",
"prod": "console",
}
vast_console_url = f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
headers = {"Authorization": f"Bearer {account_api_key}"} headers = {"Authorization": f"Bearer {account_api_key}"}
try: try:
+15
View File
@@ -0,0 +1,15 @@
import tempfile
from functools import cache
import requests
@cache
def get_cert_file_path():
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
response = requests.get(cert_url)
response.raise_for_status()
# Use a temporary file that is not deleted on close
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
f.write(response.content)
return f.name
+3
View File
@@ -5,6 +5,7 @@ import requests
from lib.test_utils import print_truncate_res from lib.test_utils import print_truncate_res
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
""" """
NOTE: this client example uses a custom comfy workflow compatible with SD3 only NOTE: this client example uses a custom comfy workflow compatible with SD3 only
@@ -51,6 +52,7 @@ def call_default_workflow(
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
response.raise_for_status() response.raise_for_status()
print_truncate_res(str(response.json())) print_truncate_res(str(response.json()))
@@ -141,6 +143,7 @@ def call_custom_workflow_for_sd3(
response = requests.post( response = requests.post(
url, url,
json=req_data, json=req_data,
verify=get_cert_file_path(),
) )
response.raise_for_status() response.raise_for_status()
print_truncate_res(str(response.json())) print_truncate_res(str(response.json()))
+7 -2
View File
@@ -6,6 +6,7 @@ from urllib.parse import urljoin
from typing import Dict, Any, Optional, Iterator, Union, List from typing import Dict, Any, Optional, Iterator, Union, List
import requests import requests
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
from .data_types.client import CompletionConfig, ChatCompletionConfig from .data_types.client import CompletionConfig, ChatCompletionConfig
logging.basicConfig( logging.basicConfig(
@@ -90,9 +91,13 @@ class APIClient:
# Make the request using the specified method # Make the request using the specified method
if method.upper() == "POST": if method.upper() == "POST":
response = requests.post(url, json=req_data, stream=stream) response = requests.post(
url, json=req_data, stream=stream, verify=get_cert_file_path()
)
elif method.upper() == "GET": elif method.upper() == "GET":
response = requests.get(url, params=req_data, stream=stream) response = requests.get(
url, params=req_data, stream=stream, verify=get_cert_file_path()
)
else: else:
raise ValueError(f"Unsupported HTTP method: {method}") raise ValueError(f"Unsupported HTTP method: {method}")
+6 -1
View File
@@ -4,6 +4,7 @@ import json
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
from utils.endpoint_util import Endpoint from utils.endpoint_util import Endpoint
from utils.ssl import get_cert_file_path
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.DEBUG,
@@ -42,7 +43,11 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
req_data = dict(payload=payload, auth_data=auth_data) req_data = dict(payload=payload, auth_data=auth_data)
url = urljoin(url, WORKER_ENDPOINT) url = urljoin(url, WORKER_ENDPOINT)
print(f"url: {url}") print(f"url: {url}")
response = requests.post(url, json=req_data) response = requests.post(
url,
json=req_data,
verify=get_cert_file_path(),
)
response.raise_for_status() response.raise_for_status()
res = response.json() res = response.json()
print(res) print(res)