Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d3be9fe7db |
+27
-38
@@ -280,52 +280,41 @@ class Backend:
|
||||
return float(f.readline())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
log.debug("Initial run to trigger model loading...")
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
await self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||
|
||||
max_throughput = 0
|
||||
last_throughput = 0
|
||||
sum_throughput = 0
|
||||
concurrent_requests = 10 if self.allow_parallel_requests else 1
|
||||
|
||||
for run in range(1, self.benchmark_handler.benchmark_runs + 1):
|
||||
for run in range(self.benchmark_handler.benchmark_runs + 1):
|
||||
start = time.time()
|
||||
tasks = []
|
||||
total_workload = 0
|
||||
|
||||
for _ in range(concurrent_requests):
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
total_workload += payload.count_workload()
|
||||
tasks.append(
|
||||
self.__call_api(handler=self.benchmark_handler, payload=payload)
|
||||
)
|
||||
|
||||
responses = await gather(*tasks)
|
||||
time_elapsed = time.time() - start
|
||||
|
||||
throughput = total_workload / time_elapsed
|
||||
sum_throughput += throughput
|
||||
max_throughput = max(max_throughput, throughput)
|
||||
|
||||
# Log results for debugging
|
||||
log.debug(
|
||||
"\n".join(
|
||||
[
|
||||
"#" * 60,
|
||||
f"Run: {run}, concurrent_requests: {concurrent_requests}",
|
||||
f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s",
|
||||
f"Throughput: {throughput} workload/s",
|
||||
f"Successful responses: {len([r for r in responses if r.status == 200])}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
payload = self.benchmark_handler.make_benchmark_payload()
|
||||
res = await self.__call_api(
|
||||
handler=self.benchmark_handler, payload=payload
|
||||
)
|
||||
|
||||
data = await res.json()
|
||||
time_elapsed = time.time() - start
|
||||
# first run triggers one-time loading of the model which is very slow, so we skip counting it
|
||||
if run == 0:
|
||||
continue
|
||||
else:
|
||||
workload = payload.count_workload()
|
||||
last_throughput = workload / time_elapsed
|
||||
sum_throughput += last_throughput
|
||||
max_throughput = max(max_throughput, last_throughput)
|
||||
log.debug(
|
||||
"\n".join(
|
||||
[
|
||||
"#" * 60,
|
||||
f"Run: {run}, workload: {workload} time_elapsed: {time_elapsed}, throughput: {last_throughput}",
|
||||
"",
|
||||
f"response: {data}",
|
||||
"#" * 60,
|
||||
]
|
||||
)
|
||||
)
|
||||
average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs
|
||||
log.debug(
|
||||
f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}"
|
||||
)
|
||||
# save max_throughput so we don't have to run benchmark again on restart of cold instances
|
||||
with open(BENCHMARK_INDICATOR_FILE, "w") as f:
|
||||
f.write(str(max_throughput))
|
||||
return max_throughput
|
||||
|
||||
@@ -10,7 +10,6 @@ from collections import Counter
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from urllib.parse import urljoin
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
import requests
|
||||
|
||||
from lib.data_types import AuthData, ApiPayload
|
||||
@@ -121,11 +120,9 @@ class ClientState:
|
||||
self.url = worker_address
|
||||
url = urljoin(worker_address, self.worker_endpoint)
|
||||
self.status = ClientStatus.Generating
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
self.infer_error.append(
|
||||
|
||||
@@ -30,12 +30,7 @@ class Endpoint:
|
||||
Returns:
|
||||
Endpoint API key if successful, None otherwise
|
||||
"""
|
||||
endpoints = {
|
||||
"alpha": "alpha",
|
||||
"candidate": "candidate",
|
||||
"prod": "console",
|
||||
}
|
||||
vast_console_url = f"https://{endpoints[instance]}.vast.ai/api/v0/endptjobs/"
|
||||
vast_console_url = "https://console.vast.ai/api/v0/endptjobs/"
|
||||
headers = {"Authorization": f"Bearer {account_api_key}"}
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
import tempfile
|
||||
from functools import cache
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@cache
|
||||
def get_cert_file_path():
|
||||
cert_url = "https://console.vast.ai/static/jvastai_root.cer"
|
||||
response = requests.get(cert_url)
|
||||
response.raise_for_status()
|
||||
# Use a temporary file that is not deleted on close
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f:
|
||||
f.write(response.content)
|
||||
return f.name
|
||||
@@ -5,7 +5,6 @@ import requests
|
||||
|
||||
from lib.test_utils import print_truncate_res
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
|
||||
"""
|
||||
NOTE: this client example uses a custom comfy workflow compatible with SD3 only
|
||||
@@ -52,7 +51,6 @@ def call_default_workflow(
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
print_truncate_res(str(response.json()))
|
||||
@@ -143,7 +141,6 @@ def call_custom_workflow_for_sd3(
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
print_truncate_res(str(response.json()))
|
||||
|
||||
@@ -6,7 +6,6 @@ from urllib.parse import urljoin
|
||||
from typing import Dict, Any, Optional, Iterator, Union, List
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
from .data_types.client import CompletionConfig, ChatCompletionConfig
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -91,13 +90,9 @@ class APIClient:
|
||||
|
||||
# Make the request using the specified method
|
||||
if method.upper() == "POST":
|
||||
response = requests.post(
|
||||
url, json=req_data, stream=stream, verify=get_cert_file_path()
|
||||
)
|
||||
response = requests.post(url, json=req_data, stream=stream)
|
||||
elif method.upper() == "GET":
|
||||
response = requests.get(
|
||||
url, params=req_data, stream=stream, verify=get_cert_file_path()
|
||||
)
|
||||
response = requests.get(url, params=req_data, stream=stream)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from utils.endpoint_util import Endpoint
|
||||
from utils.ssl import get_cert_file_path
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
@@ -43,11 +42,7 @@ def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> No
|
||||
req_data = dict(payload=payload, auth_data=auth_data)
|
||||
url = urljoin(url, WORKER_ENDPOINT)
|
||||
print(f"url: {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
json=req_data,
|
||||
verify=get_cert_file_path(),
|
||||
)
|
||||
response = requests.post(url, json=req_data)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
print(res)
|
||||
|
||||
Reference in New Issue
Block a user